File size: 1,993 Bytes
28536b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from smolagents import Tool
from langchain_chroma import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from datasets import load_dataset
import os

class LawRAGQuery(Tool):
    name = "law_rag_query"
    description = """
    This is a tool that returns law content by input a question. It will find the related law and return."""
    inputs = {
        "question": {
            "type": "string",
            "description": "the question",
        }
    }
    output_type = "array"
    vectorstore = None

    def __init__(self):
        dataset = load_dataset("robin0307/law", split='train')
        law = dataset.to_pandas()
        self.vectorstore = self.get_vectorstore("thenlper/gte-large-zh", list(law['content']))
        super().__init__()

    def get_vectorstore(self, model_path, data_list, path="chroma_db"):
        embeddings = HuggingFaceEmbeddings(model_name=model_path)
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=50)
        chunks = [text_splitter.split_text(text) for text in data_list]

        # Flatten the list
        if os.path.isdir(path):
            vectorstore = Chroma(embedding_function=embeddings, persist_directory=path)
        else:
            splits = [chunk for sublist in chunks for chunk in sublist]
            vectorstore = Chroma.from_texts(texts=splits, embedding=embeddings, persist_directory=path)
        print("count:", vectorstore._collection.count())
        return vectorstore

    def get_docs(self, input, k=10):
        retrieved_documents = self.vectorstore.similarity_search_with_score(input, k=50)

        results = []
        for i, (doc, score) in enumerate(retrieved_documents):
            results.append((doc.page_content, score))
            if i >= k:
                break
        return results

    def forward(self, question: str):
        docs = self.get_docs(question)
        return docs