|
import os |
|
import json |
|
import firebase_admin |
|
from firebase_admin import credentials, db |
|
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration |
|
import gradio as gr |
|
|
|
|
|
firebase_credential = os.getenv("FIREBASE_CREDENTIALS") |
|
if not firebase_credential: |
|
raise RuntimeError("FIREBASE_CREDENTIALS environment variable is not set.") |
|
|
|
|
|
with open("serviceAccountKey.json", "w") as f: |
|
f.write(firebase_credential) |
|
|
|
|
|
cred = credentials.Certificate("serviceAccountKey.json") |
|
firebase_admin.initialize_app(cred, {"databaseURL": "https://your-database-name.firebaseio.com/"}) |
|
|
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") |
|
retriever = RagRetriever.from_pretrained("facebook/rag-token-base", use_dummy_dataset=True) |
|
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base") |
|
|
|
|
|
def generate_answer(question, context=""): |
|
|
|
inputs = tokenizer(question, return_tensors="pt") |
|
|
|
|
|
|
|
retrieved_docs = retriever(question=question, input_ids=inputs["input_ids"]) |
|
|
|
|
|
outputs = model.generate(input_ids=inputs["input_ids"], |
|
context_input_ids=retrieved_docs["context_input_ids"]) |
|
|
|
|
|
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return answer |
|
|
|
|
|
def dashboard(question): |
|
|
|
answer = generate_answer(question) |
|
return answer |
|
|
|
|
|
interface = gr.Interface(fn=dashboard, inputs="text", outputs="text") |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|
|
|