|
from smolagents import Tool |
|
from langchain_chroma import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from datasets import load_dataset |
|
import os |
|
|
|
class LawRAGQuery(Tool): |
|
name = "law_rag_query" |
|
description = """ |
|
This is a tool that returns law content by input a question. It will find the related law and return.""" |
|
inputs = { |
|
"question": { |
|
"type": "string", |
|
"description": "the question", |
|
} |
|
} |
|
output_type = "array" |
|
vectorstore = None |
|
|
|
def __init__(self): |
|
dataset = load_dataset("robin0307/law", split='train') |
|
law = dataset.to_pandas() |
|
self.vectorstore = self.get_vectorstore("thenlper/gte-large-zh", list(law['content'])) |
|
super().__init__() |
|
|
|
def get_vectorstore(self, model_path, data_list, path="chroma_db"): |
|
embeddings = HuggingFaceEmbeddings(model_name=model_path) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=50) |
|
chunks = [text_splitter.split_text(text) for text in data_list] |
|
|
|
|
|
if os.path.isdir(path): |
|
vectorstore = Chroma(embedding_function=embeddings, persist_directory=path) |
|
else: |
|
splits = [chunk for sublist in chunks for chunk in sublist] |
|
vectorstore = Chroma.from_texts(texts=splits, embedding=embeddings, persist_directory=path) |
|
print("count:", vectorstore._collection.count()) |
|
return vectorstore |
|
|
|
def get_docs(self, input, k=10): |
|
retrieved_documents = self.vectorstore.similarity_search_with_score(input, k=50) |
|
|
|
results = [] |
|
for i, (doc, score) in enumerate(retrieved_documents): |
|
results.append((doc.page_content, score)) |
|
if i >= k: |
|
break |
|
return results |
|
|
|
def forward(self, question: str): |
|
docs = self.get_docs(question) |
|
return docs |
|
|