stefan-it commited on
Commit
63907b4
·
verified ·
1 Parent(s): c0f82f2

feat: add new additional inputs, including changing of hopefully better default generation parameters

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -19,8 +19,7 @@ tokenizer, model = load_model()
19
 
20
 
21
  @spaces.GPU
22
- def generate(prompt, history):
23
-
24
  if len(history) > 0:
25
  messages = history + [
26
  {"role": "user", "content": prompt},
@@ -42,7 +41,11 @@ def generate(prompt, history):
42
  with torch.no_grad():
43
  outputs = model.generate(
44
  **inputs,
45
- max_new_tokens=512,
 
 
 
 
46
  )
47
 
48
  generated_tokens = outputs[0, inputs.input_ids.shape[1]:]
@@ -51,5 +54,14 @@ def generate(prompt, history):
51
  return output
52
 
53
 
54
- demo = gr.ChatInterface(fn=generate, type="messages", examples=["Hallo", "Servus", "Hi"], title="German nanochat v1")
 
 
 
 
 
 
 
 
 
55
  demo.launch()
 
19
 
20
 
21
  @spaces.GPU
22
+ def generate(prompt, history, max_new_tokens, temperature, top_p, repetition_penalty, no_repeat_ngram_size):
 
23
  if len(history) > 0:
24
  messages = history + [
25
  {"role": "user", "content": prompt},
 
41
  with torch.no_grad():
42
  outputs = model.generate(
43
  **inputs,
44
+ max_new_tokens=max_new_tokens,
45
+ temperature=temperature,
46
+ top_p=top_p,
47
+ repetition_penalty=repetition_penalty,
48
+ no_repeat_ngram_size=no_repeat_ngram_size,
49
  )
50
 
51
  generated_tokens = outputs[0, inputs.input_ids.shape[1]:]
 
54
  return output
55
 
56
 
57
+ demo = gr.ChatInterface(fn=generate,
58
+ type="messages",
59
+ title="German nanochat v1",
60
+ additional_inputs=[
61
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
62
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.8, step=0.1, label="Temperature"),
63
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
64
+ gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"),
65
+ gr.Slider(minimum=0, maximum=5, value=3, step=1, label="No repeat of ngrams"),
66
+ ])
67
  demo.launch()