Robin Chiu commited on
Commit
28536b2
·
1 Parent(s): 20acf5b

add the law tools

Browse files
Files changed (6) hide show
  1. agent.json +2 -2
  2. app.py +5 -3
  3. requirements.txt +6 -1
  4. tools/law_rag_query.py +54 -0
  5. tools/law_tool.py +31 -0
  6. tools/web_search.py +0 -27
agent.json CHANGED
@@ -1,6 +1,7 @@
1
  {
2
  "tools": [
3
- "web_search",
 
4
  "final_answer"
5
  ],
6
  "model": {
@@ -39,7 +40,6 @@
39
  "name": null,
40
  "description": null,
41
  "requirements": [
42
- "duckduckgo_search",
43
  "smolagents"
44
  ],
45
  "authorized_imports": [
 
1
  {
2
  "tools": [
3
+ "law_tool",
4
+ "law_rag_query",
5
  "final_answer"
6
  ],
7
  "model": {
 
40
  "name": null,
41
  "description": null,
42
  "requirements": [
 
43
  "smolagents"
44
  ],
45
  "authorized_imports": [
app.py CHANGED
@@ -5,7 +5,8 @@ from smolagents import GradioUI, CodeAgent, HfApiModel
5
  # Get current directory path
6
  CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
7
 
8
- from tools.web_search import DuckDuckGoSearchTool as WebSearch
 
9
  from tools.final_answer import FinalAnswerTool as FinalAnswer
10
 
11
 
@@ -15,7 +16,8 @@ model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
15
  provider=None,
16
  )
17
 
18
- web_search = WebSearch()
 
19
  final_answer = FinalAnswer()
20
 
21
 
@@ -24,7 +26,7 @@ with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream:
24
 
25
  agent = CodeAgent(
26
  model=model,
27
- tools=[web_search],
28
  managed_agents=[],
29
  max_steps=20,
30
  verbosity_level=1,
 
5
  # Get current directory path
6
  CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
7
 
8
+ from tools.law_tool import LawTool
9
+ from tools.law_rag_query import LawRAGQuery
10
  from tools.final_answer import FinalAnswerTool as FinalAnswer
11
 
12
 
 
16
  provider=None,
17
  )
18
 
19
+ law_tool = LawTool()
20
+ law_rag_query = LawRAGQuery()
21
  final_answer = FinalAnswer()
22
 
23
 
 
26
 
27
  agent = CodeAgent(
28
  model=model,
29
+ tools=[law_tool, law_rag_query],
30
  managed_agents=[],
31
  max_steps=20,
32
  verbosity_level=1,
requirements.txt CHANGED
@@ -1,3 +1,8 @@
1
  duckduckgo_search
2
  smolagents
3
- gradio[oauth]==5.23.2
 
 
 
 
 
 
1
  duckduckgo_search
2
  smolagents
3
+ gradio
4
+ datasets
5
+ langchain
6
+ langchain-chroma
7
+ langchain-text-splitters
8
+ datasets
tools/law_rag_query.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from langchain_chroma import Chroma
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from datasets import load_dataset
6
+ import os
7
+
8
+ class LawRAGQuery(Tool):
9
+ name = "law_rag_query"
10
+ description = """
11
+ This is a tool that returns law content by input a question. It will find the related law and return."""
12
+ inputs = {
13
+ "question": {
14
+ "type": "string",
15
+ "description": "the question",
16
+ }
17
+ }
18
+ output_type = "array"
19
+ vectorstore = None
20
+
21
+ def __init__(self):
22
+ dataset = load_dataset("robin0307/law", split='train')
23
+ law = dataset.to_pandas()
24
+ self.vectorstore = self.get_vectorstore("thenlper/gte-large-zh", list(law['content']))
25
+ super().__init__()
26
+
27
+ def get_vectorstore(self, model_path, data_list, path="chroma_db"):
28
+ embeddings = HuggingFaceEmbeddings(model_name=model_path)
29
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=50)
30
+ chunks = [text_splitter.split_text(text) for text in data_list]
31
+
32
+ # Flatten the list
33
+ if os.path.isdir(path):
34
+ vectorstore = Chroma(embedding_function=embeddings, persist_directory=path)
35
+ else:
36
+ splits = [chunk for sublist in chunks for chunk in sublist]
37
+ vectorstore = Chroma.from_texts(texts=splits, embedding=embeddings, persist_directory=path)
38
+ print("count:", vectorstore._collection.count())
39
+ return vectorstore
40
+
41
+ def get_docs(self, input, k=10):
42
+ retrieved_documents = self.vectorstore.similarity_search_with_score(input, k=50)
43
+
44
+ results = []
45
+ for i, (doc, score) in enumerate(retrieved_documents):
46
+ results.append((doc.page_content, score))
47
+ if i >= k:
48
+ break
49
+ return results
50
+
51
+ def forward(self, question: str):
52
+ docs = self.get_docs(question)
53
+ return docs
54
+
tools/law_tool.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from datasets import load_dataset
3
+
4
+ class LawTool(Tool):
5
+ name = "law_tool"
6
+ description = """
7
+ This is a tool that returns law content by input the category and number."""
8
+ inputs = {
9
+ "category": {
10
+ "type": "string",
11
+ "description": "the law category (such as 民法, 中華民國刑法, 民事訴訟法, 刑事訴訟法, 律師法 etc)",
12
+ },
13
+ "number": {
14
+ "type": "integer",
15
+ "description": "the law number (such as 23)"
16
+ }
17
+ }
18
+ output_type = "string"
19
+ law = None
20
+
21
+ def __init__(self):
22
+ dataset = load_dataset("robin0307/law", split='train')
23
+ self.law = dataset.to_pandas()
24
+ super().__init__()
25
+
26
+ def forward(self, category: str, number: int):
27
+ if category == "刑法":
28
+ category = "中華民國刑法"
29
+
30
+ data = self.law.loc[(self.law["category"]==category) & (self.law["number"]==number), "content"].values[0]
31
+ return data
tools/web_search.py DELETED
@@ -1,27 +0,0 @@
1
- from typing import Any, Optional
2
- from smolagents.tools import Tool
3
- import duckduckgo_search
4
-
5
- class DuckDuckGoSearchTool(Tool):
6
- name = "web_search"
7
- description = "Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."
8
- inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}}
9
- output_type = "string"
10
-
11
- def __init__(self, max_results=10, **kwargs):
12
- super().__init__()
13
- self.max_results = max_results
14
- try:
15
- from duckduckgo_search import DDGS
16
- except ImportError as e:
17
- raise ImportError(
18
- "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
19
- ) from e
20
- self.ddgs = DDGS(**kwargs)
21
-
22
- def forward(self, query: str) -> str:
23
- results = self.ddgs.text(query, max_results=self.max_results)
24
- if len(results) == 0:
25
- raise Exception("No results found! Try a less restrictive/shorter query.")
26
- postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
27
- return "## Search Results\n\n" + "\n\n".join(postprocessed_results)