LawAgent / tools /law_rag_query.py
Robin Chiu
add the law tools
28536b2
raw
history blame
1.99 kB
from smolagents import Tool
from langchain_chroma import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from datasets import load_dataset
import os
class LawRAGQuery(Tool):
name = "law_rag_query"
description = """
This is a tool that returns law content by input a question. It will find the related law and return."""
inputs = {
"question": {
"type": "string",
"description": "the question",
}
}
output_type = "array"
vectorstore = None
def __init__(self):
dataset = load_dataset("robin0307/law", split='train')
law = dataset.to_pandas()
self.vectorstore = self.get_vectorstore("thenlper/gte-large-zh", list(law['content']))
super().__init__()
def get_vectorstore(self, model_path, data_list, path="chroma_db"):
embeddings = HuggingFaceEmbeddings(model_name=model_path)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=50)
chunks = [text_splitter.split_text(text) for text in data_list]
# Flatten the list
if os.path.isdir(path):
vectorstore = Chroma(embedding_function=embeddings, persist_directory=path)
else:
splits = [chunk for sublist in chunks for chunk in sublist]
vectorstore = Chroma.from_texts(texts=splits, embedding=embeddings, persist_directory=path)
print("count:", vectorstore._collection.count())
return vectorstore
def get_docs(self, input, k=10):
retrieved_documents = self.vectorstore.similarity_search_with_score(input, k=50)
results = []
for i, (doc, score) in enumerate(retrieved_documents):
results.append((doc.page_content, score))
if i >= k:
break
return results
def forward(self, question: str):
docs = self.get_docs(question)
return docs