from flask import Flask,request,render_template,send_file,jsonify
import os
# from transformers import AutoTokenizer, AutoModel
import anvil.server
import pathlib
import textwrap
import import_ipynb
from library import call_gpt, call_gemini, compress_bool_list, uncompress_bool_list
from background_service import BackgroundTaskService
import numpy as np
from keys_min import server_uplink

print(server_uplink)
anvil.server.connect(server_uplink)

# from sentence_transformers import SentenceTransformer
# from sentence_transformers.util import cos_sim
# model = SentenceTransformer('thenlper/gte-large')
# model = SentenceTransformer('BAAI/bge-large-en')

# @anvil.server.callable
# def encode(sentence = None):
#     vec = model.encode(sentence)
#     return [float(val) if isinstance(val, (int, float, np.float32)) else 0.0 for val in vec]

app=Flask(__name__)
MESSAGED={'title':'API Server for ICAPP',
          'messageL':['published server functions:','encode_anvil(text)', 'encode(sentence)',
                      'call_gemini(text,key)','call_gpt(text,key,model)',
                      'task_id<=launch(func_name,*args)','poll(task_id)']}

# tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
# encoder = AutoModel.from_pretrained('allenai/specter')

anvil.server.callable(call_gpt)
anvil.server.callable(call_gemini)
anvil.server.callable(compress_bool_list)
anvil.server.callable(uncompress_bool_list)

service=BackgroundTaskService(max_tasks=10)
service.register(call_gpt)
service.register(call_gemini)

@anvil.server.callable
def launch(func_name,*args):
    global service
    # Launch task
    task_id = service.launch_task(func_name, *args)
    print(f"Task launched with ID: {task_id}")
    return task_id

@anvil.server.callable
def poll(task_id):
    global service
    # Poll for completion; if not complete return "In Progress" else return result
    result = service.get_result(task_id)
    if result=='No such task': return str(result)
    elif result!='In Progress': 
        del service.results[task_id]
        if isinstance(result, (int, float, str, list, dict, tuple)): 
            return result
        else: 
            print(str(result))
            return str(result)
    else: return str(result)

# @anvil.server.callable
# def encode_anvil(text):
#     inputs = tokenizer(text, padding=True, truncation=True, 
#                        return_tensors="pt", max_length=512)
#     result = encoder(**inputs)
#     embeddings = result.last_hidden_state[:, 0, :]
#     emb_array = embeddings.detach().numpy()
#     embedding=emb_array.tolist()
#     return embedding

@anvil.server.callable
def reset_service():
    global call_gpt, call_gemini, service
    service=BackgroundTaskService(max_tasks=10)
    service.register(call_gpt)
    service.register(call_gemini)

@anvil.server.callable
def print_results_table():
    global service
    return(service.results)

# @app.route('/encode',methods=['GET','POST'])
# def encode():
#     print(request)
#     if request.method=='GET':
#         text=request.args.get('text')
#     elif request.method=='POST':
#         data=request.get_json()
#         if 'text' in data: text=data["text"]
#     if text=='' or text is None: return -1
#     inputs = tokenizer(text, padding=True, truncation=True, 
#                        return_tensors="pt", max_length=512)
#     result = encoder(**inputs)
#     embeddings = result.last_hidden_state[:, 0, :]
#     emb_array = embeddings.detach().numpy()
#     embedding=emb_array.tolist()
#     return jsonify({'embedding': embedding})

@app.route('/',methods=['GET', 'POST'])
def home():
    return render_template('home.html',messageD=MESSAGED)

if __name__=='__main__':
    app.run(host="0.0.0.0", port=7860)