import gradio as gr
from gradio_client import Client, handle_file
import os
import random
import json
import re
import numpy as np
from moviepy.editor import VideoFileClip
from moviepy.audio.AudioClip import AudioClip

hf_token = os.environ.get("HF_TKN")
MAX_SEED = np.iinfo(np.int32).max

def extract_audio(video_in):
    input_video = video_in
    output_audio = 'audio.wav'
    
    # Open the video file and extract the audio
    video_clip = VideoFileClip(input_video)
    audio_clip = video_clip.audio
    
    # Save the audio as a .wav file
    audio_clip.write_audiofile(output_audio, fps=44100)  # Use 44100 Hz as the sample rate for .wav files  
    print("Audio extraction complete.")

    return 'audio.wav'

def get_caption_from_kosmos(image_in):
    kosmos2_client = Client("fffiloni/Kosmos-2-API", hf_token=hf_token)
    kosmos2_result = kosmos2_client.predict(
		image_input=handle_file(image_in),
		text_input="Detailed",
		api_name="/generate_predictions"
    )
    print(f"KOSMOS2 RETURNS: {kosmos2_result}")

    data = kosmos2_result[1]

    # Extract and combine tokens starting from the second element
    sentence = ''.join(item['token'] for item in data[1:])

    # Find the last occurrence of "."
    #last_period_index = full_sentence.rfind('.')

    # Truncate the string up to the last period
    #truncated_caption = full_sentence[:last_period_index + 1]

    # print(truncated_caption)
    #print(f"\n—\nIMAGE CAPTION: {truncated_caption}")
    
    return sentence

def get_caption(image_in):
    client = Client("fffiloni/moondream1", hf_token=hf_token)
    result = client.predict(
    		image=handle_file(image_in),
    		question="Describe precisely the image in one sentence.",
    		api_name="/predict"
    )
    print(result)
    return result

def get_magnet(prompt):
    amended_prompt = f"{prompt}"
    print(amended_prompt)
    try:
        client = Client("https://fffiloni-magnet.hf.space/")
        result = client.predict(
            "facebook/audio-magnet-medium",	# Literal['facebook/magnet-small-10secs', 'facebook/magnet-medium-10secs', 'facebook/magnet-small-30secs', 'facebook/magnet-medium-30secs', 'facebook/audio-magnet-small', 'facebook/audio-magnet-medium']  in 'Model' Radio component
            "",	# str  in 'Model Path (custom models)' Textbox component
            amended_prompt,	# str  in 'Input Text' Textbox component
            3,	# float  in 'Temperature' Number component
            0.9,	# float  in 'Top-p' Number component
            10,	# float  in 'Max CFG coefficient' Number component
            1,	# float  in 'Min CFG coefficient' Number component
            20,	# float  in 'Decoding Steps (stage 1)' Number component
            10,	# float  in 'Decoding Steps (stage 2)' Number component
            10,	# float  in 'Decoding Steps (stage 3)' Number component
            10,	# float  in 'Decoding Steps (stage 4)' Number component
            "prod-stride1 (new!)",	# Literal['max-nonoverlap', 'prod-stride1 (new!)']  in 'Span Scoring' Radio component
            api_name="/predict_full"
        )
        print(result)
        return result[1]
    except:
        raise gr.Error("MAGNet space API is not ready, please try again in few minutes ")

def get_audioldm(prompt):
    try:
        client = Client("fffiloni/audioldm2-text2audio-text2music-API", hf_token=hf_token)
        seed = random.randint(0, MAX_SEED)
        result = client.predict(
            text=prompt,	# str in 'Input text' Textbox component
            negative_prompt="Low quality. Music.",	# str in 'Negative prompt' Textbox component
            duration=10,	# int | float (numeric value between 5 and 15) in 'Duration (seconds)' Slider component
            guidance_scale=6.5,	# int | float (numeric value between 0 and 7) in 'Guidance scale' Slider component
            random_seed=seed,	# int | float in 'Seed' Number component
            n_candidates=3,	# int | float (numeric value between 1 and 5) in 'Number waveforms to generate' Slider component
            api_name="/text2audio"
        )
        print(result)
        
        return result
    except:
        raise gr.Error("AudioLDM space API is not ready, please try again in few minutes ")

