import requests
from requests.auth import HTTPDigestAuth, HTTPBasicAuth
import ssl
import json

from joblib import Memory

cachedir = 'cached'
mem = Memory(cachedir, verbose=False)



@mem.cache
def execute_query(endpoint, query, auth):
    headers = {
        'Content-Type': 'application/x-www-form-urlencoded',
        'Accept': 'application/sparql-results+json'
    }

    params = {
        'query': query
    }

    response = requests.post(endpoint, headers=headers, params=params, auth=auth)

    if response.status_code != 200:
        raise Exception(f"Failed to execute query: {response.text}")

    return response.text


def execute_query_no_cache(endpoint, query, auth):
    headers = {
        'Content-Type': 'application/x-www-form-urlencoded',
        'Accept': 'application/sparql-results+json'
    }

    params = {
        'query': query
    }

    response = requests.post(endpoint, headers=headers, params=params, auth=auth)

    if response.status_code != 200:
        raise Exception(f"Failed to execute query: {response.text}")

    return response.text


def sparqlQuery(endpoint, query, usernameVirt, passwordVirt, USE_CACHE=True):
    """
    Make a SPARQL query to a Virtuoso SPARQL endpoint.

    Args:
    - endpoint (str): The URL of the Virtuoso SPARQL endpoint.
    - query (str): The SPARQL query to execute.
    - username (str): The username for authentication.
    - password (str): The password for authentication.

    Returns:
    - responseText (requests.Response): The responseText from the Virtuoso SPARQL endpoint.
    """

    # Use SSL context to establish a secure connection
    ssl_context = ssl.create_default_context()

    # Try HTTP Digest authentication
    try:
        auth = HTTPDigestAuth(usernameVirt, passwordVirt)
        if USE_CACHE:
            responseText = execute_query(endpoint, query, auth)
        else:
            responseText = execute_query_no_cache(endpoint, query, auth)
    except Exception as e:
        print(f"HTTP Digest authentication failed: {str(e)}")

        # Fallback to Basic Auth
        try:
            auth = HTTPBasicAuth(usernameVirt, passwordVirt)
            if USE_CACHE:
                responseText = execute_query(endpoint, query, auth)
            else:
                responseText = execute_query_no_cache(endpoint, query, auth)
        except Exception as e:
            print(f"Basic Auth failed: {str(e)}")
            return None

    return responseText


#
# def sparqlQuery(endpoint, query, usernameVirt, passwordVirt):
#     """
#     Make a SPARQL query to a Virtuoso SPARQL endpoint.
#
#     Args:
#     - endpoint (str): The URL of the Virtuoso SPARQL endpoint.
#     - query (str): The SPARQL query to execute.
#     - username (str): The username for authentication.
#     - password (str): The password for authentication.
#
#     Returns:
#     - response (requests.Response): The response from the Virtuoso SPARQL endpoint.
#     """
#     headers = {
#         'Content-Type': 'application/x-www-form-urlencoded',
#         'Accept': 'application/sparql-results+json'
#     }
#
#     params = {
#         'query': query
#     }
#
#     # Use SSL context to establish a secure connection
#     ssl_context = ssl.create_default_context()
#
#     # Try HTTP Digest authentication
#     try:
#         auth = HTTPDigestAuth(usernameVirt, passwordVirt)
#         response = requests.post(endpoint, headers=headers, params=params, auth=auth)
#     except Exception as e:
#         print(f"HTTP Digest authentication failed: {str(e)}")
#
#         # Fallback to Basic Auth
#         try:
#             auth = HTTPBasicAuth(usernameVirt, passwordVirt)
#             response = requests.post(endpoint, headers=headers, params=params, auth=auth)
#         except Exception as e:
#             print(f"Basic Auth failed: {str(e)}")
#             return None
#
#     if response.status_code != 200:
#         raise Exception(f"Failed to execute query: {response.text}")
#
#     return response


if __name__ == '__main__':
    # Example usage
    endpoint = 'https://api-vast.jrc.service.ec.europa.eu/sparql'

    VirtuosoUsername = 'dba'
    VirtuosoPassword = ''
    Virtuosokey_filename = 'VIRTUOSO-dba.key'

    USE_CACHE = False # True or False

    #############


    #choices = ['SNOMED', 'LOINC', 'ICD10', 'MESH', 'NCIT']  # restricts the input to these values only
    choices = ["AI", "AIO", "AEO", "BFO", "BIM", "BCGO", "CL", "CHIRO", "CHEBI", "DCM", "FMA", "GO", "GENO",
             "GeoSPARQL", "HL7", "DOID", "HP", "HP_O", "IDO", "IAO", "ICD10", "LOINC", "MESH",
             "MONDO", "NCIT", "NCBITAXON", "NCBITaxon_", "NIFCELL", "NIFSTD", "GML", "OBCS", "OCHV", "OHPI",
             "OPB", "TRANS", "PLOSTHES", "RADLEX", "RO", "STY", "SO", "SNOMED", "STATO",
             "SYMP", "FoodOn", "UBERON", "ORDO", "HOOM", "VO", "OGMS", "EuroSciVoc"]

    # Construct the FROM clauses
    from_clauses = ' '.join([f"FROM <{choice}>" for choice in choices])

    #word = "acute sinusitis"
    word = "pure mathematics"
    # Construct the full SPARQL query
    query = f"""
            prefix skosxl: <http://www.w3.org/2008/05/skos-xl#> 
            SELECT ?concept ?label (COUNT(?edge) AS ?score)
            {from_clauses}
            WHERE {{
              ?concept skos:prefLabel|rdfs:label|skos:altLabel|skosxl:literalForm|obo:hasRelatedSynonym ?label .
              FILTER (LCASE(STR(?label)) = "{word.lower()}")
              ?concept ?edge ?o .
            }}
            GROUP BY ?concept ?label
            ORDER BY DESC(?score)
        """

    print(query)




    ###############

    if Virtuosokey_filename:
        fkeyname = Virtuosokey_filename
        with open(fkeyname) as f:
            VirtuosoPassword = f.read()

    responseText = sparqlQuery(endpoint, query, VirtuosoUsername, VirtuosoPassword, USE_CACHE)

    # Parse the response as JSON

    results = json.loads(responseText)

    # Print the results
    if len(results) > 0 and results['results']['bindings']:
        for result in results['results']['bindings']:
            print(result)
    else:
        print("!!! VIRTUOSO NO RESULTS !!!")