Spaces:
Runtime error
Runtime error
Wonderplex
commited on
added decorator to generate action (#69)
Browse files* added decorator to generate action
* changed decoration to prepare_model
- sotopia_generate.py +2 -1
sotopia_generate.py
CHANGED
|
@@ -3,7 +3,6 @@ import os
|
|
| 3 |
from typing import TypeVar
|
| 4 |
from functools import cache
|
| 5 |
import logging
|
| 6 |
-
import json
|
| 7 |
|
| 8 |
import torch
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
@@ -38,6 +37,7 @@ OutputType = TypeVar("OutputType", bound=object)
|
|
| 38 |
log = logging.getLogger("generate")
|
| 39 |
logging_handler = LoggingCallbackHandler("langchain")
|
| 40 |
|
|
|
|
| 41 |
def generate_action(
|
| 42 |
model_name: str,
|
| 43 |
history: str,
|
|
@@ -82,6 +82,7 @@ def generate_action(
|
|
| 82 |
# print(e)
|
| 83 |
# return AgentAction(action_type="none", argument="")
|
| 84 |
|
|
|
|
| 85 |
@cache
|
| 86 |
def prepare_model(model_name):
|
| 87 |
compute_type = torch.float16
|
|
|
|
| 3 |
from typing import TypeVar
|
| 4 |
from functools import cache
|
| 5 |
import logging
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
|
| 37 |
log = logging.getLogger("generate")
|
| 38 |
logging_handler = LoggingCallbackHandler("langchain")
|
| 39 |
|
| 40 |
+
# @spaces.GPU
|
| 41 |
def generate_action(
|
| 42 |
model_name: str,
|
| 43 |
history: str,
|
|
|
|
| 82 |
# print(e)
|
| 83 |
# return AgentAction(action_type="none", argument="")
|
| 84 |
|
| 85 |
+
@spaces.GPU(duration=1200)
|
| 86 |
@cache
|
| 87 |
def prepare_model(model_name):
|
| 88 |
compute_type = torch.float16
|