Commit
·
b000e50
1
Parent(s):
0d28c87
feat: update notification flow generation
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
|
@@ -10,7 +10,6 @@ from src.distilabel_dataset_generator.pipelines.sft import (
|
|
| 10 |
DEFAULT_DATASET,
|
| 11 |
DEFAULT_DATASET_DESCRIPTION,
|
| 12 |
DEFAULT_SYSTEM_PROMPT,
|
| 13 |
-
MODEL,
|
| 14 |
PROMPT_CREATION_PROMPT,
|
| 15 |
get_pipeline,
|
| 16 |
get_prompt_generation_step,
|
|
@@ -104,10 +103,6 @@ def generate_dataset(
|
|
| 104 |
else:
|
| 105 |
duration = 1000
|
| 106 |
|
| 107 |
-
gr.Info(
|
| 108 |
-
"Dataset generation started. This might take a while. Don't close the page.",
|
| 109 |
-
duration=duration,
|
| 110 |
-
)
|
| 111 |
result_queue = multiprocessing.Queue()
|
| 112 |
p = multiprocessing.Process(
|
| 113 |
target=_run_pipeline,
|
|
@@ -122,7 +117,7 @@ def generate_dataset(
|
|
| 122 |
break
|
| 123 |
progress(
|
| 124 |
(step + 1) / total_steps,
|
| 125 |
-
desc=f"Generating dataset with {num_rows} rows",
|
| 126 |
)
|
| 127 |
time.sleep(duration / total_steps) # Adjust this value based on your needs
|
| 128 |
p.join()
|
|
@@ -151,52 +146,11 @@ def generate_dataset(
|
|
| 151 |
return pd.DataFrame(outputs)
|
| 152 |
|
| 153 |
|
| 154 |
-
def generate_pipeline_code(
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
from distilabel.steps import KeepColumns
|
| 158 |
-
from distilabel.steps.tasks import MagpieGenerator
|
| 159 |
-
from distilabel.llms import InferenceEndpointsLLM
|
| 160 |
-
|
| 161 |
-
MODEL = "{MODEL}"
|
| 162 |
-
SYSTEM_PROMPT = "{system_prompt}"
|
| 163 |
-
# increase this to generate multi-turn conversations
|
| 164 |
-
NUM_TURNS = 1
|
| 165 |
-
# increase this to generate a larger dataset
|
| 166 |
-
NUM_ROWS = 100
|
| 167 |
-
|
| 168 |
-
with Pipeline(name="sft") as pipeline:
|
| 169 |
-
magpie = MagpieGenerator(
|
| 170 |
-
llm=InferenceEndpointsLLM(
|
| 171 |
-
model_id=MODEL,
|
| 172 |
-
tokenizer_id=MODEL,
|
| 173 |
-
magpie_pre_query_template="llama3",
|
| 174 |
-
generation_kwargs={{
|
| 175 |
-
"temperature": 0.8,
|
| 176 |
-
"do_sample": True,
|
| 177 |
-
"max_new_tokens": 2048,
|
| 178 |
-
"stop_sequences": [
|
| 179 |
-
"<|eot_id|>",
|
| 180 |
-
"<|end_of_text|>",
|
| 181 |
-
"<|start_header_id|>",
|
| 182 |
-
"<|end_header_id|>",
|
| 183 |
-
"assistant",
|
| 184 |
-
],
|
| 185 |
-
}}
|
| 186 |
-
),
|
| 187 |
-
n_turns=NUM_TURNS,
|
| 188 |
-
num_rows=NUM_ROWS,
|
| 189 |
-
system_prompt=SYSTEM_PROMPT,
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
-
if __name__ == "__main__":
|
| 193 |
-
distiset = pipeline.run()
|
| 194 |
-
"""
|
| 195 |
-
return code
|
| 196 |
|
| 197 |
-
|
| 198 |
-
def update_pipeline_code(system_prompt):
|
| 199 |
-
return generate_pipeline_code(system_prompt)
|
| 200 |
|
| 201 |
|
| 202 |
with gr.Blocks(
|
|
@@ -267,7 +221,7 @@ with gr.Blocks(
|
|
| 267 |
minimum=1,
|
| 268 |
maximum=4,
|
| 269 |
step=1,
|
| 270 |
-
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a '
|
| 271 |
)
|
| 272 |
num_rows = gr.Number(
|
| 273 |
value=100,
|
|
@@ -297,6 +251,7 @@ with gr.Blocks(
|
|
| 297 |
<div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
|
| 298 |
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
| 299 |
<p style="margin-top: 0.5em;">
|
|
|
|
| 300 |
Your dataset is now available at:
|
| 301 |
<a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
|
| 302 |
https://huggingface.co/datasets/{org_name}/{repo_name}
|
|
@@ -307,7 +262,13 @@ with gr.Blocks(
|
|
| 307 |
visible=True,
|
| 308 |
)
|
| 309 |
|
|
|
|
|
|
|
|
|
|
| 310 |
btn_generate_full_dataset.click(
|
|
|
|
|
|
|
|
|
|
| 311 |
fn=generate_dataset,
|
| 312 |
inputs=[
|
| 313 |
system_prompt,
|
|
@@ -329,13 +290,11 @@ with gr.Blocks(
|
|
| 329 |
gr.Markdown("## Or run this pipeline locally with distilabel")
|
| 330 |
|
| 331 |
with gr.Accordion("Run this pipeline on Distilabel", open=False):
|
| 332 |
-
pipeline_code = gr.Code(
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
outputs=[pipeline_code],
|
| 338 |
-
)
|
| 339 |
|
| 340 |
app.load(get_token, outputs=[hf_token])
|
| 341 |
app.load(get_org_dropdown, outputs=[org_name])
|
|
|
|
| 10 |
DEFAULT_DATASET,
|
| 11 |
DEFAULT_DATASET_DESCRIPTION,
|
| 12 |
DEFAULT_SYSTEM_PROMPT,
|
|
|
|
| 13 |
PROMPT_CREATION_PROMPT,
|
| 14 |
get_pipeline,
|
| 15 |
get_prompt_generation_step,
|
|
|
|
| 103 |
else:
|
| 104 |
duration = 1000
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
result_queue = multiprocessing.Queue()
|
| 107 |
p = multiprocessing.Process(
|
| 108 |
target=_run_pipeline,
|
|
|
|
| 117 |
break
|
| 118 |
progress(
|
| 119 |
(step + 1) / total_steps,
|
| 120 |
+
desc=f"Generating dataset with {num_rows} rows. Don't close this window.",
|
| 121 |
)
|
| 122 |
time.sleep(duration / total_steps) # Adjust this value based on your needs
|
| 123 |
p.join()
|
|
|
|
| 146 |
return pd.DataFrame(outputs)
|
| 147 |
|
| 148 |
|
| 149 |
+
def generate_pipeline_code() -> str:
|
| 150 |
+
with open("src/distilabel_dataset_generator/pipelines/sft.py", "r") as f:
|
| 151 |
+
pipeline_code = f.read()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
return pipeline_code
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
with gr.Blocks(
|
|
|
|
| 221 |
minimum=1,
|
| 222 |
maximum=4,
|
| 223 |
step=1,
|
| 224 |
+
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
| 225 |
)
|
| 226 |
num_rows = gr.Number(
|
| 227 |
value=100,
|
|
|
|
| 251 |
<div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
|
| 252 |
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
| 253 |
<p style="margin-top: 0.5em;">
|
| 254 |
+
The generated dataset is in the right format for Fine-tuning with TRL, AutoTrain or other frameworks.
|
| 255 |
Your dataset is now available at:
|
| 256 |
<a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
|
| 257 |
https://huggingface.co/datasets/{org_name}/{repo_name}
|
|
|
|
| 262 |
visible=True,
|
| 263 |
)
|
| 264 |
|
| 265 |
+
def hide_success_message():
|
| 266 |
+
return gr.Markdown(visible=False)
|
| 267 |
+
|
| 268 |
btn_generate_full_dataset.click(
|
| 269 |
+
fn=hide_success_message,
|
| 270 |
+
outputs=[success_message],
|
| 271 |
+
).then(
|
| 272 |
fn=generate_dataset,
|
| 273 |
inputs=[
|
| 274 |
system_prompt,
|
|
|
|
| 290 |
gr.Markdown("## Or run this pipeline locally with distilabel")
|
| 291 |
|
| 292 |
with gr.Accordion("Run this pipeline on Distilabel", open=False):
|
| 293 |
+
pipeline_code = gr.Code(
|
| 294 |
+
value=generate_pipeline_code(),
|
| 295 |
+
language="python",
|
| 296 |
+
label="Distilabel Pipeline Code",
|
| 297 |
+
)
|
|
|
|
|
|
|
| 298 |
|
| 299 |
app.load(get_token, outputs=[hf_token])
|
| 300 |
app.load(get_org_dropdown, outputs=[org_name])
|
src/distilabel_dataset_generator/utils.py
CHANGED
|
@@ -39,7 +39,7 @@ def get_login_button():
|
|
| 39 |
or get_space() is None
|
| 40 |
):
|
| 41 |
return gr.LoginButton(
|
| 42 |
-
value="Sign in with Hugging Face
|
| 43 |
size="lg",
|
| 44 |
)
|
| 45 |
|
|
|
|
| 39 |
or get_space() is None
|
| 40 |
):
|
| 41 |
return gr.LoginButton(
|
| 42 |
+
value="Sign in with Hugging Face! (This resets the session)",
|
| 43 |
size="lg",
|
| 44 |
)
|
| 45 |
|