Spaces:
Paused
Paused
| import argparse | |
| import os | |
| import sys | |
| import uvicorn | |
| from fastapi import FastAPI, Depends | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel, Field | |
| from typing import Union | |
| from sse_starlette.sse import EventSourceResponse, ServerSentEvent | |
| from utils.logger import logger | |
| from networks.google_searcher import GoogleSearcher | |
| from networks.webpage_fetcher import BatchWebpageFetcher | |
| from documents.query_results_extractor import QueryResultsExtractor | |
| from documents.webpage_content_extractor import BatchWebpageContentExtractor | |
| from utils.logger import logger | |
| class SearchAPIApp: | |
| def __init__(self): | |
| self.app = FastAPI( | |
| docs_url="/", | |
| title="Web Search API", | |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
| version="1.0", | |
| ) | |
| self.setup_routes() | |
| class QueriesToSearchResultsPostItem(BaseModel): | |
| queries: list = Field( | |
| default=[""], | |
| description="(list[str]) Queries to search", | |
| ) | |
| result_num: int = Field( | |
| default=10, | |
| description="(int) Number of search results", | |
| ) | |
| safe: bool = Field( | |
| default=False, | |
| description="(bool) Enable SafeSearch", | |
| ) | |
| types: list = Field( | |
| default=["web"], | |
| description="(list[str]) Types of search results: `web`, `image`, `videos`, `news`", | |
| ) | |
| extract_webpage: bool = Field( | |
| default=False, | |
| description="(bool) Enable extracting main text contents from webpage, will add `text` filed in each `query_result` dict", | |
| ) | |
| overwrite_query_html: bool = Field( | |
| default=False, | |
| description="(bool) Overwrite HTML file of query results", | |
| ) | |
| overwrite_webpage_html: bool = Field( | |
| default=False, | |
| description="(bool) Overwrite HTML files of webpages from query results", | |
| ) | |
| def queries_to_search_results(self, item: QueriesToSearchResultsPostItem): | |
| google_searcher = GoogleSearcher() | |
| queries_search_results = [] | |
| for query in item.queries: | |
| query_results_extractor = QueryResultsExtractor() | |
| if not query.strip(): | |
| continue | |
| query_html_path = google_searcher.search( | |
| query=query, | |
| result_num=item.result_num, | |
| safe=item.safe, | |
| overwrite=item.overwrite_query_html, | |
| ) | |
| query_search_results = query_results_extractor.extract(query_html_path) | |
| queries_search_results.append(query_search_results) | |
| logger.note(queries_search_results) | |
| if item.extract_webpage: | |
| queries_search_results = self.extract_webpages( | |
| queries_search_results, | |
| overwrite_webpage_html=item.overwrite_webpage_html, | |
| ) | |
| return queries_search_results | |
| def extract_webpages(self, queries_search_results, overwrite_webpage_html=False): | |
| for query_idx, query_search_results in enumerate(queries_search_results): | |
| # Fetch webpages with urls | |
| batch_webpage_fetcher = BatchWebpageFetcher() | |
| urls = [ | |
| query_result["url"] | |
| for query_result in query_search_results["query_results"] | |
| ] | |
| url_and_html_path_list = batch_webpage_fetcher.fetch( | |
| urls, | |
| overwrite=overwrite_webpage_html, | |
| output_parent=query_search_results["query"], | |
| ) | |
| html_paths = [ | |
| str(url_and_html_path["html_path"]) | |
| for url_and_html_path in url_and_html_path_list | |
| ] | |
| # Extract webpage contents from htmls | |
| batch_webpage_content_extractor = BatchWebpageContentExtractor() | |
| html_path_and_extracted_content_list = ( | |
| batch_webpage_content_extractor.extract(html_paths) | |
| ) | |
| # Write extracted contents (as 'text' field) to query_search_results | |
| url_and_extracted_content_dict = {} | |
| for item in url_and_html_path_list: | |
| url = item["url"] | |
| html_path = str(item["html_path"]) | |
| extracted_content = html_path_and_extracted_content_list[ | |
| html_paths.index(html_path) | |
| ]["extracted_content"] | |
| url_and_extracted_content_dict[url] = extracted_content | |
| for query_result_idx, query_result in enumerate( | |
| query_search_results["query_results"] | |
| ): | |
| url = query_result["url"] | |
| extracted_content = url_and_extracted_content_dict[url] | |
| queries_search_results[query_idx]["query_results"][query_result_idx][ | |
| "text" | |
| ] = extracted_content | |
| return queries_search_results | |
| def setup_routes(self): | |
| self.app.post( | |
| "/queries_to_search_results", | |
| summary="Search queries, and extract contents from results", | |
| )(self.queries_to_search_results) | |
| class ArgParser(argparse.ArgumentParser): | |
| def __init__(self, *args, **kwargs): | |
| super(ArgParser, self).__init__(*args, **kwargs) | |
| self.add_argument( | |
| "-s", | |
| "--server", | |
| type=str, | |
| default="0.0.0.0", | |
| help="Server IP for Web Search API", | |
| ) | |
| self.add_argument( | |
| "-p", | |
| "--port", | |
| type=int, | |
| default=21111, | |
| help="Server Port for Web Search API", | |
| ) | |
| self.add_argument( | |
| "-d", | |
| "--dev", | |
| default=False, | |
| action="store_true", | |
| help="Run in dev mode", | |
| ) | |
| self.args = self.parse_args(sys.argv[1:]) | |
| app = SearchAPIApp().app | |
| if __name__ == "__main__": | |
| args = ArgParser().args | |
| if args.dev: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) | |
| else: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) | |
| # python -m apis.search_api # [Docker] in product mode | |
| # python -m apis.search_api -d # [Dev] in develop mode | |