Robin Chiu
commited on
Commit
·
28536b2
1
Parent(s):
20acf5b
add the law tools
Browse files- agent.json +2 -2
- app.py +5 -3
- requirements.txt +6 -1
- tools/law_rag_query.py +54 -0
- tools/law_tool.py +31 -0
- tools/web_search.py +0 -27
agent.json
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
{
|
2 |
"tools": [
|
3 |
-
"
|
|
|
4 |
"final_answer"
|
5 |
],
|
6 |
"model": {
|
@@ -39,7 +40,6 @@
|
|
39 |
"name": null,
|
40 |
"description": null,
|
41 |
"requirements": [
|
42 |
-
"duckduckgo_search",
|
43 |
"smolagents"
|
44 |
],
|
45 |
"authorized_imports": [
|
|
|
1 |
{
|
2 |
"tools": [
|
3 |
+
"law_tool",
|
4 |
+
"law_rag_query",
|
5 |
"final_answer"
|
6 |
],
|
7 |
"model": {
|
|
|
40 |
"name": null,
|
41 |
"description": null,
|
42 |
"requirements": [
|
|
|
43 |
"smolagents"
|
44 |
],
|
45 |
"authorized_imports": [
|
app.py
CHANGED
@@ -5,7 +5,8 @@ from smolagents import GradioUI, CodeAgent, HfApiModel
|
|
5 |
# Get current directory path
|
6 |
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
7 |
|
8 |
-
from tools.
|
|
|
9 |
from tools.final_answer import FinalAnswerTool as FinalAnswer
|
10 |
|
11 |
|
@@ -15,7 +16,8 @@ model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
|
|
15 |
provider=None,
|
16 |
)
|
17 |
|
18 |
-
|
|
|
19 |
final_answer = FinalAnswer()
|
20 |
|
21 |
|
@@ -24,7 +26,7 @@ with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream:
|
|
24 |
|
25 |
agent = CodeAgent(
|
26 |
model=model,
|
27 |
-
tools=[
|
28 |
managed_agents=[],
|
29 |
max_steps=20,
|
30 |
verbosity_level=1,
|
|
|
5 |
# Get current directory path
|
6 |
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
7 |
|
8 |
+
from tools.law_tool import LawTool
|
9 |
+
from tools.law_rag_query import LawRAGQuery
|
10 |
from tools.final_answer import FinalAnswerTool as FinalAnswer
|
11 |
|
12 |
|
|
|
16 |
provider=None,
|
17 |
)
|
18 |
|
19 |
+
law_tool = LawTool()
|
20 |
+
law_rag_query = LawRAGQuery()
|
21 |
final_answer = FinalAnswer()
|
22 |
|
23 |
|
|
|
26 |
|
27 |
agent = CodeAgent(
|
28 |
model=model,
|
29 |
+
tools=[law_tool, law_rag_query],
|
30 |
managed_agents=[],
|
31 |
max_steps=20,
|
32 |
verbosity_level=1,
|
requirements.txt
CHANGED
@@ -1,3 +1,8 @@
|
|
1 |
duckduckgo_search
|
2 |
smolagents
|
3 |
-
gradio
|
|
|
|
|
|
|
|
|
|
|
|
1 |
duckduckgo_search
|
2 |
smolagents
|
3 |
+
gradio
|
4 |
+
datasets
|
5 |
+
langchain
|
6 |
+
langchain-chroma
|
7 |
+
langchain-text-splitters
|
8 |
+
datasets
|
tools/law_rag_query.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import Tool
|
2 |
+
from langchain_chroma import Chroma
|
3 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
4 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
5 |
+
from datasets import load_dataset
|
6 |
+
import os
|
7 |
+
|
8 |
+
class LawRAGQuery(Tool):
|
9 |
+
name = "law_rag_query"
|
10 |
+
description = """
|
11 |
+
This is a tool that returns law content by input a question. It will find the related law and return."""
|
12 |
+
inputs = {
|
13 |
+
"question": {
|
14 |
+
"type": "string",
|
15 |
+
"description": "the question",
|
16 |
+
}
|
17 |
+
}
|
18 |
+
output_type = "array"
|
19 |
+
vectorstore = None
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
dataset = load_dataset("robin0307/law", split='train')
|
23 |
+
law = dataset.to_pandas()
|
24 |
+
self.vectorstore = self.get_vectorstore("thenlper/gte-large-zh", list(law['content']))
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
def get_vectorstore(self, model_path, data_list, path="chroma_db"):
|
28 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_path)
|
29 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=50)
|
30 |
+
chunks = [text_splitter.split_text(text) for text in data_list]
|
31 |
+
|
32 |
+
# Flatten the list
|
33 |
+
if os.path.isdir(path):
|
34 |
+
vectorstore = Chroma(embedding_function=embeddings, persist_directory=path)
|
35 |
+
else:
|
36 |
+
splits = [chunk for sublist in chunks for chunk in sublist]
|
37 |
+
vectorstore = Chroma.from_texts(texts=splits, embedding=embeddings, persist_directory=path)
|
38 |
+
print("count:", vectorstore._collection.count())
|
39 |
+
return vectorstore
|
40 |
+
|
41 |
+
def get_docs(self, input, k=10):
|
42 |
+
retrieved_documents = self.vectorstore.similarity_search_with_score(input, k=50)
|
43 |
+
|
44 |
+
results = []
|
45 |
+
for i, (doc, score) in enumerate(retrieved_documents):
|
46 |
+
results.append((doc.page_content, score))
|
47 |
+
if i >= k:
|
48 |
+
break
|
49 |
+
return results
|
50 |
+
|
51 |
+
def forward(self, question: str):
|
52 |
+
docs = self.get_docs(question)
|
53 |
+
return docs
|
54 |
+
|
tools/law_tool.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import Tool
|
2 |
+
from datasets import load_dataset
|
3 |
+
|
4 |
+
class LawTool(Tool):
|
5 |
+
name = "law_tool"
|
6 |
+
description = """
|
7 |
+
This is a tool that returns law content by input the category and number."""
|
8 |
+
inputs = {
|
9 |
+
"category": {
|
10 |
+
"type": "string",
|
11 |
+
"description": "the law category (such as 民法, 中華民國刑法, 民事訴訟法, 刑事訴訟法, 律師法 etc)",
|
12 |
+
},
|
13 |
+
"number": {
|
14 |
+
"type": "integer",
|
15 |
+
"description": "the law number (such as 23)"
|
16 |
+
}
|
17 |
+
}
|
18 |
+
output_type = "string"
|
19 |
+
law = None
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
dataset = load_dataset("robin0307/law", split='train')
|
23 |
+
self.law = dataset.to_pandas()
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
def forward(self, category: str, number: int):
|
27 |
+
if category == "刑法":
|
28 |
+
category = "中華民國刑法"
|
29 |
+
|
30 |
+
data = self.law.loc[(self.law["category"]==category) & (self.law["number"]==number), "content"].values[0]
|
31 |
+
return data
|
tools/web_search.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
from typing import Any, Optional
|
2 |
-
from smolagents.tools import Tool
|
3 |
-
import duckduckgo_search
|
4 |
-
|
5 |
-
class DuckDuckGoSearchTool(Tool):
|
6 |
-
name = "web_search"
|
7 |
-
description = "Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."
|
8 |
-
inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}}
|
9 |
-
output_type = "string"
|
10 |
-
|
11 |
-
def __init__(self, max_results=10, **kwargs):
|
12 |
-
super().__init__()
|
13 |
-
self.max_results = max_results
|
14 |
-
try:
|
15 |
-
from duckduckgo_search import DDGS
|
16 |
-
except ImportError as e:
|
17 |
-
raise ImportError(
|
18 |
-
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
|
19 |
-
) from e
|
20 |
-
self.ddgs = DDGS(**kwargs)
|
21 |
-
|
22 |
-
def forward(self, query: str) -> str:
|
23 |
-
results = self.ddgs.text(query, max_results=self.max_results)
|
24 |
-
if len(results) == 0:
|
25 |
-
raise Exception("No results found! Try a less restrictive/shorter query.")
|
26 |
-
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
27 |
-
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|