DanielIglesias97 commited on
Commit
28032f0
·
1 Parent(s): 04178ea

We have included a new model that copes with sentences in Spanish.

Browse files
Files changed (2) hide show
  1. main_service.py +6 -3
  2. utils_model.py +10 -18
main_service.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  from utils_model import ModelFactory
3
 
4
- def retrieve_embeddings(input_text_query):
5
- model_type = 'all-MiniLM-L6-v2'
6
  model_factory_obj = ModelFactory()
7
  model = model_factory_obj.create_model(model_type)
8
 
@@ -11,7 +10,11 @@ def retrieve_embeddings(input_text_query):
11
  return query_embeddings
12
 
13
  def build():
14
- app = gr.Interface(fn=retrieve_embeddings, inputs="text", outputs="dataframe")
 
 
 
 
15
 
16
  return app
17
 
 
1
  import gradio as gr
2
  from utils_model import ModelFactory
3
 
4
+ def retrieve_embeddings(input_text_query, model_type):
 
5
  model_factory_obj = ModelFactory()
6
  model = model_factory_obj.create_model(model_type)
7
 
 
10
  return query_embeddings
11
 
12
  def build():
13
+ models_list = ['all-MiniLM-L6-v2', 'sentence_similarity_spanish']
14
+
15
+ app = gr.Interface(fn=retrieve_embeddings,
16
+ inputs=["text", gr.Dropdown(models_list, label='Model type')],
17
+ outputs="dataframe")
18
 
19
  return app
20
 
utils_model.py CHANGED
@@ -10,12 +10,12 @@ class ModelFactory():
10
  def create_model(self, model_type):
11
  model = None
12
 
13
- if (model_type=='mock'):
14
- model = MockModel()
15
-
16
  if (model_type=='all-MiniLM-L6-v2'):
17
  model = MiniLM_L6_v2_Model()
18
 
 
 
 
19
  return model
20
 
21
  class BaseModel():
@@ -24,26 +24,18 @@ class BaseModel():
24
  pass
25
 
26
  def retrieve_embeddings(self, input_text):
27
- pass
28
-
29
- class MockModel(BaseModel):
30
-
31
- def __init__(self):
32
- pass
33
-
34
- def retrieve_embeddings(self, input_text):
35
- random_embeddings = np.random.randint(256, size=(370))/256
36
 
37
- return pd.DataFrame(random_embeddings)
38
 
39
  class MiniLM_L6_v2_Model(BaseModel):
40
 
41
  def __init__(self):
42
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
43
 
44
- def retrieve_embeddings(self, input_text):
45
- embeddings = self.model.encode(input_text, batch_size=32)
46
- embeddings *= 255
47
- embeddings = embeddings.astype(np.uint8).tolist()
48
 
49
- return embeddings
 
 
10
  def create_model(self, model_type):
11
  model = None
12
 
 
 
 
13
  if (model_type=='all-MiniLM-L6-v2'):
14
  model = MiniLM_L6_v2_Model()
15
 
16
+ if (model_type=='sentence_similarity_spanish'):
17
+ model = SentenceSimilaritySpanishModel()
18
+
19
  return model
20
 
21
  class BaseModel():
 
24
  pass
25
 
26
  def retrieve_embeddings(self, input_text):
27
+ embeddings = self.model.encode(input_text, batch_size=32)
28
+ embeddings *= 255
29
+ embeddings = embeddings.astype(np.uint8).tolist()
 
 
 
 
 
 
30
 
31
+ return embeddings
32
 
33
  class MiniLM_L6_v2_Model(BaseModel):
34
 
35
  def __init__(self):
36
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
37
 
38
+ class SentenceSimilaritySpanishModel(BaseModel):
 
 
 
39
 
40
+ def __init__(self):
41
+ self.model = SentenceTransformer('hiiamsid/sentence_similarity_spanish_es')