|
import os |
|
import gc |
|
import tempfile |
|
import uuid |
|
import logging |
|
|
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
|
|
from gitingest import ingest |
|
from llama_index.core import Settings, PromptTemplate, VectorStoreIndex, SimpleDirectoryReader |
|
from llama_index.core.node_parser import MarkdownNodeParser |
|
from llama_index.llms.sambanovasystems import SambaNovaCloud |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GitHubRAGError(Exception): |
|
"""Custom exception for GitHub RAG application errors""" |
|
pass |
|
|
|
|
|
SAMBANOVA_API_KEY = os.getenv("SAMBANOVA_API_KEY") |
|
if not SAMBANOVA_API_KEY: |
|
raise ValueError("SAMBANOVA_API_KEY is not set in environment variables") |
|
|
|
|
|
if "id" not in st.session_state: |
|
st.session_state.id = uuid.uuid4() |
|
st.session_state.file_cache = {} |
|
st.session_state.messages = [] |
|
|
|
session_id = st.session_state.id |
|
|
|
@st.cache_resource |
|
def load_llm(): |
|
""" |
|
Load and cache the SambaNova LLM predictor |
|
""" |
|
return SambaNovaCloud( |
|
api_key=SAMBANOVA_API_KEY, |
|
model="DeepSeek-R1-Distill-Llama-70B", |
|
temperature=0.1, |
|
top_p=0.1, |
|
) |
|
|
|
|
|
def reset_chat(): |
|
"""Clear chat history and free resources""" |
|
st.session_state.messages = [] |
|
gc.collect() |
|
|
|
|
|
def process_with_gitingets(github_url: str): |
|
"""Use gitingest to fetch and summarize the GitHub repository""" |
|
summary, tree, content = ingest(github_url) |
|
return summary, tree, content |
|
|
|
|
|
with st.sidebar: |
|
st.header("Add your GitHub repository!") |
|
github_url = st.text_input( |
|
"GitHub repo URL", placeholder="https://github.com/user/repo" |
|
) |
|
load_btn = st.button("Load Repository") |
|
|
|
if github_url and load_btn: |
|
try: |
|
repo_name = github_url.rstrip("/").split("/")[-1] |
|
cache_key = f"{session_id}-{repo_name}" |
|
|
|
|
|
if cache_key not in st.session_state.file_cache: |
|
with st.spinner("Processing repository..."): |
|
summary, tree, content = process_with_gitingets(github_url) |
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
md_path = os.path.join(tmpdir, f"{repo_name}.md") |
|
with open(md_path, "w", encoding="utf-8") as f: |
|
f.write(content) |
|
|
|
loader = SimpleDirectoryReader(input_dir=tmpdir) |
|
docs = loader.load_data() |
|
|
|
embed_model = HuggingFaceEmbedding( |
|
model_name="nomic-ai/nomic-embed-text-v2-moe", |
|
trust_remote_code=True, |
|
) |
|
Settings.embed_model = embed_model |
|
|
|
llm_predictor = load_llm() |
|
Settings.llm = llm_predictor |
|
|
|
node_parser = MarkdownNodeParser() |
|
index = VectorStoreIndex.from_documents( |
|
documents=docs, |
|
transformations=[node_parser], |
|
show_progress=True, |
|
) |
|
|
|
qa_prompt = PromptTemplate( |
|
"You are an AI assistant specialized in analyzing GitHub repositories.\n" |
|
"Repository structure:\n{tree}\n---\n" |
|
"Context:\n{context_str}\n---\n" |
|
"Question: {query_str}\nAnswer:" |
|
) |
|
query_engine = index.as_query_engine(streaming=True) |
|
query_engine.update_prompts({ |
|
"response_synthesizer:text_qa_template": qa_prompt |
|
}) |
|
|
|
st.session_state.file_cache[cache_key] = (query_engine, tree) |
|
st.success("Repository loaded and indexed. Ready to chat!") |
|
else: |
|
st.info("Repository already loaded.") |
|
except Exception as e: |
|
st.error(f"Error loading repository: {e}") |
|
logger.error(f"Load error: {e}") |
|
|
|
|
|
col1, col2 = st.columns([6, 1]) |
|
with col1: |
|
st.header("Chat with GitHub RAG") |
|
with col2: |
|
st.button("Clear Chat ↺", on_click=reset_chat) |
|
|
|
|
|
for msg in st.session_state.messages: |
|
with st.chat_message(msg["role"]): |
|
st.markdown(msg["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Ask a question about the repository..."): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
repo_name = github_url.rstrip("/").split("/")[-1] |
|
cache_key = f"{session_id}-{repo_name}" |
|
|
|
if cache_key not in st.session_state.file_cache: |
|
st.error("Please load a repository first!") |
|
else: |
|
query_engine, tree = st.session_state.file_cache[cache_key] |
|
with st.chat_message("assistant"): |
|
placeholder = st.empty() |
|
response_text = "" |
|
try: |
|
response = query_engine.query(prompt) |
|
if hasattr(response, 'response_gen'): |
|
for chunk in response.response_gen: |
|
response_text += chunk |
|
placeholder.markdown(response_text + "▌") |
|
else: |
|
response_text = str(response) |
|
placeholder.markdown(response_text) |
|
except GitHubRAGError as e: |
|
st.error(str(e)) |
|
logger.error(f"Error in chat processing: {e}") |
|
response_text = "Sorry, I couldn't process that request." |
|
except Exception as e: |
|
st.error("An unexpected error occurred while processing your query") |
|
logger.error(f"Unexpected error in chat: {e}") |
|
response_text = "Sorry, something went wrong." |
|
placeholder.markdown(response_text) |
|
st.session_state.messages.append({"role": "assistant", "content": response_text}) |
|
|