Gergo Szabo
Changed datasets to dataframes for ease-of-use. (#238)
01ce750
raw
history blame
6.32 kB
import enum
from lynxkite_core import ops
from lynxkite_graph_analytics.core import Bundle, TableName, ColumnNameByTableName
import unsloth
import trl
from datasets import load_dataset, Dataset
import unsloth.chat_templates
from transformers.training_args import OptimizerNames
from transformers.trainer_utils import SchedulerType
op = ops.op_registration("LynxKite Graph Analytics", "Unsloth")
@op("Load base model", slow=True, cache=False)
def load_base_model(
*,
model_name: str,
max_seq_length: int = 2048,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
full_finetuning: bool = False,
):
model, tokenizer = unsloth.FastModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
full_finetuning=full_finetuning,
)
return Bundle(other={"model": model, "tokenizer": tokenizer})
@op("Configure LoRA", slow=True, cache=False)
def configure_lora(bundle: Bundle, *, r=128, lora_dropout=0, random_state=1, rank_stabilized=False):
bundle = bundle.copy()
model = bundle.other["model"]
bundle.other["model"] = unsloth.FastModel.get_peft_model(
model,
r=r,
lora_dropout=lora_dropout,
random_state=random_state,
use_rslora=rank_stabilized,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=128,
bias="none",
use_gradient_checkpointing="unsloth",
loftq_config=None,
)
return bundle
@op("Load HF dataset", slow=True, cache=False)
def load_hf_dataset(*, name: str, split="train[:10000]") -> Bundle:
return Bundle(dfs={"dataset": load_dataset(name, split=split).to_pandas()})
@op("Convert to ChatML", slow=True, cache=False)
def convert_to_chatml(
bundle: Bundle,
*,
table_name: TableName,
system_column_name: ColumnNameByTableName,
user_column_name: ColumnNameByTableName,
assistant_column_name: ColumnNameByTableName,
save_as: str = "conversations",
):
bundle = bundle.copy()
ds = bundle.dfs[table_name]
bundle.dfs[table_name][save_as] = ds.apply(
lambda e: [
{"role": "system", "content": e[system_column_name]},
{"role": "user", "content": e[user_column_name]},
{"role": "assistant", "content": e[assistant_column_name]},
],
axis=1,
)
return bundle
@op("Apply chat template", slow=True, cache=False)
def apply_chat_template(
bundle: Bundle,
*,
table_name: TableName,
conversations_field: ColumnNameByTableName,
save_as="text",
):
bundle = bundle.copy()
tokenizer = bundle.other["tokenizer"]
bundle.dfs[table_name][save_as] = bundle.dfs[table_name][conversations_field].map(
lambda e: tokenizer.apply_chat_template(
e, tokenize=False, add_generation_prompt=False
).removeprefix("<bos>"),
)
return bundle
@op("Train LLM", slow=True, cache=False)
def train_llm(
bundle: Bundle,
*,
table_name: TableName,
dataset_text_field: ColumnNameByTableName,
train_on_responses_only=True,
per_device_train_batch_size=8,
gradient_accumulation_steps=1,
warmup_steps=5,
num_train_epochs: int | None = 1,
max_steps: int | None = None,
learning_rate=5e-5,
logging_steps=1,
optim=OptimizerNames.ADAMW_8BIT,
weight_decay=0.01,
lr_scheduler_type=SchedulerType.LINEAR,
seed=1,
):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
dataset = Dataset.from_pandas(bundle.dfs[table_name])
trainer = trl.SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
eval_dataset=None,
args=trl.SFTConfig(
dataset_text_field=dataset_text_field,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=warmup_steps,
num_train_epochs=num_train_epochs or -1,
max_steps=max_steps or -1,
learning_rate=learning_rate,
logging_steps=logging_steps,
optim=optim,
weight_decay=weight_decay,
lr_scheduler_type=lr_scheduler_type,
seed=seed,
output_dir="outputs",
report_to="none",
),
)
if train_on_responses_only:
trainer = unsloth.chat_templates.train_on_responses_only(
trainer,
instruction_part="<start_of_turn>user\n",
response_part="<start_of_turn>model\n",
)
trainer_stats = trainer.train()
bundle = bundle.copy()
bundle.other["trainer_stats"] = trainer_stats
return bundle
@op("Save model (LoRA only)", outputs=[], slow=True, cache=False)
def save_model_lora(bundle: Bundle, *, file_name: str):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
model.save_pretrained(file_name)
tokenizer.save_pretrained(file_name)
@op("Save model (float16)", outputs=[], slow=True, cache=False)
def save_model_float16(bundle: Bundle, *, file_name: str):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
model.save_pretrained_merged(file_name, tokenizer, save_method="merged_16bit")
@op("Save model (int4)", outputs=[], slow=True, cache=False)
def save_model_int4(bundle: Bundle, *, file_name: str):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
model.save_pretrained_merged(file_name, tokenizer, save_method="merged_4bit")
class QuantizationType(enum.StrEnum):
Q8_0 = "Q8_0"
BF16 = "BF16"
F16 = "F16"
@op("Save model (GGUF)", outputs=[], slow=True, cache=False)
def save_model_gguf(
bundle: Bundle, *, file_name: str, quantization: QuantizationType = QuantizationType.Q8_0
):
model = bundle.other["model"]
tokenizer = bundle.other["tokenizer"]
model.save_pretrained_gguf(
file_name,
tokenizer,
quantization_type=quantization.value,
)
@op("Chat with model", view="service")
def chat_with_model(bundle: Bundle):
# TODO: Implement this.
pass