Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,7 +23,7 @@ provider = TransformersProvider(model2, tokenizer, device)
|
|
| 23 |
strategy = CreativeWritingStrategy(provider,
|
| 24 |
top_p_flat = 0.65,
|
| 25 |
top_k_threshold_flat = 9,
|
| 26 |
-
eos_penalty = 0.
|
| 27 |
creative_sampler = BacktrackSampler(provider, strategy)
|
| 28 |
|
| 29 |
def create_chat_template_messages(history, prompt):
|
|
@@ -44,13 +44,13 @@ def generate_responses(prompt, history):
|
|
| 44 |
|
| 45 |
async def custom_sampler_task():
|
| 46 |
generated_list = []
|
| 47 |
-
generator = creative_sampler.generate(wrapped_prompt,
|
| 48 |
for token in generator:
|
| 49 |
generated_list.append(token)
|
| 50 |
return tokenizer.decode(generated_list, skip_special_tokens=True)
|
| 51 |
|
| 52 |
custom_output = asyncio.run(custom_sampler_task())
|
| 53 |
-
standard_output = model1.generate(inputs,
|
| 54 |
standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
|
| 55 |
|
| 56 |
return standard_response.strip(), custom_output.strip()
|
|
|
|
| 23 |
strategy = CreativeWritingStrategy(provider,
|
| 24 |
top_p_flat = 0.65,
|
| 25 |
top_k_threshold_flat = 9,
|
| 26 |
+
eos_penalty = 0.75)
|
| 27 |
creative_sampler = BacktrackSampler(provider, strategy)
|
| 28 |
|
| 29 |
def create_chat_template_messages(history, prompt):
|
|
|
|
| 44 |
|
| 45 |
async def custom_sampler_task():
|
| 46 |
generated_list = []
|
| 47 |
+
generator = creative_sampler.generate(wrapped_prompt, max_new_tokens=1024, temperature=1)
|
| 48 |
for token in generator:
|
| 49 |
generated_list.append(token)
|
| 50 |
return tokenizer.decode(generated_list, skip_special_tokens=True)
|
| 51 |
|
| 52 |
custom_output = asyncio.run(custom_sampler_task())
|
| 53 |
+
standard_output = model1.generate(inputs, max_new_tokens=1024, temperature=1)
|
| 54 |
standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
|
| 55 |
|
| 56 |
return standard_response.strip(), custom_output.strip()
|