Spaces:
Sleeping
Sleeping
import torch | |
import math | |
import cv2 | |
import json | |
import time | |
from PIL import Image | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer,AutoModelForCausalLM | |
import clip | |
import numpy as np | |
from tqdm import tqdm | |
import os | |
from dotenv import load_dotenv | |
from IPython.display import Audio | |
import re | |
from groq import Groq | |
from moviepy.editor import VideoFileClip, AudioFileClip,CompositeAudioClip | |
from pydub import AudioSegment | |
import shutil | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from TTS.api import TTS | |
groq_key = os.environ["GROQ_API_KEY"] | |
class TemporalTransformerEncoder(nn.Module): | |
def __init__(self, embed_dim, num_heads, num_layers, num_frames, dropout=0.1): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.num_frames = num_frames | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
nn.init.trunc_normal_(self.cls_token, std=0.02) | |
self.position_embed = nn.Parameter(torch.zeros(1, num_frames + 1, embed_dim)) | |
nn.init.trunc_normal_(self.position_embed, std=0.02) | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=embed_dim, | |
nhead=num_heads, | |
dim_feedforward=4 * embed_dim, | |
dropout=dropout, | |
activation='gelu', | |
batch_first=True | |
) | |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
def forward(self, x): | |
B = x.size(0) | |
cls_token = self.cls_token.expand(B, 1, -1) | |
x = torch.cat([cls_token, x], dim=1) | |
x = x + self.position_embed[:, :x.size(1)] | |
x = self.transformer(x) | |
return { | |
"cls": x[:, 0], | |
"tokens": x[:, 1:] | |
} | |
class CricketCommentator(nn.Module): | |
def __init__(self, train_mode=False, num_frames=16, train_layers=2): | |
super().__init__() | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.num_frames = num_frames | |
import clip | |
self.clip, self.preprocess = clip.load("ViT-B/32", device=self.device) | |
self.clip = self.clip.float() | |
if train_mode: | |
for param in self.clip.parameters(): | |
param.requires_grad = False | |
self.temporal_encoder = TemporalTransformerEncoder( | |
embed_dim=512, | |
num_heads=8, | |
num_layers=3, | |
num_frames=num_frames, | |
dropout=0.1 | |
).to(self.device).float() | |
# Updated projection for DeepSeek (2048-dim) | |
self.projection = nn.Sequential( | |
nn.Linear(512, 2048), | |
nn.GELU(), | |
nn.LayerNorm(2048), | |
nn.Dropout(0.1), | |
nn.Linear(2048, 2048), | |
nn.Tanh() | |
).to(self.device).float() | |
# DeepSeek model and tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct") | |
self.model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct").to(self.device).float() | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Freeze all parameters initially | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
# Unfreeze last N layers if training | |
if train_mode and train_layers > 0: | |
# Unfreeze last transformer blocks | |
for block in self.model.model.layers[-train_layers:]: | |
for param in block.parameters(): | |
param.requires_grad = True | |
# Unfreeze final norm and head | |
for param in self.model.model.norm.parameters(): | |
param.requires_grad = True | |
for param in self.model.lm_head.parameters(): | |
param.requires_grad = True | |
def forward(self, frames): | |
batch_size = frames.shape[0] | |
frames = frames.view(-1, 3, 224, 224) | |
with torch.no_grad(): | |
frame_features = self.clip.encode_image(frames.to(self.device)) | |
frame_features = frame_features.view(batch_size, self.num_frames, -1).float() | |
frame_features = F.normalize(frame_features, p=2, dim=-1) | |
temporal_out = self.temporal_encoder(frame_features) | |
visual_embeds = self.projection(temporal_out["cls"]) | |
return F.normalize(visual_embeds, p=2, dim=-1).unsqueeze(1) | |
def extract_frames(self, video_path): | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
stride = max(1, total_frames // self.num_frames) | |
frames = [] | |
for i in range(0, total_frames, stride): | |
cap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
ret, frame = cap.read() | |
if ret: | |
h, w, _ = frame.shape | |
crop_size = min(h, w) // 2 | |
y, x = (h - crop_size) // 2, (w - crop_size) // 2 | |
cropped = cv2.cvtColor(frame[y:y+crop_size, x:x+crop_size], cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(cropped) | |
frames.append(self.preprocess(pil_image)) | |
if len(frames) >= self.num_frames: | |
break | |
else: | |
break | |
cap.release() | |
if len(frames) < self.num_frames: | |
frames.extend([torch.zeros(3, 224, 224)] * (self.num_frames - len(frames))) | |
return torch.stack(frames) | |
def generate_commentary(self, video_path): | |
frames = self.extract_frames(video_path).unsqueeze(0).to(self.device) | |
visual_embeds = self.forward(frames) # Shape: [1, 1, 2560] | |
# Prepare text prompt | |
prompt = ("USER: <video> Provide a sequential description of the cricket delivery in the video. Start with the bowler's run-up, then describe the delivery, the batsman's action, and finally the outcome of the ball. Keep it concise also make sure that you won't cross 2 lines and the commentary must be in a professional tone.ASSISTANT:") | |
# Tokenize text prompt | |
inputs = self.tokenizer(prompt, return_tensors="pt", | |
truncation=True, max_length=512).to(self.device) | |
# Get token embeddings | |
token_embeds = self.model.model.embed_tokens(inputs['input_ids']) | |
# Combine visual and text embeddings | |
inputs_embeds = torch.cat([visual_embeds, token_embeds], dim=1) | |
# Create attention mask (1 for visual token + text tokens) | |
attention_mask = torch.cat([ | |
torch.ones(visual_embeds.shape[:2], dtype=torch.long).to(self.device), | |
inputs['attention_mask'] | |
], dim=1) | |
# Generate commentary | |
outputs = self.model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
max_new_tokens=200, | |
min_new_tokens=100, | |
do_sample=True, | |
temperature=0.8, | |
top_k=40, | |
top_p=0.9, | |
repetition_penalty=1.15, | |
no_repeat_ngram_size=3, | |
eos_token_id=self.tokenizer.eos_token_id, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
# Extract and clean generated text | |
full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
commentary = full_text.split("ASSISTANT:")[-1].strip() | |
print(commentary) | |
return commentary | |
# -------------------- PIPELINE -------------------- | |
def summarize_commentary(commentary, client, video_duration, tts_speed): | |
prompt = f""" | |
You are a professional cricket commentary editor. | |
Task: | |
- Rewrite the input commentary into a concise, broadcast-style Commetary. | |
- Focus only on the action and result. With very Minimal exaggeration or filler. | |
- DO NOT change the original event — if it’s a four, six, or wicket (out), keep it exactly the same. | |
- If the input says "four", your output must say "four". Same for "six" or "out". | |
- Ensure the sentence fits within {video_duration} seconds at {tts_speed}x speech rate. | |
- Use correct grammar and punctuation for smooth TTS (Text-to-Speech) delivery. | |
Only output the cleaned commentary. Do not add any explanations. | |
Input: | |
{commentary} | |
Output: | |
""" | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model="llama-3.1-8b-instant" | |
) | |
final = chat_completion.choices[0].message.content.strip() | |
print("="*50) | |
print(final) | |
print("="*50) | |
return final | |
def text_to_speech(text, output_path, speed): | |
raw_path = "raw_commentary.wav" | |
# Load multilingual multi-speaker TTS model | |
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=True, gpu=False) | |
language="en"; | |
# Choose a male speaker | |
male_speaker = "male-en-2\n" | |
# Generate TTS to file with speaker | |
tts.tts_to_file(text=text, speaker=male_speaker,language=language,file_path=raw_path) | |
# Speed up using ffmpeg | |
os.system(f"ffmpeg -y -i {raw_path} -filter:a atempo={speed} {output_path}") | |
os.remove(raw_path) | |
def mix_audio(video_path, voice_path, crowd_path, output_path): | |
video = VideoFileClip(video_path) | |
video_duration_ms = video.duration * 1000 | |
voice = AudioSegment.from_file(voice_path)[:int(video_duration_ms - 100)] | |
crowd = AudioSegment.from_file(crowd_path) - 10 | |
while len(crowd) < len(voice): | |
crowd += crowd | |
crowd = crowd[:len(voice)] | |
mixed = crowd.overlay(voice) | |
crowd_head = AudioSegment.from_file(crowd_path) - 15 | |
while len(crowd_head) < (video_duration_ms - len(mixed)): | |
crowd_head += crowd_head | |
crowd_head = crowd_head[:int(video_duration_ms - len(mixed))] | |
final_audio = crowd_head + mixed | |
temp_audio_path = "temp_mixed_audio.mp3" | |
final_audio.export(temp_audio_path, format="mp3") | |
final_video = video.set_audio(AudioFileClip(temp_audio_path)) | |
final_video.write_videofile(output_path, codec="libx264", audio_codec="aac") | |
def main(video_path): | |
load_dotenv() | |
model_weights_path = hf_hub_download(repo_id="switin06/Deepseek_Cricket_commentator",filename="best_model_1.pth") | |
crowd_path = "assets/Stadium_Ambience.mp3" | |
# Load model | |
model = CricketCommentator(train_mode=False) | |
model.load_state_dict(torch.load(model_weights_path, map_location=model.device)) | |
model.eval() | |
# Generate raw commentary | |
raw_commentary = model.generate_commentary(video_path) | |
# Summarize using Groq API | |
client = Groq(api_key=groq_key) | |
video = VideoFileClip(video_path) | |
video_duration = video.duration # in seconds | |
tts_speed = 1.11 # adjust as needed | |
clean_commentary = summarize_commentary(raw_commentary, client, video_duration, tts_speed) | |
# Text to speech | |
tts_path = "commentary_final.mp3" | |
text_to_speech(clean_commentary, tts_path, tts_speed) | |
short_audio_path = "pro_audio3.mp3" | |
os.system(f"ffmpeg -y -i {tts_path} -ss 0 -t 3 {short_audio_path}") | |
# Final video output | |
output_video_path = "final_video.mp4" | |
mix_audio(video_path, short_audio_path, crowd_path, output_video_path) | |
return output_video_path | |