Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Trent
		
	commited on
		
		
					Commit 
							
							·
						
						a41bdbc
	
1
								Parent(s):
							
							49438d6
								
Multi model select and local model loading
Browse files- __init__.py +0 -0
- app.py +12 -30
- backend/__init__.py +0 -0
- backend/config.py +1 -0
- backend/inference.py +9 -20
- backend/main.py +0 -19
- backend/utils.py +11 -0
- requirements.txt +1 -1
    	
        __init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        app.py
    CHANGED
    
    | @@ -1,7 +1,8 @@ | |
| 1 | 
             
            import streamlit as st
         | 
| 2 | 
             
            import pandas as pd
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            import  | 
|  | |
| 5 |  | 
| 6 | 
             
            st.title('Demo using Flax-Sentence-Tranformers')
         | 
| 7 |  | 
| @@ -20,12 +21,12 @@ For more cool information on sentence embeddings, see the [sBert project](https: | |
| 20 | 
             
            Please enjoy!!
         | 
| 21 | 
             
            ''')
         | 
| 22 |  | 
| 23 | 
            -
             | 
| 24 | 
             
            anchor = st.text_input(
         | 
| 25 | 
             
                'Please enter here the main text you want to compare:'
         | 
| 26 | 
             
            )
         | 
| 27 |  | 
| 28 | 
             
            if anchor:
         | 
|  | |
| 29 | 
             
                n_texts = st.sidebar.number_input(
         | 
| 30 | 
             
                    f'''How many texts you want to compare with: '{anchor}'?''',
         | 
| 31 | 
             
                    value=2,
         | 
| @@ -34,40 +35,21 @@ if anchor: | |
| 34 | 
             
                inputs = []
         | 
| 35 |  | 
| 36 | 
             
                for i in range(n_texts):
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                    input = st.sidebar.text_input(f'Text {i+1}:')
         | 
| 39 |  | 
| 40 | 
             
                    inputs.append(input)
         | 
| 41 |  | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
            api_base_url = 'http://127.0.0.1:8000/similarity'
         | 
| 45 | 
            -
             | 
| 46 | 
             
            if anchor:
         | 
| 47 | 
             
                if st.sidebar.button('Tell me the similarity.'):
         | 
| 48 | 
            -
                     | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
                     | 
| 52 | 
            -
                                                                               inputs = inputs,
         | 
| 53 | 
            -
                                                                               model = 'mpnet'))
         | 
| 54 | 
            -
                    res_minilm_l6 = requests.get(url = api_base_url, params = dict(anchor = anchor,
         | 
| 55 | 
            -
                                                                                   inputs = inputs,
         | 
| 56 | 
            -
                                                                                   model = 'minilm_l6'))
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                    d_distilroberta = res_distilroberta.json()['dataframe']
         | 
| 59 | 
            -
                    d_mpnet = res_mpnet.json()['dataframe']
         | 
| 60 | 
            -
                    d_minilm_l6 = res_minilm_l6.json()['dataframe']
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                    index = list(d_distilroberta['inputs'].values())
         | 
| 63 | 
             
                    df_total = pd.DataFrame(index=index)
         | 
| 64 | 
            -
                     | 
| 65 | 
            -
             | 
| 66 | 
            -
                    df_total['minilm_l6'] = list(d_minilm_l6['score'].values())
         | 
| 67 |  | 
| 68 | 
            -
                    st.write('Here are the results for  | 
| 69 | 
             
                    st.write(df_total)
         | 
| 70 | 
             
                    st.write('Visualize the results of each model:')
         | 
| 71 | 
             
                    st.area_chart(df_total)
         | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
|  | |
| 1 | 
             
            import streamlit as st
         | 
| 2 | 
             
            import pandas as pd
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from backend import inference
         | 
| 5 | 
            +
            from backend.config import MODELS_ID
         | 
| 6 |  | 
| 7 | 
             
            st.title('Demo using Flax-Sentence-Tranformers')
         | 
| 8 |  | 
|  | |
| 21 | 
             
            Please enjoy!!
         | 
| 22 | 
             
            ''')
         | 
| 23 |  | 
|  | |
| 24 | 
             
            anchor = st.text_input(
         | 
| 25 | 
             
                'Please enter here the main text you want to compare:'
         | 
| 26 | 
             
            )
         | 
| 27 |  | 
| 28 | 
             
            if anchor:
         | 
