Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from transformers import StoppingCriteriaList, StoppingCriteria | |
| from sentence_transformers import SentenceTransformer | |
| from pinecone import Pinecone | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct" | |
| model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct" | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Initialize Tokenizer & Model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| def read_file(file_path: str) -> str: | |
| """Read the contents of a file.""" | |
| with open(file_path, "r") as file: | |
| return file.read() | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| model.eval() | |
| model.to(device) | |
| document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased") | |
| # Note: 'index1' has been pre-created in the pinecone console | |
| # read the pinecone api key from a file | |
| pinecone_api_key = st.secrets["pinecone_api_key"] | |
| pc = Pinecone(api_key=pinecone_api_key) | |
| index = pc.Index("index1") | |
| def query_pincecone_namespace( | |
| vector_databse_index: Pinecone, q_embedding: str, namespace: str | |
| ) -> str: | |
| result = vector_databse_index.query( | |
| namespace=namespace, | |
| vector=q_embedding.tolist(), | |
| top_k=1, | |
| include_values=True, | |
| include_metadata=True, | |
| ) | |
| results = [] | |
| for match in result.matches: | |
| results.append(match.metadata["paragraph"]) | |
| return results[0] | |
| def generate_prompt(llmprompt: str) -> str: | |
| """Generates a prompt for the GPT-3 model""" | |
| start_token = "<|endoftext|><s>" | |
| end_token = "<s>" | |
| return f"{start_token}\nUser:\n{llmprompt}\n{end_token}\nBot:\n".strip() | |
| def encode_query(query: str) -> torch.Tensor: | |
| """Encode the query using the model's tokenizer""" | |
| return document_encoder_model.encode(query) | |
| class StopOnTokenCriteria(StoppingCriteria): | |
| def __init__(self, stop_token_id): | |
| self.stop_token_id = stop_token_id | |
| def __call__(self, input_ids, scores, **kwargs): | |
| return input_ids[0, -1] == self.stop_token_id | |
| stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id) | |
| st.title("Paralegal Assistant") | |
| st.subheader("RAG: föräldrabalken") | |
| # Initialize chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display chat messages from history on app rerun | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # React to user input | |
| if prompt := st.chat_input("Skriv din fråga..."): | |
| # Display user message in chat message container | |
| st.chat_message("user").markdown(prompt) | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| query = query_pincecone_namespace( | |
| vector_databse_index=index, | |
| q_embedding=encode_query(query=prompt), | |
| namespace="ns-parent-balk", | |
| ) | |
| llmprompt = ( | |
| "Följande stycke är en del av lagen: " | |
| + query | |
| +"Referera till lagen och besvara följande fråga på ett sakligt, kortfattat och formellt vis: " | |
| + prompt | |
| ) | |
| llmprompt = generate_prompt(llmprompt=llmprompt) | |
| # # Convert prompt to tokens | |
| input_ids = tokenizer(llmprompt, return_tensors="pt")["input_ids"].to(device) | |
| # Genqerate tokens based om prompt | |
| generated_token_ids = model.generate( | |
| inputs=input_ids, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| temperature=0.8, | |
| top_p=1, | |
| stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]), | |
| )[0] | |
| # Decode the generated tokens | |
| generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1]) | |
| response = f"{generated_text}" | |
| # Display assistant response in chat message container | |
| with st.chat_message("assistant"): | |
| st.markdown(f"```{query}```\n" + response) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) |