un-index commited on
Commit
d69b53f
·
1 Parent(s): eabeaac
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -92,6 +92,7 @@ def get_generated_text(generated_text):
92
  # return generated_text
93
 
94
 
 
95
  def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY):
96
  try:
97
 
@@ -197,8 +198,10 @@ def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY
197
  elif model_idx == 2:
198
  url = "https://api-inference.huggingface.co/models/distilgpt2"
199
  generated_text = ""#context #""
 
 
200
  while len(generated_text) < max_length:
201
- payload = {"inputs": context, "parameters": {"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
202
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
203
  context = json.loads(response.content.decode("utf-8"))
204
  context = get_generated_text(context).strip()
 
92
  # return generated_text
93
 
94
 
95
+
96
  def f(context, temperature, top_p, max_length, model_idx, SPACE_VERIFICATION_KEY):
97
  try:
98
 
 
198
  elif model_idx == 2:
199
  url = "https://api-inference.huggingface.co/models/distilgpt2"
200
  generated_text = ""#context #""
201
+ # NOTE adding repetition penalty parameter
202
+ # NOTE maybe leave tha parameter and just write a function to remove repetitions
203
  while len(generated_text) < max_length:
204
+ payload = {"inputs": context, "parameters": {"repetition_penalty":20.0,"return_full_text":False, "max_new_tokens": 250, "temperature": temperature, "top_p": top_p}}
205
  response = requests.request("POST", url, data=json.dumps(payload), headers=headers)
206
  context = json.loads(response.content.decode("utf-8"))
207
  context = get_generated_text(context).strip()