DanielIglesias97 commited on
Commit
9019d6b
·
1 Parent(s): 2dfcad3

We have modified the implementation to return the dataframe with

Browse files

the features, associating the set of embeddings with the text it
refers to.

Files changed (2) hide show
  1. main_service.py +1 -5
  2. utils_model.py +15 -2
main_service.py CHANGED
@@ -8,12 +8,8 @@ class GradioAppManager():
8
  self.model = model_factory_obj.create_model(model_type)
9
 
10
  def __retrieve_embeddings__(self, input_queries_df):
11
-
12
  queries_list = input_queries_df.values
13
- queries_embeddings_list = []
14
- for current_query_aux in queries_list:
15
- current_query_embeddings = self.model.retrieve_embeddings(current_query_aux)
16
- queries_embeddings_list+=current_query_embeddings
17
 
18
  return queries_embeddings_list
19
 
 
8
  self.model = model_factory_obj.create_model(model_type)
9
 
10
  def __retrieve_embeddings__(self, input_queries_df):
 
11
  queries_list = input_queries_df.values
12
+ queries_embeddings_list = self.model.retrieve_embeddings_from_texts_list(queries_list)
 
 
 
13
 
14
  return queries_embeddings_list
15
 
utils_model.py CHANGED
@@ -23,12 +23,25 @@ class BaseModel():
23
  def __init__(self):
24
  pass
25
 
26
- def retrieve_embeddings(self, input_text):
27
  embeddings = self.model.encode(input_text, batch_size=32)
28
  embeddings *= 255
29
- embeddings = embeddings.astype(np.uint8).tolist()
30
 
31
  return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  class MiniLM_L6_v2_Model(BaseModel):
34
 
 
23
  def __init__(self):
24
  pass
25
 
26
+ def retrieve_embeddings_from_single_input_text(self, input_text):
27
  embeddings = self.model.encode(input_text, batch_size=32)
28
  embeddings *= 255
29
+ embeddings = embeddings.astype(np.uint8).astype(str).tolist()
30
 
31
  return embeddings
32
+
33
+ def retrieve_embeddings_from_texts_list(self, input_texts_list):
34
+ all_embeddings_list = []
35
+ for current_input_text_aux in input_texts_list:
36
+ embeddings = self.retrieve_embeddings_from_single_input_text(current_input_text_aux)
37
+ nof_features = len(embeddings[0])
38
+ all_embeddings_list += [current_input_text_aux.tolist() + embeddings[0]]
39
+
40
+ queries_embeddings_df = pd.DataFrame(all_embeddings_list)
41
+ columns_list = ['text'] + [f'feature_{idx}' for idx in range(0, nof_features)]
42
+ queries_embeddings_df.columns = columns_list
43
+
44
+ return queries_embeddings_df
45
 
46
  class MiniLM_L6_v2_Model(BaseModel):
47