import torch
from transformers import AutoTokenizer,AutoModelForTokenClassification
from transformers import GeoLMModel
import requests
import numpy as np
import pandas as pd
import scipy.spatial as sp   
import streamlit as st
import folium
from streamlit.components.v1 import html


from haversine import haversine, Unit


dataset=None



def generate_human_readable(tokens,labels):
    ret = []
    for t,lab in zip(tokens,labels):
        if t == '[SEP]':
            continue

        if t.startswith("##") :
            assert len(ret) > 0
            ret[-1] = ret[-1] + t.strip('##')

        elif lab==2:
            assert len(ret) > 0
            ret[-1] = ret[-1] + " "+ t.strip('##')
        else:
            ret.append(t)

    return ret

def getSlice(tensor):
    result = []
    curr = []
    for index, value in enumerate(tensor[0]):
        if value == 1 or value == 2:
            curr.append(index)

        if value == 0 and curr != []:
            result.append(curr)
            curr = []

    return result

def getIndex(input):


    tokenizer, model= getModel1()

    # Tokenize input sentence
    tokens = tokenizer.encode(input, return_tensors="pt")


    # Pass tokens through the model
    outputs = model(tokens) 


    # Retrieve predicted labels for each token
    predicted_labels = torch.argmax(outputs.logits, dim=2)

    predicted_labels = predicted_labels.detach().cpu().numpy()

    # "id2label": { "0": "O", "1": "B-Topo", "2": "I-Topo"  }

    predicted_labels = [model.config.id2label[label] for label in predicted_labels[0]]
    # print(predicted_labels)

    predicted_labels = torch.argmax(outputs.logits, dim=2)

    # print(predicted_labels)

    query_tokens = tokens[0][torch.where(predicted_labels[0] != 0)[0]]

    query_labels = predicted_labels[0][torch.where(predicted_labels[0] != 0)[0]]

    print(predicted_labels)    
    print(predicted_labels.shape)

    slices=getSlice(predicted_labels)


    # print(tokenizer.convert_ids_to_tokens(query_tokens))


    return slices

def cutSlices(tensor, slicesList):

    locationTensor= torch.zeros(1, len(slicesList), 768)

    curr=0
    for slice in slicesList:

        if len(slice)==1:
            locationTensor[0][curr] = tensor[0][slice[0]]
            curr=curr+1
        if len(slice)>1 :

            sliceTensor=tensor[0][slice[0]:slice[-1]+1]
            #(len, 768)-> (1,len, 768)
            sliceTensor = sliceTensor.unsqueeze(0)

            mean = torch.mean(sliceTensor,dim=1,keepdim=True)

            locationTensor[0][curr] = mean[0]

            curr=curr+1


    return locationTensor






def MLearningFormInput(input):


    tokenizer,model=getModel2()

    tokens = tokenizer.encode(input, return_tensors="pt") 

     # ['[CLS]', 'Minneapolis','[SEP]','Saint','Paul','[SEP]','Du','##lut','##h','[SEP]']
    # print(tokens)


    outputs = model(tokens, spatial_position_list_x=torch.zeros(tokens.shape), spatial_position_list_y=torch.zeros(tokens.shape))


    # print(outputs.last_hidden_state)

    # print(outputs.last_hidden_state.shape)


    slicesIndex=getIndex(input)

    # print(slicesIndex)

    #tensor -> tensor
    res= cutSlices(outputs.last_hidden_state, slicesIndex)


    return res





def generate_human_readable(tokens,labels):
    ret = []
    for t,lab in zip(tokens,labels):
        if t == '[SEP]':
            continue

        if t.startswith("##") :
            assert len(ret) > 0
            ret[-1] = ret[-1] + t.strip('##')

        elif lab==2:
            assert len(ret) > 0
            ret[-1] = ret[-1] + " "+ t.strip('##')
        else:
            ret.append(t)

    return ret


def getLocationName(input_sentence):
    # Model name from Hugging Face model hub
    tokenizer, model= getModel1()


    # Tokenize input sentence
    tokens = tokenizer.encode(input_sentence, return_tensors="pt")


    # Pass tokens through the model
    outputs = model(tokens) 


    # Retrieve predicted labels for each token
    predicted_labels = torch.argmax(outputs.logits, dim=2)

    predicted_labels = predicted_labels.detach().cpu().numpy()

    # "id2label": { "0": "O", "1": "B-Topo", "2": "I-Topo"  }

    predicted_labels = [model.config.id2label[label] for label in predicted_labels[0]]

    predicted_labels = torch.argmax(outputs.logits, dim=2)

    query_tokens = tokens[0][torch.where(predicted_labels[0] != 0)[0]]

    query_labels = predicted_labels[0][torch.where(predicted_labels[0] != 0)[0]]


    human_readable = generate_human_readable(tokenizer.convert_ids_to_tokens(query_tokens), query_labels)

    return human_readable    



