import enum from lynxkite_core import ops from lynxkite_graph_analytics.core import Bundle, TableName, ColumnNameByTableName import unsloth import trl from datasets import load_dataset, Dataset import unsloth.chat_templates from transformers.training_args import OptimizerNames from transformers.trainer_utils import SchedulerType op = ops.op_registration("LynxKite Graph Analytics", "Unsloth") @op("Load base model", slow=True, cache=False) def load_base_model( *, model_name: str, max_seq_length: int = 2048, load_in_4bit: bool = False, load_in_8bit: bool = False, full_finetuning: bool = False, ): model, tokenizer = unsloth.FastModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit, full_finetuning=full_finetuning, ) return Bundle(other={"model": model, "tokenizer": tokenizer}) @op("Configure LoRA", slow=True, cache=False) def configure_lora(bundle: Bundle, *, r=128, lora_dropout=0, random_state=1, rank_stabilized=False): bundle = bundle.copy() model = bundle.other["model"] bundle.other["model"] = unsloth.FastModel.get_peft_model( model, r=r, lora_dropout=lora_dropout, random_state=random_state, use_rslora=rank_stabilized, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=128, bias="none", use_gradient_checkpointing="unsloth", loftq_config=None, ) return bundle @op("Load HF dataset", slow=True, cache=False) def load_hf_dataset(*, name: str, split="train[:10000]") -> Bundle: return Bundle(dfs={"dataset": load_dataset(name, split=split).to_pandas()}) @op("Convert to ChatML", slow=True, cache=False) def convert_to_chatml( bundle: Bundle, *, table_name: TableName, system_column_name: ColumnNameByTableName, user_column_name: ColumnNameByTableName, assistant_column_name: ColumnNameByTableName, save_as: str = "conversations", ): bundle = bundle.copy() ds = bundle.dfs[table_name] bundle.dfs[table_name][save_as] = ds.apply( lambda e: [ {"role": "system", "content": e[system_column_name]}, {"role": "user", "content": e[user_column_name]}, {"role": "assistant", "content": e[assistant_column_name]}, ], axis=1, ) return bundle @op("Apply chat template", slow=True, cache=False) def apply_chat_template( bundle: Bundle, *, table_name: TableName, conversations_field: ColumnNameByTableName, save_as="text", ): bundle = bundle.copy() tokenizer = bundle.other["tokenizer"] bundle.dfs[table_name][save_as] = bundle.dfs[table_name][conversations_field].map( lambda e: tokenizer.apply_chat_template( e, tokenize=False, add_generation_prompt=False ).removeprefix(""), ) return bundle @op("Train LLM", slow=True, cache=False) def train_llm( bundle: Bundle, *, table_name: TableName, dataset_text_field: ColumnNameByTableName, train_on_responses_only=True, per_device_train_batch_size=8, gradient_accumulation_steps=1, warmup_steps=5, num_train_epochs: int | None = 1, max_steps: int | None = None, learning_rate=5e-5, logging_steps=1, optim=OptimizerNames.ADAMW_8BIT, weight_decay=0.01, lr_scheduler_type=SchedulerType.LINEAR, seed=1, ): model = bundle.other["model"] tokenizer = bundle.other["tokenizer"] dataset = Dataset.from_pandas(bundle.dfs[table_name]) trainer = trl.SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, eval_dataset=None, args=trl.SFTConfig( dataset_text_field=dataset_text_field, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_steps=warmup_steps, num_train_epochs=num_train_epochs or -1, max_steps=max_steps or -1, learning_rate=learning_rate, logging_steps=logging_steps, optim=optim, weight_decay=weight_decay, lr_scheduler_type=lr_scheduler_type, seed=seed, output_dir="outputs", report_to="none", ), ) if train_on_responses_only: trainer = unsloth.chat_templates.train_on_responses_only( trainer, instruction_part="user\n", response_part="model\n", ) trainer_stats = trainer.train() bundle = bundle.copy() bundle.other["trainer_stats"] = trainer_stats return bundle @op("Save model (LoRA only)", outputs=[], slow=True, cache=False) def save_model_lora(bundle: Bundle, *, file_name: str): model = bundle.other["model"] tokenizer = bundle.other["tokenizer"] model.save_pretrained(file_name) tokenizer.save_pretrained(file_name) @op("Save model (float16)", outputs=[], slow=True, cache=False) def save_model_float16(bundle: Bundle, *, file_name: str): model = bundle.other["model"] tokenizer = bundle.other["tokenizer"] model.save_pretrained_merged(file_name, tokenizer, save_method="merged_16bit") @op("Save model (int4)", outputs=[], slow=True, cache=False) def save_model_int4(bundle: Bundle, *, file_name: str): model = bundle.other["model"] tokenizer = bundle.other["tokenizer"] model.save_pretrained_merged(file_name, tokenizer, save_method="merged_4bit") class QuantizationType(enum.StrEnum): Q8_0 = "Q8_0" BF16 = "BF16" F16 = "F16" @op("Save model (GGUF)", outputs=[], slow=True, cache=False) def save_model_gguf( bundle: Bundle, *, file_name: str, quantization: QuantizationType = QuantizationType.Q8_0 ): model = bundle.other["model"] tokenizer = bundle.other["tokenizer"] model.save_pretrained_gguf( file_name, tokenizer, quantization_type=quantization.value, ) @op("Chat with model", view="service") def chat_with_model(bundle: Bundle): # TODO: Implement this. pass