Spaces:
Runtime error
Runtime error
change device
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ def parse_args():
|
|
| 18 |
return parser.parse_args()
|
| 19 |
|
| 20 |
def predict(message, history, system_prompt, temperature, max_tokens):
|
| 21 |
-
global model, tokenizer
|
| 22 |
instruction = "<|im_start|>system\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|im_end|>\n"
|
| 23 |
for human, assistant in history:
|
| 24 |
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
|
|
@@ -33,8 +33,8 @@ def predict(message, history, system_prompt, temperature, max_tokens):
|
|
| 33 |
if input_ids.shape[1] > MAX_MAX_NEW_TOKENS:
|
| 34 |
input_ids = input_ids[:, -MAX_MAX_NEW_TOKENS:]
|
| 35 |
|
| 36 |
-
input_ids = input_ids.
|
| 37 |
-
attention_mask = attention_mask.
|
| 38 |
generate_kwargs = dict(
|
| 39 |
{"input_ids": input_ids, "attention_mask": attention_mask},
|
| 40 |
streamer=streamer,
|
|
@@ -59,7 +59,8 @@ if __name__ == "__main__":
|
|
| 59 |
args = parse_args()
|
| 60 |
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-instruct-3b")
|
| 61 |
model = AutoModelForCausalLM.from_pretrained("stabilityai/stable-code-instruct-3b")
|
| 62 |
-
|
|
|
|
| 63 |
gr.ChatInterface(
|
| 64 |
predict,
|
| 65 |
title="Stable Code Instruct Chat - Demo",
|
|
|
|
| 18 |
return parser.parse_args()
|
| 19 |
|
| 20 |
def predict(message, history, system_prompt, temperature, max_tokens):
|
| 21 |
+
global model, tokenizer, device
|
| 22 |
instruction = "<|im_start|>system\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|im_end|>\n"
|
| 23 |
for human, assistant in history:
|
| 24 |
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
|
|
|
|
| 33 |
if input_ids.shape[1] > MAX_MAX_NEW_TOKENS:
|
| 34 |
input_ids = input_ids[:, -MAX_MAX_NEW_TOKENS:]
|
| 35 |
|
| 36 |
+
input_ids = input_ids.to(device)
|
| 37 |
+
attention_mask = attention_mask.to(device)
|
| 38 |
generate_kwargs = dict(
|
| 39 |
{"input_ids": input_ids, "attention_mask": attention_mask},
|
| 40 |
streamer=streamer,
|
|
|
|
| 59 |
args = parse_args()
|
| 60 |
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-instruct-3b")
|
| 61 |
model = AutoModelForCausalLM.from_pretrained("stabilityai/stable-code-instruct-3b")
|
| 62 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 63 |
+
model = model.to(device)
|
| 64 |
gr.ChatInterface(
|
| 65 |
predict,
|
| 66 |
title="Stable Code Instruct Chat - Demo",
|