diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py
index 2e524466a..af9db3989 100644
--- a/api/core/callback_handler/agent_tool_callback_handler.py
+++ b/api/core/callback_handler/agent_tool_callback_handler.py
@@ -1,12 +1,32 @@
import os
-from typing import Any, Optional, Union
+from typing import Any, Optional, TextIO, Union
-from langchain.callbacks.base import BaseCallbackHandler
-from langchain.input import print_text
from pydantic import BaseModel
+_TEXT_COLOR_MAPPING = {
+ "blue": "36;1",
+ "yellow": "33;1",
+ "pink": "38;5;200",
+ "green": "32;1",
+ "red": "31;1",
+}
-class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
+def get_colored_text(text: str, color: str) -> str:
+ """Get colored text."""
+ color_str = _TEXT_COLOR_MAPPING[color]
+ return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
+
+
+def print_text(
+ text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
+) -> None:
+ """Print text with highlighting and no end characters."""
+ text_to_print = get_colored_text(text, color) if color else text
+ print(text_to_print, end=end, file=file)
+ if file:
+ file.flush() # ensure all printed content are written to file
+
+class DifyAgentCallbackHandler(BaseModel):
"""Callback Handler that prints to std out."""
color: Optional[str] = ''
current_loop = 1
diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
index 033d942f4..fb64e07a8 100644
--- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
+++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py
@@ -1,11 +1,92 @@
-from typing import Any
+import logging
+from typing import Any, Optional
-from langchain.utilities import ArxivAPIWrapper
+import arxiv
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
+logger = logging.getLogger(__name__)
+class ArxivAPIWrapper(BaseModel):
+ """Wrapper around ArxivAPI.
+
+ To use, you should have the ``arxiv`` python package installed.
+ https://lukasschwab.me/arxiv.py/index.html
+ This wrapper will use the Arxiv API to conduct searches and
+ fetch document summaries. By default, it will return the document summaries
+ of the top-k results.
+ It limits the Document content by doc_content_chars_max.
+ Set doc_content_chars_max=None if you don't want to limit the content size.
+
+ Args:
+ top_k_results: number of the top-scored document used for the arxiv tool
+ ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
+ load_max_docs: a limit to the number of loaded documents
+ load_all_available_meta:
+ if True: the `metadata` of the loaded Documents contains all available
+ meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
+ if False: the `metadata` contains only the published date, title,
+ authors and summary.
+ doc_content_chars_max: an optional cut limit for the length of a document's
+ content
+
+ Example:
+ .. code-block:: python
+
+ arxiv = ArxivAPIWrapper(
+ top_k_results = 3,
+ ARXIV_MAX_QUERY_LENGTH = 300,
+ load_max_docs = 3,
+ load_all_available_meta = False,
+ doc_content_chars_max = 40000
+ )
+ arxiv.run("tree of thought llm)
+ """
+
+ arxiv_search = arxiv.Search #: :meta private:
+ arxiv_exceptions = (
+ arxiv.ArxivError,
+ arxiv.UnexpectedEmptyPageError,
+ arxiv.HTTPError,
+ ) # :meta private:
+ top_k_results: int = 3
+ ARXIV_MAX_QUERY_LENGTH = 300
+ load_max_docs: int = 100
+ load_all_available_meta: bool = False
+ doc_content_chars_max: Optional[int] = 4000
+
+ def run(self, query: str) -> str:
+ """
+ Performs an arxiv search and A single string
+ with the publish date, title, authors, and summary
+ for each article separated by two newlines.
+
+ If an error occurs or no documents found, error text
+ is returned instead. Wrapper for
+ https://lukasschwab.me/arxiv.py/index.html#Search
+
+ Args:
+ query: a plaintext search query
+ """ # noqa: E501
+ try:
+ results = self.arxiv_search( # type: ignore
+ query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
+ ).results()
+ except self.arxiv_exceptions as ex:
+ return f"Arxiv exception: {ex}"
+ docs = [
+ f"Published: {result.updated.date()}\n"
+ f"Title: {result.title}\n"
+ f"Authors: {', '.join(a.name for a in result.authors)}\n"
+ f"Summary: {result.summary}"
+ for result in results
+ ]
+ if docs:
+ return "\n\n".join(docs)[: self.doc_content_chars_max]
+ else:
+ return "No good Arxiv Result was found"
+
class ArxivSearchInput(BaseModel):
query: str = Field(..., description="Search query.")
diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.py b/api/core/tools/provider/builtin/brave/tools/brave_search.py
index cb91d9499..f121cb0e3 100644
--- a/api/core/tools/provider/builtin/brave/tools/brave_search.py
+++ b/api/core/tools/provider/builtin/brave/tools/brave_search.py
@@ -1,11 +1,95 @@
-from typing import Any
+import json
+from typing import Any, Optional
-from langchain.tools import BraveSearch
+import requests
+from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
+class BraveSearchWrapper(BaseModel):
+ """Wrapper around the Brave search engine."""
+
+ api_key: str
+ """The API key to use for the Brave search engine."""
+ search_kwargs: dict = Field(default_factory=dict)
+ """Additional keyword arguments to pass to the search request."""
+ base_url = "https://api.search.brave.com/res/v1/web/search"
+ """The base URL for the Brave search engine."""
+
+ def run(self, query: str) -> str:
+ """Query the Brave search engine and return the results as a JSON string.
+
+ Args:
+ query: The query to search for.
+
+ Returns: The results as a JSON string.
+
+ """
+ web_search_results = self._search_request(query=query)
+ final_results = [
+ {
+ "title": item.get("title"),
+ "link": item.get("url"),
+ "snippet": item.get("description"),
+ }
+ for item in web_search_results
+ ]
+ return json.dumps(final_results)
+
+ def _search_request(self, query: str) -> list[dict]:
+ headers = {
+ "X-Subscription-Token": self.api_key,
+ "Accept": "application/json",
+ }
+ req = requests.PreparedRequest()
+ params = {**self.search_kwargs, **{"q": query}}
+ req.prepare_url(self.base_url, params)
+ if req.url is None:
+ raise ValueError("prepared url is None, this should not happen")
+
+ response = requests.get(req.url, headers=headers)
+ if not response.ok:
+ raise Exception(f"HTTP error {response.status_code}")
+
+ return response.json().get("web", {}).get("results", [])
+
+class BraveSearch(BaseModel):
+ """Tool that queries the BraveSearch."""
+
+ name = "brave_search"
+ description = (
+ "a search engine. "
+ "useful for when you need to answer questions about current events."
+ " input should be a search query."
+ )
+ search_wrapper: BraveSearchWrapper
+
+ @classmethod
+ def from_api_key(
+ cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any
+ ) -> "BraveSearch":
+ """Create a tool from an api key.
+
+ Args:
+ api_key: The api key to use.
+ search_kwargs: Any additional kwargs to pass to the search wrapper.
+ **kwargs: Any additional kwargs to pass to the tool.
+
+ Returns:
+ A tool.
+ """
+ wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {})
+ return cls(search_wrapper=wrapper, **kwargs)
+
+ def _run(
+ self,
+ query: str,
+ ) -> str:
+ """Use the tool."""
+ return self.search_wrapper.run(query)
+
class BraveSearchTool(BuiltinTool):
"""
Tool for performing a search using Brave search engine.
@@ -31,7 +115,7 @@ class BraveSearchTool(BuiltinTool):
tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count})
- results = tool.run(query)
+ results = tool._run(query)
if not results:
return self.create_text_message(f"No results found for '{query}' in Tavily")
diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py
index 6046a1893..80722a4d6 100644
--- a/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py
+++ b/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py
@@ -1,16 +1,147 @@
-from typing import Any
+from typing import Any, Optional
-from langchain.tools import DuckDuckGoSearchRun
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
+class DuckDuckGoSearchAPIWrapper(BaseModel):
+ """Wrapper for DuckDuckGo Search API.
+
+ Free and does not require any setup.
+ """
+
+ region: Optional[str] = "wt-wt"
+ safesearch: str = "moderate"
+ time: Optional[str] = "y"
+ max_results: int = 5
+
+ def get_snippets(self, query: str) -> list[str]:
+ """Run query through DuckDuckGo and return concatenated results."""
+ from duckduckgo_search import DDGS
+
+ with DDGS() as ddgs:
+ results = ddgs.text(
+ query,
+ region=self.region,
+ safesearch=self.safesearch,
+ timelimit=self.time,
+ )
+ if results is None:
+ return ["No good DuckDuckGo Search Result was found"]
+ snippets = []
+ for i, res in enumerate(results, 1):
+ if res is not None:
+ snippets.append(res["body"])
+ if len(snippets) == self.max_results:
+ break
+ return snippets
+
+ def run(self, query: str) -> str:
+ snippets = self.get_snippets(query)
+ return " ".join(snippets)
+
+ def results(
+ self, query: str, num_results: int, backend: str = "api"
+ ) -> list[dict[str, str]]:
+ """Run query through DuckDuckGo and return metadata.
+
+ Args:
+ query: The query to search for.
+ num_results: The number of results to return.
+
+ Returns:
+ A list of dictionaries with the following keys:
+ snippet - The description of the result.
+ title - The title of the result.
+ link - The link to the result.
+ """
+ from duckduckgo_search import DDGS
+
+ with DDGS() as ddgs:
+ results = ddgs.text(
+ query,
+ region=self.region,
+ safesearch=self.safesearch,
+ timelimit=self.time,
+ backend=backend,
+ )
+ if results is None:
+ return [{"Result": "No good DuckDuckGo Search Result was found"}]
+
+ def to_metadata(result: dict) -> dict[str, str]:
+ if backend == "news":
+ return {
+ "date": result["date"],
+ "title": result["title"],
+ "snippet": result["body"],
+ "source": result["source"],
+ "link": result["url"],
+ }
+ return {
+ "snippet": result["body"],
+ "title": result["title"],
+ "link": result["href"],
+ }
+
+ formatted_results = []
+ for i, res in enumerate(results, 1):
+ if res is not None:
+ formatted_results.append(to_metadata(res))
+ if len(formatted_results) == num_results:
+ break
+ return formatted_results
+
+
+class DuckDuckGoSearchRun(BaseModel):
+ """Tool that queries the DuckDuckGo search API."""
+
+ name = "duckduckgo_search"
+ description = (
+ "A wrapper around DuckDuckGo Search. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query."
+ )
+ api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
+ default_factory=DuckDuckGoSearchAPIWrapper
+ )
+
+ def _run(
+ self,
+ query: str,
+ ) -> str:
+ """Use the tool."""
+ return self.api_wrapper.run(query)
+
+
+class DuckDuckGoSearchResults(BaseModel):
+ """Tool that queries the DuckDuckGo search API and gets back json."""
+
+ name = "DuckDuckGo Results JSON"
+ description = (
+ "A wrapper around Duck Duck Go Search. "
+ "Useful for when you need to answer questions about current events. "
+ "Input should be a search query. Output is a JSON array of the query results"
+ )
+ num_results: int = 4
+ api_wrapper: DuckDuckGoSearchAPIWrapper = Field(
+ default_factory=DuckDuckGoSearchAPIWrapper
+ )
+ backend: str = "api"
+
+ def _run(
+ self,
+ query: str,
+ ) -> str:
+ """Use the tool."""
+ res = self.api_wrapper.results(query, self.num_results, backend=self.backend)
+ res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res]
+ return ", ".join([f"[{rs}]" for rs in res_strs])
+
class DuckDuckGoInput(BaseModel):
query: str = Field(..., description="Search query.")
-
class DuckDuckGoSearchTool(BuiltinTool):
"""
Tool for performing a search using DuckDuckGo search engine.
@@ -34,7 +165,7 @@ class DuckDuckGoSearchTool(BuiltinTool):
tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput)
- result = tool.run(query)
+ result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id, content=result))
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py
index 1bed1fa77..ee465d9bc 100644
--- a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py
+++ b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py
@@ -1,16 +1,187 @@
+import json
+import time
+import urllib.error
+import urllib.parse
+import urllib.request
from typing import Any
-from langchain.tools import PubmedQueryRun
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
+class PubMedAPIWrapper(BaseModel):
+ """
+ Wrapper around PubMed API.
+
+ This wrapper will use the PubMed API to conduct searches and fetch
+ document summaries. By default, it will return the document summaries
+ of the top-k results of an input search.
+
+ Parameters:
+ top_k_results: number of the top-scored document used for the PubMed tool
+ load_max_docs: a limit to the number of loaded documents
+ load_all_available_meta:
+ if True: the `metadata` of the loaded Documents gets all available meta info
+ (see https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch)
+ if False: the `metadata` gets only the most informative fields.
+ """
+
+ base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
+ base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
+ max_retry = 5
+ sleep_time = 0.2
+
+ # Default values for the parameters
+ top_k_results: int = 3
+ load_max_docs: int = 25
+ ARXIV_MAX_QUERY_LENGTH = 300
+ doc_content_chars_max: int = 2000
+ load_all_available_meta: bool = False
+ email: str = "your_email@example.com"
+
+ def run(self, query: str) -> str:
+ """
+ Run PubMed search and get the article meta information.
+ See https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch
+ It uses only the most informative fields of article meta information.
+ """
+
+ try:
+ # Retrieve the top-k results for the query
+ docs = [
+ f"Published: {result['pub_date']}\nTitle: {result['title']}\n"
+ f"Summary: {result['summary']}"
+ for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH])
+ ]
+
+ # Join the results and limit the character count
+ return (
+ "\n\n".join(docs)[:self.doc_content_chars_max]
+ if docs
+ else "No good PubMed Result was found"
+ )
+ except Exception as ex:
+ return f"PubMed exception: {ex}"
+
+ def load(self, query: str) -> list[dict]:
+ """
+ Search PubMed for documents matching the query.
+ Return a list of dictionaries containing the document metadata.
+ """
+
+ url = (
+ self.base_url_esearch
+ + "db=pubmed&term="
+ + str({urllib.parse.quote(query)})
+ + f"&retmode=json&retmax={self.top_k_results}&usehistory=y"
+ )
+ result = urllib.request.urlopen(url)
+ text = result.read().decode("utf-8")
+ json_text = json.loads(text)
+
+ articles = []
+ webenv = json_text["esearchresult"]["webenv"]
+ for uid in json_text["esearchresult"]["idlist"]:
+ article = self.retrieve_article(uid, webenv)
+ articles.append(article)
+
+ # Convert the list of articles to a JSON string
+ return articles
+
+ def retrieve_article(self, uid: str, webenv: str) -> dict:
+ url = (
+ self.base_url_efetch
+ + "db=pubmed&retmode=xml&id="
+ + uid
+ + "&webenv="
+ + webenv
+ )
+
+ retry = 0
+ while True:
+ try:
+ result = urllib.request.urlopen(url)
+ break
+ except urllib.error.HTTPError as e:
+ if e.code == 429 and retry < self.max_retry:
+ # Too Many Requests error
+ # wait for an exponentially increasing amount of time
+ print(
+ f"Too Many Requests, "
+ f"waiting for {self.sleep_time:.2f} seconds..."
+ )
+ time.sleep(self.sleep_time)
+ self.sleep_time *= 2
+ retry += 1
+ else:
+ raise e
+
+ xml_text = result.read().decode("utf-8")
+
+ # Get title
+ title = ""
+ if "" in xml_text and "" in xml_text:
+ start_tag = ""
+ end_tag = ""
+ title = xml_text[
+ xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
+ ]
+
+ # Get abstract
+ abstract = ""
+ if "" in xml_text and "" in xml_text:
+ start_tag = ""
+ end_tag = ""
+ abstract = xml_text[
+ xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
+ ]
+
+ # Get publication date
+ pub_date = ""
+ if "" in xml_text and "" in xml_text:
+ start_tag = ""
+ end_tag = ""
+ pub_date = xml_text[
+ xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)
+ ]
+
+ # Return article as dictionary
+ article = {
+ "uid": uid,
+ "title": title,
+ "summary": abstract,
+ "pub_date": pub_date,
+ }
+ return article
+
+
+class PubmedQueryRun(BaseModel):
+ """Tool that searches the PubMed API."""
+
+ name = "PubMed"
+ description = (
+ "A wrapper around PubMed.org "
+ "Useful for when you need to answer questions about Physics, Mathematics, "
+ "Computer Science, Quantitative Biology, Quantitative Finance, Statistics, "
+ "Electrical Engineering, and Economics "
+ "from scientific articles on PubMed.org. "
+ "Input should be a search query."
+ )
+ api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper)
+
+ def _run(
+ self,
+ query: str,
+ ) -> str:
+ """Use the Arxiv tool."""
+ return self.api_wrapper.run(query)
+
+
class PubMedInput(BaseModel):
query: str = Field(..., description="Search query.")
-
class PubMedSearchTool(BuiltinTool):
"""
Tool for performing a search using PubMed search engine.
@@ -34,7 +205,7 @@ class PubMedSearchTool(BuiltinTool):
tool = PubmedQueryRun(args_schema=PubMedInput)
- result = tool.run(query)
+ result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id, content=result))
\ No newline at end of file
diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py
index 984ac3e90..33c5d62e4 100644
--- a/api/core/tools/provider/builtin/twilio/tools/send_message.py
+++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py
@@ -1,11 +1,81 @@
-from typing import Any, Union
+from typing import Any, Optional, Union
-from langchain.utilities import TwilioAPIWrapper
+from pydantic import BaseModel, validator
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
+class TwilioAPIWrapper(BaseModel):
+ """Messaging Client using Twilio.
+
+ To use, you should have the ``twilio`` python package installed,
+ and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and
+ ``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as
+ named parameters to the constructor.
+
+ Example:
+ .. code-block:: python
+
+ from langchain.utilities.twilio import TwilioAPIWrapper
+ twilio = TwilioAPIWrapper(
+ account_sid="ACxxx",
+ auth_token="xxx",
+ from_number="+10123456789"
+ )
+ twilio.run('test', '+12484345508')
+ """
+
+ client: Any #: :meta private:
+ account_sid: Optional[str] = None
+ """Twilio account string identifier."""
+ auth_token: Optional[str] = None
+ """Twilio auth token."""
+ from_number: Optional[str] = None
+ """A Twilio phone number in [E.164](https://www.twilio.com/docs/glossary/what-e164)
+ format, an
+ [alphanumeric sender ID](https://www.twilio.com/docs/sms/send-messages#use-an-alphanumeric-sender-id),
+ or a [Channel Endpoint address](https://www.twilio.com/docs/sms/channels#channel-addresses)
+ that is enabled for the type of message you want to send. Phone numbers or
+ [short codes](https://www.twilio.com/docs/sms/api/short-code) purchased from
+ Twilio also work here. You cannot, for example, spoof messages from a private
+ cell phone number. If you are using `messaging_service_sid`, this parameter
+ must be empty.
+ """ # noqa: E501
+
+ @validator("client", pre=True, always=True)
+ def set_validator(cls, values: dict) -> dict:
+ """Validate that api key and python package exists in environment."""
+ try:
+ from twilio.rest import Client
+ except ImportError:
+ raise ImportError(
+ "Could not import twilio python package. "
+ "Please install it with `pip install twilio`."
+ )
+ account_sid = values.get("account_sid")
+ auth_token = values.get("auth_token")
+ values["from_number"] = values.get("from_number")
+ values["client"] = Client(account_sid, auth_token)
+
+ return values
+
+ def run(self, body: str, to: str) -> str:
+ """Run body through Twilio and respond with message sid.
+
+ Args:
+ body: The text of the message you want to send. Can be up to 1,600
+ characters in length.
+ to: The destination phone number in
+ [E.164](https://www.twilio.com/docs/glossary/what-e164) format for
+ SMS/MMS or
+ [Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
+ for other 3rd-party channels.
+ """ # noqa: E501
+ message = self.client.messages.create(to, from_=self.from_number, body=body)
+ return message.sid
+
+
class SendMessageTool(BuiltinTool):
"""
A tool for sending messages using Twilio API.
diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py
index 38b495ad6..ef2990bfe 100644
--- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py
+++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py
@@ -1,16 +1,79 @@
-from typing import Any, Union
+from typing import Any, Optional, Union
-from langchain import WikipediaAPIWrapper
-from langchain.tools import WikipediaQueryRun
-from pydantic import BaseModel, Field
+import wikipedia
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
+WIKIPEDIA_MAX_QUERY_LENGTH = 300
-class WikipediaInput(BaseModel):
- query: str = Field(..., description="search query.")
+class WikipediaAPIWrapper:
+ """Wrapper around WikipediaAPI.
+ To use, you should have the ``wikipedia`` python package installed.
+ This wrapper will use the Wikipedia API to conduct searches and
+ fetch page summaries. By default, it will return the page summaries
+ of the top-k results.
+ It limits the Document content by doc_content_chars_max.
+ """
+
+ top_k_results: int = 3
+ lang: str = "en"
+ load_all_available_meta: bool = False
+ doc_content_chars_max: int = 4000
+
+ def __init__(self, doc_content_chars_max: int = 4000):
+ self.doc_content_chars_max = doc_content_chars_max
+
+ def run(self, query: str) -> str:
+ wikipedia.set_lang(self.lang)
+ wiki_client = wikipedia
+
+ """Run Wikipedia search and get page summaries."""
+ page_titles = wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])
+ summaries = []
+ for page_title in page_titles[: self.top_k_results]:
+ if wiki_page := self._fetch_page(page_title):
+ if summary := self._formatted_page_summary(page_title, wiki_page):
+ summaries.append(summary)
+ if not summaries:
+ return "No good Wikipedia Search Result was found"
+ return "\n\n".join(summaries)[: self.doc_content_chars_max]
+
+ @staticmethod
+ def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]:
+ return f"Page: {page_title}\nSummary: {wiki_page.summary}"
+
+ def _fetch_page(self, page: str) -> Optional[str]:
+ try:
+ return wikipedia.page(title=page, auto_suggest=False)
+ except (
+ wikipedia.exceptions.PageError,
+ wikipedia.exceptions.DisambiguationError,
+ ):
+ return None
+
+class WikipediaQueryRun:
+ """Tool that searches the Wikipedia API."""
+
+ name = "Wikipedia"
+ description = (
+ "A wrapper around Wikipedia. "
+ "Useful for when you need to answer general questions about "
+ "people, places, companies, facts, historical events, or other subjects. "
+ "Input should be a search query."
+ )
+ api_wrapper: WikipediaAPIWrapper
+
+ def __init__(self, api_wrapper: WikipediaAPIWrapper):
+ self.api_wrapper = api_wrapper
+
+ def _run(
+ self,
+ query: str,
+ ) -> str:
+ """Use the Wikipedia tool."""
+ return self.api_wrapper.run(query)
class WikiPediaSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
@@ -24,14 +87,10 @@ class WikiPediaSearchTool(BuiltinTool):
return self.create_text_message('Please input query')
tool = WikipediaQueryRun(
- name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
- args_schema=WikipediaInput
)
- result = tool.run(tool_input={
- 'query': query
- })
+ result = tool._run(query)
return self.create_text_message(self.summary(user_id=user_id,content=result))
\ No newline at end of file