add csv ingestion and prompt templates
Browse files- .gitignore +3 -0
- app.py +33 -33
- employees +0 -0
- llm.py +93 -0
- sql.py +102 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __pycache__
         | 
| 2 | 
            +
            ./__pycache__
         | 
| 3 | 
            +
            */__pycache
         | 
    	
        app.py
    CHANGED
    
    | @@ -1,35 +1,33 @@ | |
| 1 | 
             
            import streamlit as st
         | 
| 2 | 
            -
            from  | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
            repo_ir = "Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF"
         | 
| 6 | 
            -
            llm = Llama.from_pretrained(
         | 
| 7 | 
            -
                repo_id=repo_ir,
         | 
| 8 | 
            -
                filename="qwen2.5-coder-1.5b-instruct-q8_0.gguf",
         | 
| 9 | 
            -
                verbose=True,
         | 
| 10 | 
            -
                use_mmap=True,
         | 
| 11 | 
            -
                use_mlock=True,
         | 
| 12 | 
            -
                n_threads=4,
         | 
| 13 | 
            -
                n_threads_batch=4,
         | 
| 14 | 
            -
                n_ctx=8000,
         | 
| 15 | 
            -
            )
         | 
| 16 | 
            -
            print(f"{repo_ir} loaded successfully. ✅")
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
            # Streamed response emulator
         | 
| 20 | 
            -
            def response_generator(messages):
         | 
| 21 | 
            -
                completion = llm.create_chat_completion(
         | 
| 22 | 
            -
                    messages, max_tokens=2048, stream=True, temperature=0.7, top_p=0.95
         | 
| 23 | 
            -
                )
         | 
| 24 |  | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 |  | 
|  | |
| 30 |  | 
| 31 | 
             
            st.title("CSV TO SQL")
         | 
| 32 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 33 | 
             
            # Initialize chat history
         | 
| 34 | 
             
            if "messages" not in st.session_state:
         | 
| 35 | 
             
                st.session_state.messages = []
         | 
| @@ -47,14 +45,16 @@ if prompt := st.chat_input("What is up?"): | |
| 47 | 
             
                with st.chat_message("user"):
         | 
| 48 | 
             
                    st.markdown(prompt)
         | 
| 49 |  | 
| 50 | 
            -
                messages = [{"role": "system", "content": "You are a helpful assistant"}]
         | 
| 51 | 
            -
             | 
| 52 | 
            -
                for val in st.session_state.messages:
         | 
| 53 | 
            -
                    messages.append(val)
         | 
| 54 | 
            -
             | 
| 55 | 
            -
                messages.append({"role": "user", "content": prompt})
         | 
| 56 | 
             
                # Display assistant response in chat message container
         | 
| 57 | 
             
                with st.chat_message("assistant"):
         | 
| 58 | 
            -
                    response = st. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 59 | 
             
                # Add assistant response to chat history
         | 
| 60 | 
             
                st.session_state.messages.append({"role": "assistant", "content": response})
         | 
|  | |
| 1 | 
             
            import streamlit as st
         | 
| 2 | 
            +
            from llm import load_llm, response_generator
         | 
| 3 | 
            +
            from sql import csv_to_sqlite
         | 
| 4 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            +
            # repo_id = "Qwen/Qwen2.5-Coder-1.5B-Instruct-GGUF"
         | 
| 7 | 
            +
            repo_id = "Qwen/Qwen2.5-0.5B-Instruct-GGUF"
         | 
| 8 | 
            +
            # filename="qwen2.5-coder-1.5b-instruct-q8_0.gguf"
         | 
| 9 | 
            +
            filename = "qwen2.5-0.5b-instruct-q8_0.gguf"
         | 
| 10 |  | 
| 11 | 
            +
            llm = load_llm(repo_id, filename)
         | 
| 12 |  | 
| 13 | 
             
            st.title("CSV TO SQL")
         | 
| 14 |  | 
| 15 | 
            +
            with st.expander("Upload CSV"):
         | 
| 16 | 
            +
                csv_file = st.file_uploader(
         | 
| 17 | 
            +
                    "CSV",
         | 
| 18 | 
            +
                )
         | 
| 19 | 
            +
                db_name = st.text_input("DB Name")
         | 
| 20 | 
            +
                table_name = st.text_input("Table Name")
         | 
| 21 | 
            +
                if st.button("Save"):
         | 
| 22 | 
            +
                    if csv_file and db_name and table_name:
         | 
| 23 | 
            +
                        st.session_state.db_name = db_name
         | 
| 24 | 
            +
                        st.session_state.table_name = table_name
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                        csv_to_sqlite(csv_file, db_name, table_name)
         | 
| 27 | 
            +
                        st.write("Saved ✅")
         | 
| 28 | 
            +
                    else:
         | 
| 29 | 
            +
                        st.write("Please enter all values")
         | 
| 30 | 
            +
             | 
| 31 | 
             
            # Initialize chat history
         | 
| 32 | 
             
            if "messages" not in st.session_state:
         | 
| 33 | 
             
                st.session_state.messages = []
         | 
|  | |
| 45 | 
             
                with st.chat_message("user"):
         | 