def search_geonames(toponym, df):
    # GeoNames API endpoint
    api_endpoint = "http://api.geonames.org/searchJSON"

    username = "zekun"

    print(toponym)

    params = {
        'q': toponym,
        'username': username,
        'maxRows':10
    }

    response = requests.get(api_endpoint, params=params)
    data = response.json()

    result = []

    lat=[]
    lon=[]

    if 'geonames' in data:
        for place_info in data['geonames']:
            latitude = float(place_info.get('lat', 0.0))
            longitude = float(place_info.get('lng', 0.0))

            lat.append(latitude)
            lon.append(longitude)

            print(latitude)
            print(longitude)

            # getNeighborsDistance

            id = place_info.get('geonameId', '')

            print(id)

            global dataset
            res = get50Neigbors(id, dataset, k=50) 
            result.append(res)
            # candidate_places.append({
            #     'name': place_info.get('name', ''),
            #     'country': place_info.get('countryName', ''),
            #     'latitude': latitude,
            #     'longitude': longitude,

            # })
            print(res)


    df['lat'] = lat
    df['lon'] = lon
    result = torch.cat(result, dim=1).detach().numpy()
    return result



def get50Neigbors(locationID, dataset, k=50):

    print("neighbor part----------------------------------------------------------------")

    input_row = dataset.loc[dataset['GeonameID'] == locationID].iloc[0]


    lat, lon, geohash,name = input_row['Latitude'], input_row['Longitude'], input_row['Geohash'], input_row['Name']

    filtered_dataset = dataset.loc[dataset['Geohash'].str.startswith(geohash[:7])].copy()

    filtered_dataset['distance'] = filtered_dataset.apply(
        lambda row: haversine((lat, lon), (row['Latitude'], row['Longitude']), Unit.KILOMETERS),
        axis=1
    ).copy()


    print("neighbor end----------------------------------------------------------------")



    filtered_dataset = filtered_dataset.sort_values(by='distance')



    nearest_neighbors = filtered_dataset.head(k)[['Name']]


    neighbors=nearest_neighbors.values.tolist()


    tokenizer, model= getModel1_0()


    sep_token_id = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)
    cls_token_id = tokenizer.convert_tokens_to_ids(tokenizer.cls_token)


    neighbor_token_list = []
    neighbor_token_list.append(cls_token_id)

    target_token=tokenizer.convert_tokens_to_ids(tokenizer.tokenize(name))



    for neighbor in neighbors:


        neighbor_token = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(neighbor[0]))
        neighbor_token_list.extend(neighbor_token)
        neighbor_token_list.append(sep_token_id)


    # print(tokenizer.convert_ids_to_tokens(neighbor_token_list))

    #--------------------------------------------


    tokens = torch.Tensor(neighbor_token_list).unsqueeze(0).long()


    # input "new neighbor sentence"-> model -> output
    outputs = model(tokens, spatial_position_list_x=torch.zeros(tokens.shape), spatial_position_list_y=torch.zeros(tokens.shape))



    # print(outputs.last_hidden_state)

    # print(outputs.last_hidden_state.shape)


    targetIndex=list(range(1, len(target_token)+1))

    # #tensor -> tensor
    # get (1, len(target_token), 768) -> (1, 1, 768)
    res=cutSlices(outputs.last_hidden_state, [targetIndex])





    return res



def cosine_similarity(target_feature, candidate_feature):

    target_feature = target_feature.squeeze()
    candidate_feature = candidate_feature.squeeze()

    dot_product = torch.dot(target_feature, candidate_feature)
    
    target = torch.norm(target_feature)
    candidate = torch.norm(candidate_feature)
    
    similarity = dot_product / (target * candidate)
    
    return similarity.item() 


@st.cache_data

def getCSV():
    dataset = pd.read_csv('geohash.csv')
    return dataset

@st.cache_data

def getModel1():
    # Model name from Hugging Face model hub
    model_name = "zekun-li/geolm-base-toponym-recognition"

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForTokenClassification.from_pretrained(model_name)

    return tokenizer,model

