chore: remove langchain in tools (#3247)

This commit is contained in:
Yeuoly 2024-04-09 19:28:22 +08:00 committed by GitHub
parent 2a6b7d57cb
commit e635f3dc1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 645 additions and 29 deletions

View File

@ -1,12 +1,32 @@
import os
from typing import Any, Optional, Union
from typing import Any, Optional, TextIO, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from pydantic import BaseModel
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file
class DifyAgentCallbackHandler(BaseModel):
"""Callback Handler that prints to std out."""
color: Optional[str] = ''
current_loop = 1

View File

@ -1,11 +1,92 @@
from typing import Any
import logging
from typing import Any, Optional
from langchain.utilities import ArxivAPIWrapper
import arxiv
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
logger = logging.getLogger(__name__)
class ArxivAPIWrapper(BaseModel):
"""Wrapper around ArxivAPI.
To use, you should have the ``arxiv`` python package installed.
https://lukasschwab.me/arxiv.py/index.html
This wrapper will use the Arxiv API to conduct searches and
fetch document summaries. By default, it will return the document summaries
of the top-k results.
It limits the Document content by doc_content_chars_max.
Set doc_content_chars_max=None if you don't want to limit the content size.
Args:
top_k_results: number of the top-scored document used for the arxiv tool
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
load_max_docs: a limit to the number of loaded documents
load_all_available_meta:
if True: the `metadata` of the loaded Documents contains all available
meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
if False: the `metadata` contains only the published date, title,
authors and summary.
doc_content_chars_max: an optional cut limit for the length of a document's
content
Example:
.. code-block:: python
arxiv = ArxivAPIWrapper(
top_k_results = 3,
ARXIV_MAX_QUERY_LENGTH = 300,
load_max_docs = 3,
load_all_available_meta = False,
doc_content_chars_max = 40000
)
arxiv.run("tree of thought llm)
"""
arxiv_search = arxiv.Search #: :meta private:
arxiv_exceptions = (
arxiv.ArxivError,
arxiv.UnexpectedEmptyPageError,
arxiv.HTTPError,
) # :meta private:
top_k_results: int = 3
ARXIV_MAX_QUERY_LENGTH = 300
load_max_docs: int = 100
load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000
def run(self, query: str) -> str:
"""
Performs an arxiv search and A single string
with the publish date, title, authors, and summary
for each article separated by two newlines.
If an error occurs or no documents found, error text
is returned instead. Wrapper for
https://lukasschwab.me/arxiv.py/index.html#Search
Args:
query: a plaintext search query
""" # noqa: E501
try:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
).results()
except self.arxiv_exceptions as ex:
return f"Arxiv exception: {ex}"
docs = [
f"Published: {result.updated.date()}\n"
f"Title: {result.title}\n"
f"Authors: {', '.join(a.name for a in result.authors)}\n"
f"Summary: {result.summary}"
for result in results
]
if docs:
return "\n\n".join(docs)[: self.doc_content_chars_max]
else:
return "No good Arxiv Result was found"
class ArxivSearchInput(BaseModel):
query: str = Field(..., description="Search query.")

View File

@ -1,11 +1,95 @@
from typing import Any
import json
from typing import Any, Optional
from langchain.tools import BraveSearch
import requests
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class BraveSearchWrapper(BaseModel):
"""Wrapper around the Brave search engine."""
api_key: str
"""The API key to use for the Brave search engine."""
search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request."""
base_url = "https://api.search.brave.com/res/v1/web/search"
"""The base URL for the Brave search engine."""
def run(self, query: str) -> str:
"""Query the Brave search engine and return the results as a JSON string.
Args:
query: The query to search for.
Returns: The results as a JSON string.
"""
web_search_results = self._search_request(query=query)
final_results = [
{
"title": item.get("title"),
"link": item.get("url"),
"snippet": item.get("description"),
}
for item in web_search_results
]
return json.dumps(final_results)
def _search_request(self, query: str) -> list[dict]:
headers = {
"X-Subscription-Token": self.api_key,
"Accept": "application/json",
}
req = requests.PreparedRequest()
params = {**self.search_kwargs, **{"q": query}}
req.prepare_url(self.base_url, params)
if req.url is None:
raise ValueError("prepared url is None, this should not happen")
response = requests.get(req.url, headers=headers)
if not response.ok:
raise Exception(f"HTTP error {response.status_code}")
return response.json().get("web", {}).get("results", [])
class BraveSearch(BaseModel):
"""Tool that queries the BraveSearch."""
name = "brave_search"
description = (
"a search engine. "
"useful for when you need to answer questions about current events."
" input should be a search query."
)
search_wrapper: BraveSearchWrapper
@classmethod
def from_api_key(
cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any
) -> "BraveSearch":
"""Create a tool from an api key.
Args:
api_key: The api key to use.
search_kwargs: Any additional kwargs to pass to the search wrapper.
**kwargs: Any additional kwargs to pass to the tool.
Returns:
A tool.
"""
wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {})
return cls(search_wrapper=wrapper, **kwargs)
def _run(
self,
query: str,
) -> str:
"""Use the tool."""
return self.search_wrapper.run(query)
class BraveSearchTool(BuiltinTool):
"""
Tool for performing a search using Brave search engine.
@ -31,7 +115,7 @@ class BraveSearchTool(BuiltinTool):
tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count})
results = tool.run(query)
results = tool._run(query)
if not results:
return self.create_text_message(f"No results found for '{query}' in Tavily")