| 46 | 
             
                    st.markdown(prompt)
         | 
| 47 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 48 | 
             
                # Display assistant response in chat message container
         | 
| 49 | 
             
                with st.chat_message("assistant"):
         | 
| 50 | 
            +
                    response = st.write(
         | 
| 51 | 
            +
                        response_generator(
         | 
| 52 | 
            +
                            db_name=st.session_state.db_name,
         | 
| 53 | 
            +
                            table_name=st.session_state.table_name,
         | 
| 54 | 
            +
                            llm=llm,
         | 
| 55 | 
            +
                            messages=st.session_state.messages,
         | 
| 56 | 
            +
                            question=prompt,
         | 
| 57 | 
            +
                        )
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
             
                # Add assistant response to chat history
         | 
| 60 | 
             
                st.session_state.messages.append({"role": "assistant", "content": response})
         | 
    	
        employees
    ADDED
    
    | Binary file (8.19 kB). View file | 
|  | 
    	
        llm.py
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import streamlit as st
         | 
| 2 | 
            +
            from llama_cpp import Llama
         | 
| 3 | 
            +
            from sql import get_table_schema
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            @st.cache_resource()
         | 
| 7 | 
            +
            def load_llm(repo_id, filename):
         | 
| 8 | 
            +
                llm = Llama.from_pretrained(
         | 
| 9 | 
            +
                    repo_id=repo_id,
         | 
| 10 | 
            +
                    filename=filename,
         | 
| 11 | 
            +
                    verbose=True,
         | 
| 12 | 
            +
                    use_mmap=True,
         | 
| 13 | 
            +
                    use_mlock=True,
         | 
| 14 | 
            +
                    n_threads=4,
         | 
| 15 | 
            +
                    n_threads_batch=4,
         | 
| 16 | 
            +
                    n_ctx=8000,
         | 
| 17 | 
            +
                )
         | 
| 18 | 
            +
                print(f"{repo_id} loaded successfully. ✅")
         | 
| 19 | 
            +
                return llm
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def generate_llm_prompt(table_name, table_schema):
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                Generates a prompt to provide context about a table's schema for LLM to convert natural language to SQL.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Args:
         | 
| 27 | 
            +
                    table_name (str): The name of the table.
         | 
| 28 | 
            +
                    table_schema (list): A list of tuples where each tuple contains information about the columns in the table.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Returns:
         | 
| 31 | 
            +
                    str: The generated prompt to be used by the LLM.
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                prompt = f"""You are an expert in writing SQL queries for relational databases. 
         | 
| 34 | 
            +
                You will be provided with a database schema and a natural 
         | 
| 35 | 
            +
                language question, and your task is to generate an accurate SQL query.
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                The database has a table named '{table_name}' with the following schema:\n\n"""
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                prompt += "Columns:\n"
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                for col in table_schema:
         | 
| 42 | 
            +
                    column_name = col[1]
         | 
| 43 | 
            +
                    column_type = col[2]
         | 
| 44 | 
            +
                    prompt += f"- {column_name} ({column_type})\n"
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                prompt += "\nPlease generate a SQL query based on the following natural language question. ONLY return the SQL query."
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                return prompt
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def generate_sql_query(question, table_name, db_name):
         | 
| 52 | 
            +
                pass
         | 
| 53 | 
            +
                # table_name = 'movies'
         | 
| 54 | 
            +
                # db_name = 'movies_db.db'
         | 
| 55 | 
            +
                # table_schema = get_table_schema(db_name, table_name)
         | 
| 56 | 
            +
                # llm_prompt = generate_llm_prompt(table_name, table_schema)
         | 
| 57 | 
            +
                # user_prompt = """Question: {question}"""
         | 
| 58 | 
            +
                # response = completion(
         | 
| 59 | 
            +
                #     api_key=OPENAI_API_KEY,
         | 
| 60 | 
            +
                #     model="gpt-4o-mini",
         | 
| 61 | 
            +
                #     messages=[
         | 
| 62 | 
            +
                #         ,
         | 
| 63 | 
            +
                #         {"content": user_prompt.format(question=question),"role": "user"}],
         | 
| 64 | 
            +
                #     max_tokens=1000
         | 
| 65 | 
            +
                # )
         | 
| 66 | 
            +
                # answer = response.choices[0].message.content
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # query = answer.replace("```sql", "").replace("```", "")
         | 
| 69 | 
            +
                # query = query.strip()
         | 
| 70 | 
            +
                # return query
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            # Streamed response emulator
         | 
| 74 | 
            +
            def response_generator(llm, messages, question, table_name, db_name):
         | 
| 75 | 
            +
                table_schema = get_table_schema(db_name, table_name)
         | 
| 76 | 
            +
                llm_prompt = generate_llm_prompt(table_name, table_schema)
         | 
| 77 | 
            +
                user_prompt = """Question: {question}"""
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                messages = [{"content": llm_prompt.format(table_name=table_name), "role": "system"}]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                for val in st.session_state.messages:
         | 
| 82 | 
            +
                    messages.append(val)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                messages.append({"role": "user", "content": user_prompt})
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                response = llm.create_chat_completion(
         | 
| 87 | 
            +
                    messages, max_tokens=2048, temperature=0.7, top_p=0.95
         | 
| 88 | 
            +
                )
         | 
| 89 | 
            +
                answer = response["choices"][0].message.content
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                query = answer.replace("```sql", "").replace("```", "")
         | 
| 92 | 
            +
                query = query.strip()
         | 
| 93 | 
            +
                return query
         | 
    	
        sql.py
    ADDED
    
    | @@ -0,0 +1,102 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import pandas as pd
         | 
| 2 | 
            +
            import sqlite3
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def csv_to_sqlite(csv_file, db_name, table_name):
         | 
| 6 | 
            +
                # Read the CSV file into a pandas DataFrame
         | 
| 7 | 
            +
                df = pd.read_csv(csv_file)
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                # Connect to the SQLite database (it will create the database file if it doesn't exist)
         | 
| 10 | 
            +
                conn = sqlite3.connect(db_name)
         | 
| 11 | 
            +
                cursor = conn.cursor()
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                # Infer the schema based on the DataFrame columns and data types
         | 
| 14 | 
            +
                def create_table_from_df(df, table_name):
         | 
| 15 | 
            +
                    # Get column names and types
         | 
| 16 | 
            +
                    col_types = []
         | 
| 17 | 
            +
                    for col in df.columns:
         | 
| 18 | 
            +
                        dtype = df[col].dtype
         | 
| 19 | 
            +
                        if dtype == "int64":
         | 
| 20 | 
            +
                            col_type = "INTEGER"
         | 
| 21 | 
            +
                        elif dtype == "float64":
         | 
| 22 | 
            +
                            col_type = "REAL"
         | 
| 23 | 
            +
                        else:
         | 
| 24 | 
            +
                            col_type = "TEXT"
         | 
| 25 | 
            +
                        col_types.append(f'"{col}" {col_type}')
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    # Create the table schema
         | 
| 28 | 
            +
                    col_definitions = ", ".join(col_types)
         | 
| 29 | 
            +
                    create_table_query = (
         | 
| 30 | 
            +
                        f"CREATE TABLE IF NOT EXISTS {table_name} ({col_definitions});"
         | 
| 31 | 
            +
                    )
         | 
| 32 | 
            +
                    # print(create_table_query)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    # Execute the table creation query
         | 
| 35 | 
            +
                    cursor.execute(create_table_query)
         | 
| 36 | 
            +
                    print(f"Table '{table_name}' created with schema: {col_definitions}")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # Create table schema
         | 
| 39 | 
            +
                create_table_from_df(df, table_name)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                # Insert CSV data into the SQLite table
         | 
| 42 | 
            +
                df.to_sql(table_name, conn, if_exists="replace", index=False)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # Commit and close the connection
         | 
| 45 | 
            +
                conn.commit()
         | 
| 46 | 
            +
                conn.close()
         | 
| 47 | 
            +
                print(f"Data loaded into '{table_name}' table in '{db_name}' SQLite database.")
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def run_sql_query(db_name, query):
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                Executes a SQL query on a SQLite database and returns the results.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                Args:
         | 
| 55 | 
            +
                    db_name (str): The name of the SQLite database file.
         | 
| 56 | 
            +
                    query (str): The SQL query to run.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                Returns:
         | 
| 59 | 
            +
                    list: Query result as a list of tuples, or an empty list if no results or error occurred.
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                try:
         | 
| 62 | 
            +
                    # Connect to the SQLite database
         | 
| 63 | 
            +
                    conn = sqlite3.connect(db_name)
         | 
| 64 | 
            +
                    cursor = conn.cursor()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    # Execute the SQL query
         | 
| 67 | 
            +
                    cursor.execute(query)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # Fetch all results
         | 
| 70 | 
            +
                    results = cursor.fetchall()
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    # Close the connection
         | 
| 73 | 
            +
                    conn.close()
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # Return results or an empty list if no results were found
         | 
| 76 | 
            +
                    return results if results else []
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                except sqlite3.Error as e:
         | 
| 79 | 
            +
                    print(f"An error occurred while executing the query: {e}")
         | 
| 80 | 
            +
                    return []
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def get_table_schema(db_name, table_name):
         | 
| 84 | 
            +
                """
         | 
| 85 | 
            +
                Retrieves the schema (columns and data types) for a given table in the SQLite database.
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                Args:
         | 
| 88 | 
            +
                    db_name (str): The name of the SQLite database file.
         | 
| 89 | 
            +
                    table_name (str): The name of the table.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                Returns:
         | 
| 92 | 
            +
                    list: A list of tuples with column name, data type, and other info.
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                conn = sqlite3.connect(db_name)
         | 
| 95 | 
            +
                cursor = conn.cursor()
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                # Use PRAGMA to get the table schema
         | 
| 98 | 
            +
                cursor.execute(f"PRAGMA table_info({table_name});")
         | 
| 99 | 
            +
                schema = cursor.fetchall()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                conn.close()
         | 
| 102 | 
            +
                return schema
         |