from smolagents import Tool from langchain_chroma import Chroma from langchain.embeddings import HuggingFaceEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from datasets import load_dataset import os class LawRAGQuery(Tool): name = "law_rag_query" description = """ This is a tool that returns law content by input a question. It will find the related law and return.""" inputs = { "question": { "type": "string", "description": "the question", } } output_type = "array" vectorstore = None def __init__(self): dataset = load_dataset("robin0307/law", split='train') law = dataset.to_pandas() self.vectorstore = self.get_vectorstore("thenlper/gte-large-zh", list(law['content'])) super().__init__() def get_vectorstore(self, model_path, data_list, path="chroma_db"): embeddings = HuggingFaceEmbeddings(model_name=model_path) text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=50) chunks = [text_splitter.split_text(text) for text in data_list] # Flatten the list if os.path.isdir(path): vectorstore = Chroma(embedding_function=embeddings, persist_directory=path) else: splits = [chunk for sublist in chunks for chunk in sublist] vectorstore = Chroma.from_texts(texts=splits, embedding=embeddings, persist_directory=path) print("count:", vectorstore._collection.count()) return vectorstore def get_docs(self, input, k=10): retrieved_documents = self.vectorstore.similarity_search_with_score(input, k=50) results = [] for i, (doc, score) in enumerate(retrieved_documents): results.append((doc.page_content, score)) if i >= k: break return results def forward(self, question: str): docs = self.get_docs(question) return docs