bhardwaj08sarthak's picture
Upload 5 files
79418f8 verified
raw
history blame
981 Bytes
import spaces
def extract_top_level_json(s: str) -> str:
start = s.find("{")
if start == -1:
return ""
depth = 0
for i in range(start, len(s)):
ch = s[i]
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
candidate = s[start:i + 1]
try:
json.loads(candidate) # validate
return candidate
except Exception:
return ""
return ""
@spaces.GPU(duration=25)
def get_local_model_gpu(model_id: str):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
model.to(device)
model.eval()
return model