def get_audiogen(prompt):
    try: 
        client = Client("https://fffiloni-audiogen.hf.space/")
        result = client.predict(
            prompt,
            10,
            api_name="/infer"
        )
        return result
    except:
        raise gr.Error("AudioGen space API is not ready, please try again in few minutes ")

def get_tango(prompt):
    try:
        client = Client("fffiloni/tango", hf_token=hf_token)
        result = client.predict(
        		prompt=prompt,
        		steps=100,
        		guidance=3,
        		api_name="/predict"
        )
        print(result)
        return result
    except:
        raise gr.Error("Tango space API is not ready, please try again in few minutes ")
    
    

def get_tango2(prompt):
    try:
        client = Client("declare-lab/tango2")
        result = client.predict(
        		prompt=prompt,
        		output_format="wav",
        		steps=100,
        		guidance=3,
        		api_name="/predict"
        )
        print(result)
        return result
    except:
        raise gr.Error("Tango2 space API is not ready, please try again in few minutes ")
    
    

def get_stable_audio_open(prompt):
    try:
        client = Client("fffiloni/Stable-Audio-Open-A10", hf_token=hf_token)
        result = client.predict(
    		prompt=prompt,
    		seconds_total=30,
    		steps=100,
    		cfg_scale=7,
    		api_name="/predict"
        )
        print(result)
        return result
    except:
        raise gr.Error("Stable Audio Open space API is not ready, please try again in few minutes ")
    
def get_ezaudio(prompt):
    try:
        client = Client("OpenSound/EzAudio")
        result = client.predict(
        		text=prompt,
        		length=10,
        		guidance_scale=5,
        		guidance_rescale=0.75,
        		ddim_steps=50,
        		eta=1,
        		random_seed=0,
        		randomize_seed=True,
        		api_name="/generate_audio"
        )
        print(result)
        return result
    except:
        raise gr.Error("EzAudio space API is not ready, please try again in few minutes ")
    
def infer(image_in, chosen_model):
    caption = get_caption_from_kosmos(image_in)
    if chosen_model == "MAGNet" :
        magnet_result = get_magnet(caption)
        return magnet_result
    elif chosen_model == "AudioLDM-2" : 
        audioldm_result = get_audioldm(caption)
        return audioldm_result
    elif chosen_model == "AudioGen" :
        audiogen_result = get_audiogen(caption)
        return audiogen_result
    elif chosen_model == "Tango" :
        tango_result = get_tango(caption)
        return tango_result
    elif chosen_model == "Tango 2" :
        tango2_result = get_tango2(caption)
        return tango2_result
    elif chosen_model == "Stable Audio Open" :
        stable_audio_open_result = get_stable_audio_open(caption)
        return stable_audio_open_result
    elif chosen_model == "EzAudio" :
        ezaudio_result = get_ezaudio(caption)
        return ezaudio_result

css="""
#col-container{
    margin: 0 auto;
    max-width: 800px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML("""
        <h2 style="text-align: center;">
            Image to SFX
        </h2>
        <p style="text-align: center;">
            Compare sound effects generation models from image caption.
        </p>
        """)
        
        with gr.Column():
            image_in = gr.Image(sources=["upload"], type="filepath", label="Image input")
            with gr.Row():
                chosen_model = gr.Dropdown(label="Choose a model", choices=[
                    #"MAGNet", 
                    "AudioLDM-2", 
                    #"AudioGen", 
                    "Tango", 
                    "Tango 2", 
                    "Stable Audio Open", 
                    "EzAudio"
                ], value="AudioLDM-2")
                submit_btn = gr.Button("Submit")
        with gr.Column():
            audio_o = gr.Audio(label="Audio output")

        gr.Examples(
            examples = [["oiseau.png", "AudioLDM-2"]],
            inputs = [image_in, chosen_model]
        )
    
    submit_btn.click(
        fn=infer,
        inputs=[image_in, chosen_model],
        outputs=[audio_o],
    )

demo.queue(max_size=10).launch(debug=True, show_error=True)