Cricket-Commentary / inference.py
switin06's picture
Update inference.py
eb9c410 verified
import torch
import math
import cv2
import json
import time
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer,AutoModelForCausalLM
import clip
import numpy as np
from tqdm import tqdm
import os
from dotenv import load_dotenv
from IPython.display import Audio
import re
from groq import Groq
from moviepy.editor import VideoFileClip, AudioFileClip,CompositeAudioClip
from pydub import AudioSegment
import shutil
import gradio as gr
from huggingface_hub import hf_hub_download
from TTS.api import TTS
groq_key = os.environ["GROQ_API_KEY"]
class TemporalTransformerEncoder(nn.Module):
def __init__(self, embed_dim, num_heads, num_layers, num_frames, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_frames = num_frames
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.position_embed = nn.Parameter(torch.zeros(1, num_frames + 1, embed_dim))
nn.init.trunc_normal_(self.position_embed, std=0.02)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=4 * embed_dim,
dropout=dropout,
activation='gelu',
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, x):
B = x.size(0)
cls_token = self.cls_token.expand(B, 1, -1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.position_embed[:, :x.size(1)]
x = self.transformer(x)
return {
"cls": x[:, 0],
"tokens": x[:, 1:]
}
class CricketCommentator(nn.Module):
def __init__(self, train_mode=False, num_frames=16, train_layers=2):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.num_frames = num_frames
import clip
self.clip, self.preprocess = clip.load("ViT-B/32", device=self.device)
self.clip = self.clip.float()
if train_mode:
for param in self.clip.parameters():
param.requires_grad = False
self.temporal_encoder = TemporalTransformerEncoder(
embed_dim=512,
num_heads=8,
num_layers=3,
num_frames=num_frames,
dropout=0.1
).to(self.device).float()
# Updated projection for DeepSeek (2048-dim)
self.projection = nn.Sequential(
nn.Linear(512, 2048),
nn.GELU(),
nn.LayerNorm(2048),
nn.Dropout(0.1),
nn.Linear(2048, 2048),
nn.Tanh()
).to(self.device).float()
# DeepSeek model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct")
self.model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct").to(self.device).float()
self.tokenizer.pad_token = self.tokenizer.eos_token
# Freeze all parameters initially
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze last N layers if training
if train_mode and train_layers > 0:
# Unfreeze last transformer blocks
for block in self.model.model.layers[-train_layers:]:
for param in block.parameters():
param.requires_grad = True
# Unfreeze final norm and head
for param in self.model.model.norm.parameters():
param.requires_grad = True
for param in self.model.lm_head.parameters():
param.requires_grad = True
def forward(self, frames):
batch_size = frames.shape[0]
frames = frames.view(-1, 3, 224, 224)
with torch.no_grad():
frame_features = self.clip.encode_image(frames.to(self.device))
frame_features = frame_features.view(batch_size, self.num_frames, -1).float()
frame_features = F.normalize(frame_features, p=2, dim=-1)
temporal_out = self.temporal_encoder(frame_features)
visual_embeds = self.projection(temporal_out["cls"])
return F.normalize(visual_embeds, p=2, dim=-1).unsqueeze(1)
def extract_frames(self, video_path):
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
stride = max(1, total_frames // self.num_frames)
frames = []
for i in range(0, total_frames, stride):
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
h, w, _ = frame.shape
crop_size = min(h, w) // 2
y, x = (h - crop_size) // 2, (w - crop_size) // 2
cropped = cv2.cvtColor(frame[y:y+crop_size, x:x+crop_size], cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(cropped)
frames.append(self.preprocess(pil_image))
if len(frames) >= self.num_frames:
break
else:
break
cap.release()
if len(frames) < self.num_frames:
frames.extend([torch.zeros(3, 224, 224)] * (self.num_frames - len(frames)))
return torch.stack(frames)
def generate_commentary(self, video_path):
frames = self.extract_frames(video_path).unsqueeze(0).to(self.device)
visual_embeds = self.forward(frames) # Shape: [1, 1, 2560]
# Prepare text prompt
prompt = ("USER: <video> Provide a sequential description of the cricket delivery in the video. Start with the bowler's run-up, then describe the delivery, the batsman's action, and finally the outcome of the ball. Keep it concise also make sure that you won't cross 2 lines and the commentary must be in a professional tone.ASSISTANT:")
# Tokenize text prompt
inputs = self.tokenizer(prompt, return_tensors="pt",
truncation=True, max_length=512).to(self.device)
# Get token embeddings
token_embeds = self.model.model.embed_tokens(inputs['input_ids'])
# Combine visual and text embeddings
inputs_embeds = torch.cat([visual_embeds, token_embeds], dim=1)
# Create attention mask (1 for visual token + text tokens)
attention_mask = torch.cat([
torch.ones(visual_embeds.shape[:2], dtype=torch.long).to(self.device),
inputs['attention_mask']
], dim=1)
# Generate commentary
outputs = self.model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=200,
min_new_tokens=100,
do_sample=True,
temperature=0.8,
top_k=40,
top_p=0.9,
repetition_penalty=1.15,
no_repeat_ngram_size=3,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.eos_token_id
)
# Extract and clean generated text
full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
commentary = full_text.split("ASSISTANT:")[-1].strip()
print(commentary)
return commentary
# -------------------- PIPELINE --------------------
def summarize_commentary(commentary, client, video_duration, tts_speed):
prompt = f"""
You are a professional cricket commentary editor.
Task:
- Rewrite the input commentary into a concise, broadcast-style Commetary.
- Focus only on the action and result. With very Minimal exaggeration or filler.
- DO NOT change the original event — if it’s a four, six, or wicket (out), keep it exactly the same.
- If the input says "four", your output must say "four". Same for "six" or "out".
- Ensure the sentence fits within {video_duration} seconds at {tts_speed}x speech rate.
- Use correct grammar and punctuation for smooth TTS (Text-to-Speech) delivery.
Only output the cleaned commentary. Do not add any explanations.
Input:
{commentary}
Output:
"""
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.1-8b-instant"
)
final = chat_completion.choices[0].message.content.strip()
print("="*50)
print(final)
print("="*50)
return final
def text_to_speech(text, output_path, speed):
raw_path = "raw_commentary.wav"
# Load multilingual multi-speaker TTS model
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=True, gpu=False)
language="en";
# Choose a male speaker
male_speaker = "male-en-2\n"
# Generate TTS to file with speaker
tts.tts_to_file(text=text, speaker=male_speaker,language=language,file_path=raw_path)
# Speed up using ffmpeg
os.system(f"ffmpeg -y -i {raw_path} -filter:a atempo={speed} {output_path}")
os.remove(raw_path)
def mix_audio(video_path, voice_path, crowd_path, output_path):
video = VideoFileClip(video_path)
video_duration_ms = video.duration * 1000
voice = AudioSegment.from_file(voice_path)[:int(video_duration_ms - 100)]
crowd = AudioSegment.from_file(crowd_path) - 10
while len(crowd) < len(voice):
crowd += crowd
crowd = crowd[:len(voice)]
mixed = crowd.overlay(voice)
crowd_head = AudioSegment.from_file(crowd_path) - 15
while len(crowd_head) < (video_duration_ms - len(mixed)):
crowd_head += crowd_head
crowd_head = crowd_head[:int(video_duration_ms - len(mixed))]
final_audio = crowd_head + mixed
temp_audio_path = "temp_mixed_audio.mp3"
final_audio.export(temp_audio_path, format="mp3")
final_video = video.set_audio(AudioFileClip(temp_audio_path))
final_video.write_videofile(output_path, codec="libx264", audio_codec="aac")
def main(video_path):
load_dotenv()
model_weights_path = hf_hub_download(repo_id="switin06/Deepseek_Cricket_commentator",filename="best_model_1.pth")
crowd_path = "assets/Stadium_Ambience.mp3"
# Load model
model = CricketCommentator(train_mode=False)
model.load_state_dict(torch.load(model_weights_path, map_location=model.device))
model.eval()
# Generate raw commentary
raw_commentary = model.generate_commentary(video_path)
# Summarize using Groq API
client = Groq(api_key=groq_key)
video = VideoFileClip(video_path)
video_duration = video.duration # in seconds
tts_speed = 1.11 # adjust as needed
clean_commentary = summarize_commentary(raw_commentary, client, video_duration, tts_speed)
# Text to speech
tts_path = "commentary_final.mp3"
text_to_speech(clean_commentary, tts_path, tts_speed)
short_audio_path = "pro_audio3.mp3"
os.system(f"ffmpeg -y -i {tts_path} -ss 0 -t 3 {short_audio_path}")
# Final video output
output_video_path = "final_video.mp4"
mix_audio(video_path, short_audio_path, crowd_path, output_video_path)
return output_video_path