Spaces:
Sleeping
Sleeping
| """Text-to-SQL running.""" | |
| import asyncio | |
| import json | |
| import re | |
| import time | |
| from typing import cast | |
| import duckdb | |
| import structlog | |
| from manifest import Manifest | |
| from manifest.response import Response, Usage | |
| from prompt_formatters import RajkumarFormatter, MotherDuckFormatter | |
| from schema import DEFAULT_TABLE_NAME, TextToSQLModelResponse, TextToSQLParams | |
| from tqdm.auto import tqdm | |
| logger = structlog.get_logger() | |
| def clean_whitespace(sql: str) -> str: | |
| """Clean whitespace.""" | |
| return re.sub(r"[\t\n\s]+", " ", sql) | |
| def instruction_to_sql( | |
| params: TextToSQLParams, | |
| extra_context: list[str], | |
| manifest: Manifest, | |
| prompt_formatter: RajkumarFormatter = None, | |
| overwrite_manifest: bool = False, | |
| max_tokens: int = 300, | |
| temperature: float = 0.1, | |
| stop_sequences: list[str] | None = None, | |
| num_beams: int = 1, | |
| ) -> TextToSQLModelResponse: | |
| """Parse the instruction to a sql command.""" | |
| return instruction_to_sql_list( | |
| params=[params], | |
| extra_context=[extra_context], | |
| manifest=manifest, | |
| prompt_formatter=prompt_formatter, | |
| overwrite_manifest=overwrite_manifest, | |
| max_tokens=max_tokens, | |
| temperature=0.1, | |
| stop_sequences=stop_sequences, | |
| num_beams=num_beams, | |
| )[0] | |
| def run_motherduck_prompt_sql(params: list[TextToSQLParams]) -> list[TextToSQLModelResponse]: | |
| results = [] | |
| for param in params: | |
| con = duckdb.connect('md:') | |
| try: | |
| sql_query = con.execute("CALL prompt_sql(?);", [param.instruction]).fetchall()[0][0] | |
| except Exception as e: | |
| print(e) | |
| sql_query = "SELECT * FROM hn.hacker_news LIMIT 1"; | |
| usage = Usage( | |
| completion_tokens = 0, | |
| prompt_tokens = 0, | |
| total_tokens = 0 | |
| ) | |
| model_response = TextToSQLModelResponse( | |
| output=sql_query, | |
| raw_output=sql_query, | |
| final_prompt=param.instruction, | |
| usage=usage, | |
| ) | |
| results.append(model_response) | |
| return results | |
| def instruction_to_sql_list( | |
| params: list[TextToSQLParams], | |
| extra_context: list[list[str]], | |
| manifest: Manifest, | |
| prompt_formatter: RajkumarFormatter = None, | |
| overwrite_manifest: bool = False, | |
| max_tokens: int = 300, | |
| temperature: float = 0.1, | |
| stop_sequences: list[str] | None = None, | |
| num_beams: int = 1, | |
| verbose: bool = False, | |
| ) -> list[TextToSQLModelResponse]: | |
| """Parse the list of instructions to sql commands. | |
| Connector is used for default retry handlers only. | |
| """ | |
| if type(prompt_formatter) is MotherDuckFormatter: | |
| return run_motherduck_prompt_sql(params) | |
| if prompt_formatter is None: | |
| raise ValueError("Prompt formatter is required.") | |
| def construct_params( | |
| params: TextToSQLParams, | |
| context: list[str], | |
| ) -> str | list[dict]: | |
| """Turn params into prompt.""" | |
| if prompt_formatter.clean_whitespace: | |
| instruction = clean_whitespace(params.instruction) | |
| else: | |
| instruction = params.instruction | |
| table_texts = prompt_formatter.format_all_tables( | |
| params.tables, instruction=instruction | |
| ) | |
| # table_texts can be list of chat messages. Only join list of str. | |
| if table_texts: | |
| if isinstance(table_texts[0], str): | |
| table_text = prompt_formatter.table_sep.join(table_texts) | |
| else: | |
| table_text = table_texts | |
| else: | |
| table_text = "" | |
| if context: | |
| context_text = prompt_formatter.format_retrieved_context(context) | |
| else: | |
| context_text = "" if isinstance(table_text, str) else [] | |
| prompt = prompt_formatter.format_prompt( | |
| instruction, | |
| table_text, | |
| context_text, | |
| ) | |
| return prompt | |
| # If no inputs, return nothing | |
| if not params: | |
| return [] | |
| # Stitch together demonstrations and params | |
| prompts: list[str | list[dict]] = [] | |
| for i, param in tqdm( | |
| enumerate(params), | |
| total=len(params), | |
| desc="Constructing prompts", | |
| disable=not verbose, | |
| ): | |
| predict_str = construct_params(param, extra_context[i] if extra_context else []) | |
| if isinstance(predict_str, str): | |
| prompt = predict_str.lstrip() | |
| else: | |
| prompt = predict_str | |
| prompts.append(prompt) | |
| manifest_params = dict( | |
| max_tokens=max_tokens, | |
| overwrite_cache=overwrite_manifest, | |
| num_beams=num_beams, | |
| logprobs=5, | |
| temperature=0.1, | |
| do_sample=False, | |
| stop_sequences=stop_sequences or prompt_formatter.stop_sequences, | |
| ) | |
| ret: list[TextToSQLModelResponse] = [] | |
| if len(params) == 1: | |
| prompt = prompts[0] | |
| success = False | |
| retries = 0 | |
| while not success and retries < 5: | |
| try: | |
| model_response = _run_manifest( | |
| prompt, | |
| manifest_params, | |
| prompt_formatter, | |
| manifest, | |
| stop_sequences=stop_sequences, | |
| ) | |
| success = True | |
| except: | |
| retries +=1 | |
| usage = model_response.usage | |
| model_response.usage = usage | |
| ret.append(model_response) | |
| else: | |
| # We do not handle retry logic on parallel requests right now | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| response = cast( | |
| Response, | |
| loop.run_until_complete( | |
| manifest.arun_batch( | |
| prompts, | |
| **manifest_params, # type: ignore | |
| ), | |
| ), | |
| ) | |
| loop.close() | |
| response_usage = response.get_usage() | |
| response_text = response.get_parsed_response() | |
| for prompt, resp in zip(prompts, response_text): | |
| # This will restitch the query in the case we force it to start with SELECT | |
| sql_query = prompt_formatter.format_model_output(cast(str, resp), prompt) | |
| for token in stop_sequences: | |
| sql_query = sql_query.split(token)[0] | |
| logger.info(f"FINAL OUTPUT: {sql_query}") | |
| ret.append( | |
| TextToSQLModelResponse( | |
| output=sql_query, | |
| raw_output=cast(str, resp), | |
| final_prompt=prompt, | |
| usage=response_usage, | |
| ) | |
| ) | |
| return ret | |
| def _run_manifest( | |
| prompt: str | list[str], | |
| manifest_params: dict, | |
| prompt_formatter: RajkumarFormatter, | |
| manifest: Manifest, | |
| stop_sequences: list[str] | None = None, | |
| ) -> TextToSQLModelResponse: | |
| """Run manifest for prompt format.""" | |
| logger.info(f"PARAMS: {manifest_params}") | |
| start_time = time.time() | |
| # Run result | |
| response = cast( | |
| Response, | |
| manifest.run( | |
| prompt, | |
| return_response=True, | |
| client_timeout=1800, | |
| **manifest_params, # type: ignore | |
| ), | |
| ) | |
| logger.info(f"TIME: {time.time() - start_time: .2f}") | |
| response_usage = response.get_usage_obj() | |
| summed_usage = Usage() | |
| for usage in response_usage.usages: | |
| summed_usage.completion_tokens += usage.completion_tokens | |
| summed_usage.prompt_tokens += usage.prompt_tokens | |
| summed_usage.total_tokens += usage.total_tokens | |
| # This will restitch the query in the case we force it to start with SELECT | |
| sql_query = prompt_formatter.format_model_output( | |
| cast(str, response.get_response()), prompt | |
| ) | |
| for token in stop_sequences: | |
| sql_query = sql_query.split(token)[0] | |
| logger.info(f"OUTPUT: {sql_query}") | |
| model_response = TextToSQLModelResponse( | |
| output=sql_query, | |
| raw_output=cast(str, response.get_response()), | |
| final_prompt=prompt, | |
| usage=summed_usage, | |
| ) | |
| return model_response | |