| 29 | 
            +
                select_models = st.sidebar.multiselect("Choose models", options=MODELS_ID.keys())
         | 
| 30 | 
             
                n_texts = st.sidebar.number_input(
         | 
| 31 | 
             
                    f'''How many texts you want to compare with: '{anchor}'?''',
         | 
| 32 | 
             
                    value=2,
         | 
|  | |
| 35 | 
             
                inputs = []
         | 
| 36 |  | 
| 37 | 
             
                for i in range(n_texts):
         | 
| 38 | 
            +
                    input = st.sidebar.text_input(f'Text {i + 1}:')
         | 
|  | |
| 39 |  | 
| 40 | 
             
                    inputs.append(input)
         | 
| 41 |  | 
|  | |
|  | |
|  | |
|  | |
| 42 | 
             
            if anchor:
         | 
| 43 | 
             
                if st.sidebar.button('Tell me the similarity.'):
         | 
| 44 | 
            +
                    results = {model: inference.text_similarity(anchor, inputs, model) for model in select_models}
         | 
| 45 | 
            +
                    df_results = {model: results[model] for model in results}
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    index = inputs
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 48 | 
             
                    df_total = pd.DataFrame(index=index)
         | 
| 49 | 
            +
                    for key, value in df_results.items():
         | 
| 50 | 
            +
                        df_total[key] = list(value['score'].values)
         | 
|  | |
| 51 |  | 
| 52 | 
            +
                    st.write('Here are the results for selected models:')
         | 
| 53 | 
             
                    st.write(df_total)
         | 
| 54 | 
             
                    st.write('Visualize the results of each model:')
         | 
| 55 | 
             
                    st.area_chart(df_total)
         | 
|  | |
|  | 
    	
        backend/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        backend/config.py
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
| 1 | 
             
            MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
         | 
| 2 | 
             
                             mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
         | 
|  | |
| 3 | 
             
                             minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
         | 
|  | |
| 1 | 
             
            MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
         | 
| 2 | 
             
                             mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
         | 
| 3 | 
            +
                             mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
         | 
| 4 | 
             
                             minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
         | 
    	
        backend/inference.py
    CHANGED
    
    | @@ -1,41 +1,30 @@ | |
| 1 | 
            -
            from sentence_transformers import SentenceTransformer
         | 
| 2 | 
             
            import pandas as pd
         | 
| 3 | 
             
            import jax.numpy as jnp
         | 
| 4 |  | 
| 5 | 
             
            from typing import List
         | 
| 6 | 
            -
            import config
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            # We download the models we will be using.
         | 
| 9 | 
            -
            # If you do not want to use all, you can comment the unused ones.
         | 
| 10 | 
            -
            distilroberta_model = SentenceTransformer(config.MODELS_ID['distilroberta'])
         | 
| 11 | 
            -
            mpnet_model = SentenceTransformer(config.MODELS_ID['mpnet'])
         | 
| 12 | 
            -
            minilm_l6_model = SentenceTransformer(config.MODELS_ID['minilm_l6'])
         | 
| 13 |  | 
| 14 | 
             
            # Defining cosine similarity using flax.
         | 
|  | |
|  | |
|  | |
| 15 | 
             
            def cos_sim(a, b):
         | 
| 16 | 
            -
                return jnp.matmul(a, jnp.transpose(b))/(jnp.linalg.norm(a)*jnp.linalg.norm(b))
         | 
| 17 |  | 
| 18 |  | 
| 19 | 
             
            # We get similarity between embeddings.
         | 
| 20 | 
            -
            def text_similarity(anchor: str, inputs: List[str],  | 
|  | |
| 21 |  | 
| 22 | 
             
                # Creating embeddings
         | 
| 23 | 
            -
                 | 
| 24 | 
            -
             | 
| 25 | 
            -
                    inputs_emb = distilroberta_model.encode([input for input in inputs])
         | 
| 26 | 
            -
                elif model == 'mpnet':
         | 
| 27 | 
            -
                    anchor_emb = mpnet_model.encode(anchor)[None, :]
         | 
| 28 | 
            -
                    inputs_emb = mpnet_model.encode([input for input in inputs])
         | 
| 29 | 
            -
                elif model == 'minilm_l6':
         | 
| 30 | 
            -
                    anchor_emb = minilm_l6_model.encode(anchor)[None, :]
         | 
| 31 | 
            -
                    inputs_emb = minilm_l6_model.encode([input for input in inputs])
         | 
| 32 |  | 
| 33 | 
             
                # Obtaining similarity
         | 
| 34 | 
             
                similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
         | 
| 35 |  | 
| 36 | 
             
                # Returning a Pandas' dataframe
         | 
| 37 | 
             
                d = {'inputs': [input for input in inputs],
         | 
| 38 | 
            -
                     'score': [round(similarity[i],3) for i in range(len(similarity))]}
         | 
| 39 | 
             
                df = pd.DataFrame(d, columns=['inputs', 'score'])
         | 
| 40 |  | 
| 41 | 
             
                return df.sort_values('score', ascending=False)
         | 
|  | |
|  | |
| 1 | 
             
            import pandas as pd
         | 
| 2 | 
             
            import jax.numpy as jnp
         | 
| 3 |  | 
| 4 | 
             
            from typing import List
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
             
            # Defining cosine similarity using flax.
         | 
| 7 | 
            +
            from backend.utils import load_model
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
             
            def cos_sim(a, b):
         | 
| 11 | 
            +
                return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
         | 
| 12 |  | 
| 13 |  | 
| 14 | 
             
            # We get similarity between embeddings.
         | 
| 15 | 
            +
            def text_similarity(anchor: str, inputs: List[str], model_name: str):
         | 
| 16 | 
            +
                model = load_model(model_name)
         | 
| 17 |  | 
| 18 | 
             
                # Creating embeddings
         | 
| 19 | 
            +
                anchor_emb = model.encode(anchor)[None, :]
         | 
| 20 | 
            +
                inputs_emb = model.encode([input for input in inputs])
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 21 |  | 
| 22 | 
             
                # Obtaining similarity
         | 
| 23 | 
             
                similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
         | 
| 24 |  | 
| 25 | 
             
                # Returning a Pandas' dataframe
         | 
| 26 | 
             
                d = {'inputs': [input for input in inputs],
         | 
| 27 | 
            +
                     'score': [round(similarity[i], 3) for i in range(len(similarity))]}
         | 
| 28 | 
             
                df = pd.DataFrame(d, columns=['inputs', 'score'])
         | 
| 29 |  | 
| 30 | 
             
                return df.sort_values('score', ascending=False)
         | 
    	
        backend/main.py
    DELETED
    
    | @@ -1,19 +0,0 @@ | |
| 1 | 
            -
            from fastapi import Query, FastAPI
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            import config
         | 
| 4 | 
            -
            import inference
         | 
| 5 | 
            -
            from typing import List
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            app = FastAPI()
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            @app.get("/")
         | 
| 10 | 
            -
            def read_root():
         | 
| 11 | 
            -
                return {"message": "Welcome to the API of flax-sentence-embeddings."}
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            @app.get('/similarity')
         | 
| 14 | 
            -
            def get_similarity(anchor: str, inputs: List[str] = Query([]), model: str = 'distilroberta'):
         | 
| 15 | 
            -
                return {'dataframe': inference.text_similarity(anchor, inputs, model)}
         | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
            #if __name__ == "__main__":
         | 
| 19 | 
            -
            #    uvicorn.run("main:app", host="0.0.0.0", port=8080)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        backend/utils.py
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import streamlit as st
         | 
| 2 | 
            +
            from sentence_transformers import SentenceTransformer
         | 
| 3 | 
            +
            from .config import MODELS_ID
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            @st.cache(allow_output_mutation=True)
         | 
| 7 | 
            +
            def load_model(model_name):
         | 
| 8 | 
            +
                assert model_name in MODELS_ID.keys()
         | 
| 9 | 
            +
                # Lazy downloading
         | 
| 10 | 
            +
                model = SentenceTransformer(MODELS_ID[model_name])
         | 
| 11 | 
            +
                return model
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
            -
            fastapi
         | 
| 2 | 
             
            sentence_transformers
         | 
| 3 | 
             
            pandas
         | 
| 4 | 
             
            jax
         | 
|  | |
| 5 | 
             
            streamlit
         | 
|  | |
|  | |
| 1 | 
             
            sentence_transformers
         | 
| 2 | 
             
            pandas
         | 
| 3 | 
             
            jax
         | 
| 4 | 
            +
            jaxlib
         | 
| 5 | 
             
            streamlit
         | 