View File

@ -1,16 +1,147 @@
from typing import Any
from typing import Any, Optional
from langchain.tools import DuckDuckGoSearchRun
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class DuckDuckGoSearchAPIWrapper(BaseModel):
"""Wrapper for DuckDuckGo Search API.
Free and does not require any setup.
"""
region: Optional[str] = "wt-wt"
safesearch: str = "moderate"
time: Optional[str] = "y"
max_results: int = 5
def get_snippets(self, query: str) -> list[str]:
"""Run query through DuckDuckGo and return concatenated results."""
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = ddgs.text(
query,
region=self.region,
safesearch=self.safesearch,
timelimit=self.time,
)
if results is None:
return ["No good DuckDuckGo Search Result was found"]
snippets = []
for i, res in enumerate(results, 1):
if res is not None:
snippets.append(res["body"])
if len(snippets) == self.max_results:
break
return snippets
def run(self, query: str) -> str:
snippets = self.get_snippets(query)
return " ".join(snippets)
def results(
self, query: str, num_results: int, backend: str = "api"
) -> list[dict[str, str]]:
"""Run query through DuckDuckGo and return metadata.
Args:
query: The query to search for.
num_results: The number of results to return.
Returns:
A list of dictionaries with the following keys:
snippet - The description of the result.
title - The title of the result.
link - The link to the result.
"""
from duckduckgo_search import DDGS
with DDGS() as ddgs:
results = ddgs.text(
query,
region=self.region,
safesearch=self.safesearch,
timelimit=self.time,
backend=backend,
)
if results is None:
return [{"Result": "No good DuckDuckGo Search Result was found"}]
def to_metadata(result: dict) -> dict[str, str]:
if backend == "news":
return {
"date": result["date"],
"title": result["title"],
"snippet": result["body"],
"source": result["source"],
"link": result["url"],
}
return {
"snippet": result["body"],
"title": result["title"],
"link": result["href"],
}
formatted_results = []
for i, res in enumerate(results, 1):
if res is not None:
formatted_results.append(to_metadata(res))
if len(formatted_results) == num_results:
break
return formatted_results
class DuckDuckGoSearchRun(BaseModel):
"""Tool that queries the DuckDuckGo search API."""
name = "duckduckgo_search"
description = (
"A wrapper around DuckDuckGo Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
)
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
default_factory=DuckDuckGoSearchAPIWrapper
)
def _run(
self,
query: str,
) -> str:
"""Use the tool."""
return self.api_wrapper.run(query)
class DuckDuckGoSearchResults(BaseModel):
"""Tool that queries the DuckDuckGo search API and gets back json."""
name = "DuckDuckGo Results JSON"
description = (
"A wrapper around Duck Duck Go Search. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query. Output is a JSON array of the query results"
)
num_results: int = 4
api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
default_factory=DuckDuckGoSearchAPIWrapper
)
backend: str = "api"
def _run(
self,
query: str,
) -> str:
"""Use the tool."""
res = self.api_wrapper.results(query, self.num_results, backend=self.backend)
res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
return ", ".join([f"[{rs}]" for rs in res_strs])
class DuckDuckGoInput(BaseModel):
query: str = Field(..., description="Search query.")
class DuckDuckGoSearchTool(BuiltinTool):
"""
Tool for performing a search using DuckDuckGo search engine.
@ -34,7 +165,7 @@ class DuckDuckGoSearchTool(BuiltinTool):
tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput)
result = tool.run(query)
result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id, content=result))

View File

@ -1,16 +1,187 @@
import json
import time
import urllib.error
import urllib.parse
import urllib.request
from typing import Any
from langchain.tools import PubmedQueryRun
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class PubMedAPIWrapper(BaseModel):
"""
Wrapper around PubMed API.
This wrapper will use the PubMed API to conduct searches and fetch
document summaries. By default, it will return the document summaries
of the top-k results of an input search.
Parameters:
top_k_results: number of the top-scored document used for the PubMed tool
load_max_docs: a limit to the number of loaded documents
load_all_available_meta:
if True: the `metadata` of the loaded Documents gets all available meta info
(see https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch)
if False: the `metadata` gets only the most informative fields.
"""
base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
max_retry = 5
sleep_time = 0.2
# Default values for the parameters
top_k_results: int = 3
load_max_docs: int = 25
ARXIV_MAX_QUERY_LENGTH = 300
doc_content_chars_max: int = 2000
load_all_available_meta: bool = False
email: str = "your_email@example.com"
def run(self, query: str) -> str:
"""
Run PubMed search and get the article meta information.
See https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch
It uses only the most informative fields of article meta information.
"""
try:
# Retrieve the top-k results for the query
docs = [
f"Published: {result['pub_date']}\nTitle: {result['title']}\n"
f"Summary: {result['summary']}"
for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH])
]
# Join the results and limit the character count
return (
"\n\n".join(docs)[:self.doc_content_chars_max]
if docs
else "No good PubMed Result was found"
)
except Exception as ex:
return f"PubMed exception: {ex}"
def load(self, query: str) -> list[dict]:
"""
Search PubMed for documents matching the query.
Return a list of dictionaries containing the document metadata.
"""
url = (
self.base_url_esearch
+ "db=pubmed&term="
+ str({urllib.parse.quote(query)})
+ f"&retmode=json&retmax={self.top_k_results}&usehistory=y"
)
result = urllib.request.urlopen(url)
text = result.read().decode("utf-8")
json_text = json.loads(text)
articles = []
webenv = json_text["esearchresult"]["webenv"]
for uid in json_text["esearchresult"]["idlist"]:
article = self.retrieve_article(uid, webenv)
articles.append(article)
# Convert the list of articles to a JSON string
return articles
def retrieve_article(self, uid: str, webenv: str) -> dict:
url = (
self.base_url_efetch
+ "db=pubmed&retmode=xml&id="
+ uid
+ "&webenv="
+ webenv
)
retry = 0
while True:
try:
result = urllib.request.urlopen(url)
break
except urllib.error.HTTPError as e:
if e.code == 429 and retry < self.max_retry:
# Too Many Requests error
# wait for an exponentially increasing amount of time
print(
f"Too Many Requests, "
f"waiting for {self.sleep_time:.2f} seconds..."
)
time.sleep(self.sleep_time)
self.sleep_time *= 2
retry += 1
else:
raise e
xml_text = result.read().decode("utf-8")
# Get title
title = ""
if "<ArticleTitle>" in xml_text and "</ArticleTitle>" in xml_text:
start_tag = "<ArticleTitle>"
end_tag = "</ArticleTitle>"
title = xml_text[
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
]
# Get abstract
abstract = ""
if "<AbstractText>" in xml_text and "</AbstractText>" in xml_text:
start_tag = "<AbstractText>"
end_tag = "</AbstractText>"
abstract = xml_text[
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
]
# Get publication date
pub_date = ""
if "<PubDate>" in xml_text and "</PubDate>" in xml_text:
start_tag = "<PubDate>"
end_tag = "</PubDate>"
pub_date = xml_text[
xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
]
# Return article as dictionary
article = {
"uid": uid,
"title": title,
"summary": abstract,
"pub_date": pub_date,
}
return article
class PubmedQueryRun(BaseModel):
"""Tool that searches the PubMed API."""
name = "PubMed"
description = (
"A wrapper around PubMed.org "
"Useful for when you need to answer questions about Physics, Mathematics, "
"Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
"Electrical Engineering, and Economics "
"from scientific articles on PubMed.org. "
"Input should be a search query."
)
api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper)
def _run(
self,
query: str,
) -> str:
"""Use the Arxiv tool."""
return self.api_wrapper.run(query)
class PubMedInput(BaseModel):
query: str = Field(..., description="Search query.")
class PubMedSearchTool(BuiltinTool):
"""
Tool for performing a search using PubMed search engine.
@ -34,7 +205,7 @@ class PubMedSearchTool(BuiltinTool):
tool = PubmedQueryRun(args_schema=PubMedInput)
result = tool.run(query)
result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id, content=result))

