import gradio as gr
import json
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import json
import time
import os
from huggingface_hub import CommitScheduler
from functools import partial
import pandas as pd
import numpy as np
from huggingface_hub import snapshot_download

def enable_buttons_side_by_side():
    return tuple(gr.update(visible=True, interactive=True) for i in range(6))

def disable_buttons_side_by_side():
    return tuple(gr.update(visible=i>=4, interactive=False) for i in range(6))


os.makedirs('data', exist_ok = True)
LOG_FILENAME = os.path.join('data', f'log_{datetime.now().isoformat()}.json')
FLAG_FILENAME = os.path.join('data', f'flagged_{datetime.now().isoformat()}.json')

enable_btn = gr.update(interactive=True, visible=True)
disable_btn = gr.update(interactive=False)
invisible_btn = gr.update(interactive=False, visible=False)
no_change_btn = gr.update(value="No Change", interactive=True, visible=True)

DS_ID = os.getenv('DS_ID')
TOKEN = os.getenv('TOKEN')
SONG_SOURCE = os.getenv("SONG_SOURCE")
LOCAL_DIR = './'

snapshot_download(repo_id=SONG_SOURCE, repo_type="dataset", token = TOKEN, local_dir = LOCAL_DIR)
print(os.listdir(LOCAL_DIR))


scheduler = CommitScheduler(
    repo_id= DS_ID,
    repo_type="dataset",
    folder_path= os.path.dirname(LOG_FILENAME),
    path_in_repo="data",
    token = TOKEN,
    every = 10,
)

df = pd.read_csv(os.path.join(LOCAL_DIR,'singfake_english.csv'))
df.filename = os.path.join(LOCAL_DIR, 'Songs') + '/' + df.filename + '.mp3'


indices = list(df.index)
main_indices = indices.copy()

def init_indices():
    global indices, main_indices
    indices = main_indices



def pick_and_remove_two():
    global indices
    if len(indices) < 2:
        init_indices()

    np.random.shuffle(indices)
    sel_indices = indices[:2].copy()
    indices = indices[2:]
    return sel_indices




def vote_last_response(state, vote_type, request: gr.Request):
    with scheduler.lock:
      with open(LOG_FILENAME, "a") as fout:
          data = {
              "tstamp": round(time.time(), 4),
              "type": vote_type,
              "state0": state[0].dict(),
              "state1": state[1].dict(),
              "ip": get_ip(request),
          }
          fout.write(json.dumps(data) + "\n")

def flag_last_response(state, vote_type, request: gr.Request):
    with scheduler.lock:
      with open(FLAG_FILENAME, "a") as fout:
          data = {
              "tstamp": round(time.time(), 4),
              "type": vote_type,
              "state": state.dict(),
              "ip": get_ip(request),
          }
          fout.write(json.dumps(data) + "\n")


class AudioStateIG:
    def __init__(self, model_name):
        self.conv_id = uuid4().hex
        self.model_name = model_name

    def dict(self):
        base = {
            "conv_id": self.conv_id,
            "model_name": self.model_name,
            }
        return base

def get_ip(request: gr.Request):
    if request:
        if "cf-connecting-ip" in request.headers:
            ip = request.headers["cf-connecting-ip"] or request.client.host
        else:
            ip = request.client.host
    else:
        ip = None
    return ip


def get_song(idx, df = df):
    row = df.loc[idx]
    audio_path = row.filename
    state = AudioStateIG(row['Bonafide Or Spoof'])
    return state, audio_path

def generate_songs(state0, state1):
    idx0, idx1 = pick_and_remove_two()
    state0, audio_a = get_song(idx0)
    state1, audio_b = get_song(idx1)

    return state0, audio_a, state1, audio_b, "Model A: Vote to Reveal", "Model B: Vote to Reveal"

def random_sample_button(prompt):

    audio_a = "marine.mp3"
    audio_b = "marine.mp3"
    return audio_a, audio_b

def leftvote_last_response(
    state0, state1, request: gr.Request
):
    vote_last_response(
        [state0, state1], "leftvote", request
    )
    return (disable_btn,) * 6 + (
        gr.Markdown(f"### Model A: {state0.model_name}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name}", visible=True))

def rightvote_last_response(
    state0, state1, request: gr.Request
):
    vote_last_response(
        [state0, state1], "rightvote", request
    )
    return (disable_btn,) * 6 + (
        gr.Markdown(f"### Model A: {state0.model_name}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name}", visible=True))

def tievote_last_response(
    state0, state1, request: gr.Request
):
    vote_last_response(
        [state0, state1], "tievote", request
    )
    return (disable_btn,) * 6 + (
        gr.Markdown(f"### Model A: {state0.model_name}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name}", visible=True))

def bothbadvote_last_response(
    state0, state1, request: gr.Request
):
    vote_last_response(
        [state0, state1], "bothbadvote", request
    )
    return (disable_btn,) * 6 + (
        gr.Markdown(f"### Model A: {state0.model_name}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name}", visible=True))

def leftheard_last_response(
    state, request: gr.Request
):
    vote_last_response(
        [state], "leftheard", request
    )
    return (disable_btn,) * 6 + (
        gr.Markdown(f"### Model A: {state0.model_name}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name}", visible=True))


def rightheard_last_response(
    state, request: gr.Request
):
    vote_last_response(
        [state], "rightheard", request
    )
    return (disable_btn,) * 6 + (
        gr.Markdown(f"### Model A: {state0.model_name}", visible=True),
        gr.Markdown(f"### Model B: {state1.model_name}", visible=True))