Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,23 +34,23 @@ DSPY_PREFIX_URL = "luna-code/dspy-codegen-350M-mono-prefix"
|
|
| 34 |
CS_EVO_PREFIX_URL = "luna-code/cs-codegen-350M-mono-evo-prefix"
|
| 35 |
|
| 36 |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_URL)
|
| 37 |
-
basemodel = AutoModelForCausalLM.from_pretrained(CHECKPOINT_URL)
|
| 38 |
|
| 39 |
-
sql_prefix = PeftModel.from_pretrained(basemodel, SQLMODEL_PREFIX_URL)
|
| 40 |
-
sfepy_prefix = PeftModel.from_pretrained(basemodel, SFEPY_PREFIX_URL)
|
| 41 |
-
megengine_prefix = PeftModel.from_pretrained(basemodel, MEGENGINE_PREFIX_URL)
|
| 42 |
-
main_evo_prefix = PeftModel.from_pretrained(basemodel, MAIN_EVO_PREFIX_URL)
|
| 43 |
|
| 44 |
-
sqlmodel_fft = AutoModelForCausalLM.from_pretrained(SQLMODEL_FFT_URL)
|
| 45 |
-
sfepy_fft = AutoModelForCausalLM.from_pretrained(SFEPY_FFT_URL)
|
| 46 |
-
megengine_fft = AutoModelForCausalLM.from_pretrained(MEGENGINE_FFT_URL)
|
| 47 |
-
main_evo_fft = AutoModelForCausalLM.from_pretrained(MAIN_EVO_FFT_URL)
|
| 48 |
-
main_fd_fft = AutoModelForCausalLM.from_pretrained(MAIN_FD_FFT_URL)
|
| 49 |
|
| 50 |
-
langchain_prefix = PeftModel.from_pretrained(basemodel, LANGCHAIN_PREFIX_URL)
|
| 51 |
-
llamaindex_prefix = PeftModel.from_pretrained(basemodel, LLAMAINDEX_PREFIX_URL)
|
| 52 |
-
dspy_prefix = PeftModel.from_pretrained(basemodel, DSPY_PREFIX_URL)
|
| 53 |
-
cs_evo_prefix = PeftModel.from_pretrained(basemodel, CS_EVO_PREFIX_URL)
|
| 54 |
|
| 55 |
# basemodel = ""
|
| 56 |
# sql_prefix = ""
|
|
@@ -147,8 +147,7 @@ theme = gr.themes.Monochrome(
|
|
| 147 |
)
|
| 148 |
|
| 149 |
def stream(model, code, generate_kwargs):
|
| 150 |
-
|
| 151 |
-
input_ids = tokenizer(code, return_tensors="pt").to("cuda")
|
| 152 |
generated_ids = model.generate(**input_ids, **generate_kwargs)
|
| 153 |
return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
|
| 154 |
|
|
@@ -183,6 +182,8 @@ def generate(
|
|
| 183 |
output = stream(model_map["Main Evo FFT"], prompt, generate_kwargs)
|
| 184 |
elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
|
| 185 |
output = stream(model_map["Main FD FFT"], prompt, generate_kwargs)
|
|
|
|
|
|
|
| 186 |
else:
|
| 187 |
output = ""
|
| 188 |
|
|
@@ -241,7 +242,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
|
|
| 241 |
gr.Markdown(description)
|
| 242 |
with gr.Row():
|
| 243 |
library = gr.Dropdown(
|
| 244 |
-
["SQLModel", "SfePy", "MegEngine", "LangChain", "LlamaIndex", "
|
| 245 |
value="LangChain",
|
| 246 |
label="Library",
|
| 247 |
info="Choose a library from the list",
|
|
|
|
| 34 |
CS_EVO_PREFIX_URL = "luna-code/cs-codegen-350M-mono-evo-prefix"
|
| 35 |
|
| 36 |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_URL)
|
| 37 |
+
basemodel = AutoModelForCausalLM.from_pretrained(CHECKPOINT_URL, device_map="auto")
|
| 38 |
|
| 39 |
+
sql_prefix = PeftModel.from_pretrained(basemodel, SQLMODEL_PREFIX_URL, device_map="auto")
|
| 40 |
+
sfepy_prefix = PeftModel.from_pretrained(basemodel, SFEPY_PREFIX_URL, device_map="auto")
|
| 41 |
+
megengine_prefix = PeftModel.from_pretrained(basemodel, MEGENGINE_PREFIX_URL, device_map="auto")
|
| 42 |
+
main_evo_prefix = PeftModel.from_pretrained(basemodel, MAIN_EVO_PREFIX_URL, device_map="auto")
|
| 43 |
|
| 44 |
+
sqlmodel_fft = AutoModelForCausalLM.from_pretrained(SQLMODEL_FFT_URL, device_map="auto")
|
| 45 |
+
sfepy_fft = AutoModelForCausalLM.from_pretrained(SFEPY_FFT_URL, device_map="auto")
|
| 46 |
+
megengine_fft = AutoModelForCausalLM.from_pretrained(MEGENGINE_FFT_URL, device_map="auto")
|
| 47 |
+
main_evo_fft = AutoModelForCausalLM.from_pretrained(MAIN_EVO_FFT_URL, device_map="auto")
|
| 48 |
+
main_fd_fft = AutoModelForCausalLM.from_pretrained(MAIN_FD_FFT_URL, device_map="auto")
|
| 49 |
|
| 50 |
+
langchain_prefix = PeftModel.from_pretrained(basemodel, LANGCHAIN_PREFIX_URL, device_map="auto")
|
| 51 |
+
llamaindex_prefix = PeftModel.from_pretrained(basemodel, LLAMAINDEX_PREFIX_URL, device_map="auto")
|
| 52 |
+
dspy_prefix = PeftModel.from_pretrained(basemodel, DSPY_PREFIX_URL, device_map="auto")
|
| 53 |
+
cs_evo_prefix = PeftModel.from_pretrained(basemodel, CS_EVO_PREFIX_URL, device_map="auto")
|
| 54 |
|
| 55 |
# basemodel = ""
|
| 56 |
# sql_prefix = ""
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
def stream(model, code, generate_kwargs):
|
| 150 |
+
input_ids = tokenizer(code, return_tensors="pt").to(device)
|
|
|
|
| 151 |
generated_ids = model.generate(**input_ids, **generate_kwargs)
|
| 152 |
return tokenizer.decode(generated_ids[0][input_ids["input_ids"].shape[1]:], skip_special_tokens=True).strip()
|
| 153 |
|
|
|
|
| 182 |
output = stream(model_map["Main Evo FFT"], prompt, generate_kwargs)
|
| 183 |
elif method == "Full Data FFT" and library in ["SQLModel", "SfePy", "MegEngine"]:
|
| 184 |
output = stream(model_map["Main FD FFT"], prompt, generate_kwargs)
|
| 185 |
+
elif method == "Evo Prefix" and library in ["LangChain", "LlamaIndex", "DSPy"]:
|
| 186 |
+
output = stream(model_map["CS Evo Prefix"], prompt, generate_kwargs)
|
| 187 |
else:
|
| 188 |
output = ""
|
| 189 |
|
|
|
|
| 242 |
gr.Markdown(description)
|
| 243 |
with gr.Row():
|
| 244 |
library = gr.Dropdown(
|
| 245 |
+
["SQLModel", "SfePy", "MegEngine", "LangChain", "LlamaIndex", "DSPy"],
|
| 246 |
value="LangChain",
|
| 247 |
label="Library",
|
| 248 |
info="Choose a library from the list",
|