def getModel1_0():
    # Model name from Hugging Face model hub
    model_name = "zekun-li/geolm-base-toponym-recognition"

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = GeoLMModel.from_pretrained(model_name)
    return tokenizer,model



def getModel2():

    model_name = "zekun-li/geolm-base-cased"

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model = GeoLMModel.from_pretrained(model_name)

    return tokenizer,model
 

def showing(df):

    m = folium.Map(location=[df['lat'].mean(), df['lon'].mean()], zoom_start=5)

    size_scale = 100  
    color_scale = 255  
    for i in range(len(df)):
        lat, lon, prob = df.iloc[i]['lat'], df.iloc[i]['lon'], df.iloc[i]['prob']
        
        size = int(prob**2 * size_scale )
        color = int(prob**2 * color_scale)
        
        folium.CircleMarker(
            location=[lat, lon],
            radius=size,
            color=f'#{color:02X}0000',
            fill=True,
            fill_color=f'#{color:02X}0000'
        ).add_to(m)

    m.save("map.html")

    with open("map.html", "r", encoding="utf-8") as f:
        map_html = f.read()

    st.components.v1.html(map_html, height=600)


def mapping(selected_place,locations, sentence_info):
    location_index = locations.index(selected_place)
    print(location_index)

    df = pd.DataFrame()

    # get same name for "Beijing" in geonames
    same_name_embedding=search_geonames(selected_place, df)


    sim_matrix=[]
    print(sim_matrix)


    same_name_embedding=torch.tensor(same_name_embedding)
    # loop each "Beijing"
    for i in range(same_name_embedding.size(1)):
        print((sentence_info[:, location_index, :]).shape)
        print((same_name_embedding[:, i, :]).shape)

        similarities = cosine_similarity(sentence_info[:, location_index, :], same_name_embedding[:, i, :])
        sim_matrix.append(similarities)

    # print("Cosine Similarity Matrix:")
    # print(sim_matrix)
    
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    prob_matrix = sigmoid(np.array(sim_matrix))

    
    df['prob'] = prob_matrix


    print(df)

    showing(df)



def show_on_map():



    input = st.text_area("Enter a sentence:", height=200)

    st.button("Submit")

    sentence_info= MLearningFormInput(input)

    print("sentence info: ")
    print(sentence_info)
    print(sentence_info.shape)


     # input: a sentence  -> output : locations 
    locations=getLocationName(input)

    # 1. input: a sentence  ->  output: tensor (1sentence_info
    selected_place = st.selectbox("Select a location:", locations)
    
    if selected_place is not None:

        mapping(selected_place, locations, sentence_info)




if __name__ == "__main__":


    dataset = getCSV()

    show_on_map()
        
     
    # # can be hidding.............................................................
    
    # #len: 80
    # input= 'Minneapolis, officially the City of Minneapolis, is a city in the state of Minnesota and the county seat of Hennepin County. making it the largest city in Minnesota and the 46th-most-populous in the  United States. Nicknamed the "City of Lakes", Minneapolis is abundant in water,  with thirteen lakes, wetlands, the Mississippi River, creeks, and waterfalls.'


    # 1. input: a sentence  ->  output: tensor (1,num_locations,768)
    # sentence_info= MLearningFormInput(input)

    # print("sentence info: ")
    # print(sentence_info)
    # print(sentence_info.shape)



    # # input: a sentence  -> output : locations 
    # locations=getLocationName(input)

    # print(locations)

    # j=0


    # k=0

    # for location in locations:

    #     if k==0:

    #         # input: locations -> output: search in geoname(get top 10 items) -> loop each item -> num_location x 10 x (1,1,768)
    #         same_name_embedding=search_geonames(location)

    #         sim_matrix=[]
    #         print(sim_matrix)





    #         same_name_embedding=torch.tensor(same_name_embedding)
    #         # loop each "Beijing"
    #         for i in range(same_name_embedding.size(1)):
    #             # print((sentence_info[:, j, :]).shape)
    #             # print((same_name_embedding[:, i, :]).shape)

    #             similarities = cosine_similarity(sentence_info[:, j, :], same_name_embedding[:, i, :])
    #             sim_matrix.append(similarities)



    #         j=j+1

            
    #         print("Cosine Similarity Matrix:")
    #         print(sim_matrix)

    #         k=1

    #     else:
    #         break