Spaces:
Paused
Paused
| import argparse | |
| import os | |
| import sys | |
| import uvicorn | |
| from fastapi import FastAPI, Depends | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel, Field | |
| from typing import Union | |
| from sse_starlette.sse import EventSourceResponse, ServerSentEvent | |
| from utils.logger import logger | |
| from networks.google_searcher import GoogleSearcher | |
| from networks.webpage_fetcher import WebpageFetcher | |
| from documents.query_results_extractor import QueryResultsExtractor | |
| from documents.webpage_content_extractor import WebpageContentExtractor | |
| from utils.logger import logger | |
| class SearchAPIApp: | |
| def __init__(self): | |
| self.app = FastAPI( | |
| docs_url="/", | |
| title="Web Search API", | |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
| version="1.0", | |
| ) | |
| self.setup_routes() | |
| class QueriesToSearchResultsPostItem(BaseModel): | |
| queries: list = Field( | |
| default=[""], | |
| description="(list[str]) Queries to search", | |
| ) | |
| result_num: int = Field( | |
| default=10, | |
| description="(int) Number of search results", | |
| ) | |
| safe: bool = Field( | |
| default=False, | |
| description="(bool) Enable SafeSearch", | |
| ) | |
| types: list = Field( | |
| default=["web"], | |
| description="(list[str]) Types of search results: `web`, `image`, `videos`, `news`", | |
| ) | |
| extract_content: bool = Field( | |
| default=False, | |
| description="(bool) Enable extracting main text contents from webpage, will add `text` filed in each `query_result` dict", | |
| ) | |
| overwrite_query_html: bool = Field( | |
| default=False, | |
| description="(bool) Overwrite HTML file of query results", | |
| ) | |
| overwrite_webpage_html: bool = Field( | |
| default=False, | |
| description="(bool) Overwrite HTML files of webpages from query results", | |
| ) | |
| def queries_to_search_results(self, item: QueriesToSearchResultsPostItem): | |
| google_searcher = GoogleSearcher() | |
| query_results_extractor = QueryResultsExtractor() | |
| queries_search_results = [] | |
| for query in item.queries: | |
| if not query.strip(): | |
| continue | |
| query_html_path = google_searcher.search( | |
| query=query, | |
| result_num=item.result_num, | |
| safe=item.safe, | |
| overwrite=item.overwrite_query_html, | |
| ) | |
| query_search_results = query_results_extractor.extract(query_html_path) | |
| queries_search_results.append(query_search_results) | |
| logger.note(queries_search_results) | |
| if item.extract_content: | |
| webpage_fetcher = WebpageFetcher() | |
| webpage_content_extractor = WebpageContentExtractor() | |
| for query_idx, query_search_result in enumerate(queries_search_results): | |
| for query_result_idx, query_result in enumerate( | |
| query_search_result["query_results"] | |
| ): | |
| webpage_html_path = webpage_fetcher.fetch( | |
| query_result["url"], | |
| overwrite=item.overwrite_webpage_html, | |
| output_parent=query_search_result["query"], | |
| ) | |
| extracted_content = webpage_content_extractor.extract( | |
| webpage_html_path | |
| ) | |
| queries_search_results[query_idx]["query_results"][ | |
| query_result_idx | |
| ]["text"] = extracted_content | |
| return queries_search_results | |
| def setup_routes(self): | |
| self.app.post( | |
| "/queries_to_search_results", | |
| summary="Search queries, and extract contents from results", | |
| )(self.queries_to_search_results) | |
| class ArgParser(argparse.ArgumentParser): | |
| def __init__(self, *args, **kwargs): | |
| super(ArgParser, self).__init__(*args, **kwargs) | |
| self.add_argument( | |
| "-s", | |
| "--server", | |
| type=str, | |
| default="0.0.0.0", | |
| help="Server IP for Web Search API", | |
| ) | |
| self.add_argument( | |
| "-p", | |
| "--port", | |
| type=int, | |
| default=21111, | |
| help="Server Port for Web Search API", | |
| ) | |
| self.add_argument( | |
| "-d", | |
| "--dev", | |
| default=False, | |
| action="store_true", | |
| help="Run in dev mode", | |
| ) | |
| self.args = self.parse_args(sys.argv[1:]) | |
| app = SearchAPIApp().app | |
| if __name__ == "__main__": | |
| args = ArgParser().args | |
| if args.dev: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) | |
| else: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) | |
| # python -m apis.search_api # [Docker] in product mode | |
| # python -m apis.search_api -d # [Dev] in develop mode | |