Spaces:
Running
Running
File size: 6,321 Bytes
bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae 01ce750 bfa3dae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import enum
from lynxkite_core import ops
from lynxkite_graph_analytics.core import Bundle, TableName, ColumnNameByTableName
import unsloth
import trl
from datasets import load_dataset, Dataset
import unsloth.chat_templates
from transformers.training_args import OptimizerNames
from transformers.trainer_utils import SchedulerType
op = ops.op_registration("LynxKite Graph Analytics", "Unsloth")
@op("Load base model", slow=True, cache=False)
def load_base_model(
*,
model_name: str,
max_seq_length: int = 2048,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
full_finetuning: bool = False,
):
model, tokenizer = unsloth.FastModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
full_finetuning=full_finetuning,
)
return Bundle(other={"model": model, "tokenizer": tokenizer})
@op("Configure LoRA", slow=True, cache=False)
def configure_lora(bundle: Bundle, *, r=128, lora_dropout=0, random_state=1, rank_stabilized=False):
bundle = bundle.copy()
model = bundle.other["model"]
bundle.other["model"] = unsloth.FastModel.get_peft_model(
model,
r=r,
lora_dropout=lora_dropout,
random_state=random_state,
use_rslora=rank_stabilized,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=128,
bias="none",
use_gradient_checkpointing="unsloth",
loftq_config=None,
)
return bundle
@op("Load HF dataset", slow=True, cache=False)
def load_hf_dataset(*, name: str, split="train[:10000]") -> Bundle:
return Bundle(dfs={"dataset": load_dataset(name, split=split).to_pandas()})
@op("Convert to ChatML", slow=True, cache=False)
def convert_to_chatml(
bundle: Bundle,
*,
table_name: TableName,
system_column_name: ColumnNameByTableName,
user_column_name: ColumnNameByTableName,
assistant_column_name: ColumnNameByTableName,
save_as: str = "conversations",
):
bundle = bundle.copy()
ds = bundle.dfs[table_name]
bundle.dfs[table_name][save_as] = ds.apply(
lambda e: [
{"role": "system", "content": e[system_column_name]},
{"role": "user", "content": e[user_column_name]},
{"role": "assistant", "content": e[assistant_column_name]},
],
axis=1,
)
return bundle
@op("Apply chat template", slow=True, cache=False)
def apply_chat_template(
bundle: Bundle,
*,
table_name: TableName,
conversations_field: ColumnNameByTableName,
save_as="text",
):
bundle = bundle.copy()
tokenizer = bundle.other["tokenizer"]
bundle.dfs[table_name][save_as] = bundle.dfs[table_name][conversations_field].map(
lambda e: tokenizer.apply_chat_template(
e, tokenize=False, add_generation_prompt=False
).removeprefix("<bos>"),
)
return bundle
@op("Train LLM", slow=True, cache=False)
def train_llm(
bundle: Bundle,
*,
table_name: TableName,
dataset_text_field: ColumnNameByTableName,
train_on_responses_only=True,
per_device_train_batch_size=8,
gradient_accumulation_steps=1,
warmup_steps=5,
num_train_epochs: int | None = 1,
max_steps: int | None = None,
learning_rate=5e-5,
logging_steps=1,
optim=OptimizerNames.ADAMW_8BIT,
weight_decay=0.01,
lr_scheduler_type=SchedulerType.LINEAR,
seed=1,
):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
dataset = Dataset.from_pandas(bundle.dfs[table_name])
trainer = trl.SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
eval_dataset=None,
args=trl.SFTConfig(
dataset_text_field=dataset_text_field,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=warmup_steps,
num_train_epochs=num_train_epochs or -1,
max_steps=max_steps or -1,
learning_rate=learning_rate,
logging_steps=logging_steps,
optim=optim,
weight_decay=weight_decay,
lr_scheduler_type=lr_scheduler_type,
seed=seed,
output_dir="outputs",
report_to="none",
),
)
if train_on_responses_only:
trainer = unsloth.chat_templates.train_on_responses_only(
trainer,
instruction_part="<start_of_turn>user\n",
response_part="<start_of_turn>model\n",
)
trainer_stats = trainer.train()
bundle = bundle.copy()
bundle.other["trainer_stats"] = trainer_stats
return bundle
@op("Save model (LoRA only)", outputs=[], slow=True, cache=False)
def save_model_lora(bundle: Bundle, *, file_name: str):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
model.save_pretrained(file_name)
tokenizer.save_pretrained(file_name)
@op("Save model (float16)", outputs=[], slow=True, cache=False)
def save_model_float16(bundle: Bundle, *, file_name: str):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
model.save_pretrained_merged(file_name, tokenizer, save_method="merged_16bit")
@op("Save model (int4)", outputs=[], slow=True, cache=False)
def save_model_int4(bundle: Bundle, *, file_name: str):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
model.save_pretrained_merged(file_name, tokenizer, save_method="merged_4bit")
class QuantizationType(enum.StrEnum):
Q8_0 = "Q8_0"
BF16 = "BF16"
F16 = "F16"
@op("Save model (GGUF)", outputs=[], slow=True, cache=False)
def save_model_gguf(
bundle: Bundle, *, file_name: str, quantization: QuantizationType = QuantizationType.Q8_0
):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
model.save_pretrained_gguf(
file_name,
tokenizer,
quantization_type=quantization.value,
)
@op("Chat with model", view="service")
def chat_with_model(bundle: Bundle):
# TODO: Implement this.
pass
|