Update src/distilabel_dataset_generator/sft.py
Browse files
src/distilabel_dataset_generator/sft.py
CHANGED
|
@@ -223,13 +223,12 @@ def generate_dataset(
|
|
| 223 |
num_turns=1,
|
| 224 |
num_rows=5,
|
| 225 |
private=True,
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
token: OAuthToken = None,
|
| 229 |
progress=gr.Progress(),
|
| 230 |
):
|
| 231 |
-
if
|
| 232 |
-
if not
|
| 233 |
raise gr.Error("Please provide a dataset name to push the dataset to.")
|
| 234 |
if token is None:
|
| 235 |
raise gr.Error(
|
|
@@ -280,14 +279,13 @@ def generate_dataset(
|
|
| 280 |
|
| 281 |
distiset = result_queue.get()
|
| 282 |
|
| 283 |
-
if
|
| 284 |
progress(0.95, desc="Pushing dataset to Hugging Face Hub.")
|
| 285 |
-
repo_id = f"{orgs_selector}/{dataset_name}"
|
| 286 |
distiset.push_to_hub(
|
| 287 |
repo_id=repo_id,
|
| 288 |
private=private,
|
| 289 |
include_script=False,
|
| 290 |
-
token=token
|
| 291 |
)
|
| 292 |
gr.Info(
|
| 293 |
f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
|
|
@@ -339,7 +337,6 @@ with gr.Blocks(
|
|
| 339 |
)
|
| 340 |
gr.Column(scale=1)
|
| 341 |
|
| 342 |
-
#table = gr.HTML(_format_dataframe_as_html(DEFAULT_DATASET))
|
| 343 |
table = gr.DataFrame(
|
| 344 |
value=DEFAULT_DATASET,
|
| 345 |
interactive=False,
|
|
@@ -347,7 +344,7 @@ with gr.Blocks(
|
|
| 347 |
|
| 348 |
)
|
| 349 |
|
| 350 |
-
btn_generate_system_prompt.click(
|
| 351 |
fn=generate_system_prompt,
|
| 352 |
inputs=[dataset_description],
|
| 353 |
outputs=[system_prompt],
|
|
@@ -365,12 +362,10 @@ with gr.Blocks(
|
|
| 365 |
outputs=[table],
|
| 366 |
show_progress=True,
|
| 367 |
)
|
| 368 |
-
|
| 369 |
# Add a header for the full dataset generation section
|
| 370 |
-
gr.Markdown("## Generate full dataset
|
| 371 |
gr.Markdown("Once you're satisfied with the sample, generate a larger dataset and push it to the hub.")
|
| 372 |
-
|
| 373 |
-
btn_login: gr.LoginButton | None = get_login_button()
|
| 374 |
with gr.Column() as push_to_hub_ui:
|
| 375 |
with gr.Row(variant="panel"):
|
| 376 |
num_turns = gr.Number(
|
|
@@ -386,11 +381,12 @@ with gr.Blocks(
|
|
| 386 |
maximum=5000,
|
| 387 |
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
| 388 |
)
|
| 389 |
-
|
| 390 |
|
| 391 |
with gr.Row(variant="panel"):
|
| 392 |
-
|
| 393 |
-
|
|
|
|
| 394 |
|
| 395 |
btn_generate_full_dataset = gr.Button(
|
| 396 |
value="⚗️ Generate Full Dataset", variant="primary"
|
|
@@ -403,12 +399,8 @@ with gr.Blocks(
|
|
| 403 |
num_turns,
|
| 404 |
num_rows,
|
| 405 |
private,
|
| 406 |
-
|
| 407 |
-
dataset_name_push_to_hub,
|
| 408 |
],
|
| 409 |
outputs=[table],
|
| 410 |
show_progress=True,
|
| 411 |
)
|
| 412 |
-
|
| 413 |
-
app.load(get_org_dropdown, outputs=[orgs_selector])
|
| 414 |
-
app.load(fn=swap_visibilty, outputs=push_to_hub_ui)
|
|
|
|
| 223 |
num_turns=1,
|
| 224 |
num_rows=5,
|
| 225 |
private=True,
|
| 226 |
+
repo_id=None,
|
| 227 |
+
token=None,
|
|
|
|
| 228 |
progress=gr.Progress(),
|
| 229 |
):
|
| 230 |
+
if repo_id is not None:
|
| 231 |
+
if not repo_id:
|
| 232 |
raise gr.Error("Please provide a dataset name to push the dataset to.")
|
| 233 |
if token is None:
|
| 234 |
raise gr.Error(
|
|
|
|
| 279 |
|
| 280 |
distiset = result_queue.get()
|
| 281 |
|
| 282 |
+
if repo_id is not None:
|
| 283 |
progress(0.95, desc="Pushing dataset to Hugging Face Hub.")
|
|
|
|
| 284 |
distiset.push_to_hub(
|
| 285 |
repo_id=repo_id,
|
| 286 |
private=private,
|
| 287 |
include_script=False,
|
| 288 |
+
token=token,
|
| 289 |
)
|
| 290 |
gr.Info(
|
| 291 |
f'Dataset pushed to Hugging Face Hub: <a href="https://huggingface.co/datasets/{repo_id}">https://huggingface.co/datasets/{repo_id}</a>'
|
|
|
|
| 337 |
)
|
| 338 |
gr.Column(scale=1)
|
| 339 |
|
|
|
|
| 340 |
table = gr.DataFrame(
|
| 341 |
value=DEFAULT_DATASET,
|
| 342 |
interactive=False,
|
|
|
|
| 344 |
|
| 345 |
)
|
| 346 |
|
| 347 |
+
result = btn_generate_system_prompt.click(
|
| 348 |
fn=generate_system_prompt,
|
| 349 |
inputs=[dataset_description],
|
| 350 |
outputs=[system_prompt],
|
|
|
|
| 362 |
outputs=[table],
|
| 363 |
show_progress=True,
|
| 364 |
)
|
| 365 |
+
|
| 366 |
# Add a header for the full dataset generation section
|
| 367 |
+
gr.Markdown("## Generate full dataset")
|
| 368 |
gr.Markdown("Once you're satisfied with the sample, generate a larger dataset and push it to the hub.")
|
|
|
|
|
|
|
| 369 |
with gr.Column() as push_to_hub_ui:
|
| 370 |
with gr.Row(variant="panel"):
|
| 371 |
num_turns = gr.Number(
|
|
|
|
| 381 |
maximum=5000,
|
| 382 |
info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
|
| 383 |
)
|
| 384 |
+
|
| 385 |
|
| 386 |
with gr.Row(variant="panel"):
|
| 387 |
+
hf_token = gr.Textbox(label="HF token")
|
| 388 |
+
repo_id = gr.Textbox(label="HF repo ID", placeholder="owner/dataset_name")
|
| 389 |
+
private = gr.Checkbox(label="Private dataset", value=True, interactive=True)
|
| 390 |
|
| 391 |
btn_generate_full_dataset = gr.Button(
|
| 392 |
value="⚗️ Generate Full Dataset", variant="primary"
|
|
|
|
| 399 |
num_turns,
|
| 400 |
num_rows,
|
| 401 |
private,
|
| 402 |
+
repo_id,
|
|
|
|
| 403 |
],
|
| 404 |
outputs=[table],
|
| 405 |
show_progress=True,
|
| 406 |
)
|
|
|
|
|
|
|
|
|