Commit
·
e5f64a0
1
Parent(s):
489b632
feat: update run local code
Browse files
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
|
@@ -172,6 +172,7 @@ def generate_pipeline_code(system_prompt, num_turns, num_rows):
|
|
| 172 |
input_mappings = _get_output_mappings(num_turns)
|
| 173 |
code = f"""
|
| 174 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
|
|
|
| 175 |
from distilabel.pipeline import Pipeline
|
| 176 |
from distilabel.steps import KeepColumns
|
| 177 |
from distilabel.steps.tasks import MagpieGenerator
|
|
@@ -179,6 +180,7 @@ from distilabel.llms import InferenceEndpointsLLM
|
|
| 179 |
|
| 180 |
MODEL = "{MODEL}"
|
| 181 |
SYSTEM_PROMPT = "{system_prompt}"
|
|
|
|
| 182 |
|
| 183 |
with Pipeline(name="sft") as pipeline:
|
| 184 |
magpie = MagpieGenerator(
|
|
@@ -191,7 +193,8 @@ with Pipeline(name="sft") as pipeline:
|
|
| 191 |
"do_sample": True,
|
| 192 |
"max_new_tokens": 2048,
|
| 193 |
"stop_sequences": {_STOP_SEQUENCES}
|
| 194 |
-
}}
|
|
|
|
| 195 |
),
|
| 196 |
n_turns={num_turns},
|
| 197 |
num_rows={num_rows},
|
|
@@ -200,7 +203,7 @@ with Pipeline(name="sft") as pipeline:
|
|
| 200 |
output_mappings={input_mappings},
|
| 201 |
)
|
| 202 |
keep_columns = KeepColumns(
|
| 203 |
-
columns={list(input_mappings.values())} + ["model_name"
|
| 204 |
)
|
| 205 |
magpie.connect(keep_columns)
|
| 206 |
|
|
|
|
| 172 |
input_mappings = _get_output_mappings(num_turns)
|
| 173 |
code = f"""
|
| 174 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
| 175 |
+
import os
|
| 176 |
from distilabel.pipeline import Pipeline
|
| 177 |
from distilabel.steps import KeepColumns
|
| 178 |
from distilabel.steps.tasks import MagpieGenerator
|
|
|
|
| 180 |
|
| 181 |
MODEL = "{MODEL}"
|
| 182 |
SYSTEM_PROMPT = "{system_prompt}"
|
| 183 |
+
os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?globalPermissions=inference.serverless.write&tokenType=fineGrained
|
| 184 |
|
| 185 |
with Pipeline(name="sft") as pipeline:
|
| 186 |
magpie = MagpieGenerator(
|
|
|
|
| 193 |
"do_sample": True,
|
| 194 |
"max_new_tokens": 2048,
|
| 195 |
"stop_sequences": {_STOP_SEQUENCES}
|
| 196 |
+
}},
|
| 197 |
+
api_key=os.environ["HF_TOKEN"],
|
| 198 |
),
|
| 199 |
n_turns={num_turns},
|
| 200 |
num_rows={num_rows},
|
|
|
|
| 203 |
output_mappings={input_mappings},
|
| 204 |
)
|
| 205 |
keep_columns = KeepColumns(
|
| 206 |
+
columns={list(input_mappings.values())} + ["model_name"],
|
| 207 |
)
|
| 208 |
magpie.connect(keep_columns)
|
| 209 |
|