DanielIglesias97 commited on
Commit
2dfcad3
·
1 Parent(s): 28032f0

This new implementation allows to obtain the embeddings of

Browse files
Files changed (1) hide show
  1. main_service.py +24 -15
main_service.py CHANGED
@@ -1,25 +1,34 @@
1
  import gradio as gr
2
  from utils_model import ModelFactory
3
 
4
- def retrieve_embeddings(input_text_query, model_type):
5
- model_factory_obj = ModelFactory()
6
- model = model_factory_obj.create_model(model_type)
7
 
8
- query_embeddings = model.retrieve_embeddings(input_text_query)
 
 
9
 
10
- return query_embeddings
11
 
12
- def build():
13
- models_list = ['all-MiniLM-L6-v2', 'sentence_similarity_spanish']
 
 
 
14
 
15
- app = gr.Interface(fn=retrieve_embeddings,
16
- inputs=["text", gr.Dropdown(models_list, label='Model type')],
17
- outputs="dataframe")
18
 
19
- return app
 
20
 
21
- def run(app):
22
- app.launch(server_name='0.0.0.0')
 
23
 
24
- app = build()
25
- run(app)
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from utils_model import ModelFactory
3
 
4
+ class GradioAppManager():
 
 
5
 
6
+ def __init__(self, model_type):
7
+ model_factory_obj = ModelFactory()
8
+ self.model = model_factory_obj.create_model(model_type)
9
 
10
+ def __retrieve_embeddings__(self, input_queries_df):
11
 
12
+ queries_list = input_queries_df.values
13
+ queries_embeddings_list = []
14
+ for current_query_aux in queries_list:
15
+ current_query_embeddings = self.model.retrieve_embeddings(current_query_aux)
16
+ queries_embeddings_list+=current_query_embeddings
17
 
18
+ return queries_embeddings_list
 
 
19
 
20
+ def build(self):
21
+ gr_input_dataframe = gr.Dataframe(headers=['queries'], datatype=['str'], row_count=2, col_count=(1, 'fixed'))
22
 
23
+ app = gr.Interface(fn=self.__retrieve_embeddings__,
24
+ inputs=[gr_input_dataframe],
25
+ outputs="dataframe")
26
 
27
+ return app
28
+
29
+ def run(self, app):
30
+ app.launch(server_name='0.0.0.0')
31
+
32
+ gradio_app_manager_obj = GradioAppManager('sentence_similarity_spanish')
33
+ app = gradio_app_manager_obj.build()
34
+ gradio_app_manager_obj.run(app)