View File

@ -1,11 +1,81 @@
from typing import Any, Union
from typing import Any, Optional, Union
from langchain.utilities import TwilioAPIWrapper
from pydantic import BaseModel, validator
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class TwilioAPIWrapper(BaseModel):
"""Messaging Client using Twilio.
To use, you should have the ``twilio`` python package installed,
and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and
``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as
named parameters to the constructor.
Example:
.. code-block:: python
from langchain.utilities.twilio import TwilioAPIWrapper
twilio = TwilioAPIWrapper(
account_sid="ACxxx",
auth_token="xxx",
from_number="+10123456789"
)
twilio.run('test', '+12484345508')
"""
client: Any #: :meta private:
account_sid: Optional[str] = None
"""Twilio account string identifier."""
auth_token: Optional[str] = None
"""Twilio auth token."""
from_number: Optional[str] = None
"""A Twilio phone number in [E.164](https://www.twilio.com/docs/glossary/what-e164)
format, an
[alphanumeric sender ID](https://www.twilio.com/docs/sms/send-messages#use-an-alphanumeric-sender-id),
or a [Channel Endpoint address](https://www.twilio.com/docs/sms/channels#channel-addresses)
that is enabled for the type of message you want to send. Phone numbers or
[short codes](https://www.twilio.com/docs/sms/api/short-code) purchased from
Twilio also work here. You cannot, for example, spoof messages from a private
cell phone number. If you are using `messaging_service_sid`, this parameter
must be empty.
""" # noqa: E501
@validator("client", pre=True, always=True)
def set_validator(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
try:
from twilio.rest import Client
except ImportError:
raise ImportError(
"Could not import twilio python package. "
"Please install it with `pip install twilio`."
)
account_sid = values.get("account_sid")
auth_token = values.get("auth_token")
values["from_number"] = values.get("from_number")
values["client"] = Client(account_sid, auth_token)
return values
def run(self, body: str, to: str) -> str:
"""Run body through Twilio and respond with message sid.
Args:
body: The text of the message you want to send. Can be up to 1,600
characters in length.
to: The destination phone number in
[E.164](https://www.twilio.com/docs/glossary/what-e164) format for
SMS/MMS or
[Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
for other 3rd-party channels.
""" # noqa: E501
message = self.client.messages.create(to, from_=self.from_number, body=body)
return message.sid
class SendMessageTool(BuiltinTool):
"""
A tool for sending messages using Twilio API.

View File

@ -1,16 +1,79 @@
from typing import Any, Union
from typing import Any, Optional, Union
from langchain import WikipediaAPIWrapper
from langchain.tools import WikipediaQueryRun
from pydantic import BaseModel, Field
import wikipedia
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
WIKIPEDIA_MAX_QUERY_LENGTH = 300
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
class WikipediaAPIWrapper:
"""Wrapper around WikipediaAPI.
To use, you should have the ``wikipedia`` python package installed.
This wrapper will use the Wikipedia API to conduct searches and
fetch page summaries. By default, it will return the page summaries
of the top-k results.
It limits the Document content by doc_content_chars_max.
"""
top_k_results: int = 3
lang: str = "en"
load_all_available_meta: bool = False
doc_content_chars_max: int = 4000
def __init__(self, doc_content_chars_max: int = 4000):
self.doc_content_chars_max = doc_content_chars_max
def run(self, query: str) -> str:
wikipedia.set_lang(self.lang)
wiki_client = wikipedia
"""Run Wikipedia search and get page summaries."""
page_titles = wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])
summaries = []
for page_title in page_titles[: self.top_k_results]:
if wiki_page := self._fetch_page(page_title):
if summary := self._formatted_page_summary(page_title, wiki_page):
summaries.append(summary)
if not summaries:
return "No good Wikipedia Search Result was found"
return "\n\n".join(summaries)[: self.doc_content_chars_max]
@staticmethod
def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]:
return f"Page: {page_title}\nSummary: {wiki_page.summary}"
def _fetch_page(self, page: str) -> Optional[str]:
try:
return wikipedia.page(title=page, auto_suggest=False)
except (
wikipedia.exceptions.PageError,
wikipedia.exceptions.DisambiguationError,
):
return None
class WikipediaQueryRun:
"""Tool that searches the Wikipedia API."""
name = "Wikipedia"
description = (
"A wrapper around Wikipedia. "
"Useful for when you need to answer general questions about "
"people, places, companies, facts, historical events, or other subjects. "
"Input should be a search query."
)
api_wrapper: WikipediaAPIWrapper
def __init__(self, api_wrapper: WikipediaAPIWrapper):
self.api_wrapper = api_wrapper
def _run(
self,
query: str,
) -> str:
"""Use the Wikipedia tool."""
return self.api_wrapper.run(query)
class WikiPediaSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
@ -24,14 +87,10 @@ class WikiPediaSearchTool(BuiltinTool):
return self.create_text_message('Please input query')
tool = WikipediaQueryRun(
name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput
)
result = tool.run(tool_input={
'query': query
})
result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id,content=result))