Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig, get_peft_model | |
| from trl import SFTTrainer, SFTConfig | |
| from datasets import load_dataset | |
| import torch | |
| import tarfile | |
| from huggingface_hub import HfApi | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| # Debug environment variables | |
| logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()}) | |
| model_path = "ibm-granite/granite-3.3-8b-instruct" | |
| dataset_path = "mycholpath/ascii-json" | |
| output_dir = "/app/granite-8b-finetuned-ascii" | |
| output_tarball = "/app/granite-8b-finetuned-ascii.tar.gz" | |
| model_repo = "mycholpath/granite-8b-finetuned-ascii" | |
| artifact_repo = "mycholpath/granite-finetuned-artifacts" | |
| # Get HF token from granite environment variable | |
| granite_var = os.getenv("granite") | |
| if not granite_var or not granite_var.startswith("HF_TOKEN="): | |
| logger.error("granite environment variable is not set or invalid. Expected format: HF_TOKEN=<token>.") | |
| raise ValueError("granite environment variable is not set or invalid. Please set it in HF Space settings.") | |
| hf_token = granite_var.replace("HF_TOKEN=", "") | |
| logger.info("HF_TOKEN extracted from granite (value hidden for security)") | |
| logging.info("Loading tokenizer...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, token=hf_token, cache_dir="/tmp/hf_cache", trust_remote_code=True | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = 'right' | |
| except Exception as e: | |
| logger.error(f"Failed to load tokenizer: {str(e)}") | |
| raise | |
| logging.info("Loading model...") | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| token=hf_token, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| cache_dir="/tmp/hf_cache", | |
| trust_remote_code=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "v_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| logging.info("Preparing to load private dataset...") | |
| logger.info("Using HF_TOKEN from granite for private dataset authentication") | |
| try: | |
| dataset = load_dataset(dataset_path, split="train", token=hf_token) | |
| logger.info(f"Dataset loaded successfully: {len(dataset)} examples") | |
| except Exception as e: | |
| logger.error(f"Failed to load dataset: {str(e)}") | |
| raise | |
| def formatting_prompts_func(example): | |
| formatted = f"{example['prompt']}\n{example['completion']}" | |
| return [formatted] | |
| # Use SFTConfig for training arguments | |
| sft_config = SFTConfig( | |
| output_dir=output_dir, | |
| num_train_epochs=5, | |
| per_device_train_batch_size=4, | |
| per_device_eval_batch_size=4, | |
| gradient_accumulation_steps=4, | |
| learning_rate=2e-4, | |
| weight_decay=0.01, | |
| eval_strategy="no", | |
| save_steps=50, | |
| logging_steps=10, | |
| fp16=True, | |
| max_grad_norm=0.3, | |
| warmup_ratio=0.03, | |
| lr_scheduler_type="cosine", | |
| max_seq_length=768, | |
| dataset_text_field=None, | |
| packing=False | |
| ) | |
| logging.info("Starting training...") | |
| try: | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=dataset, | |
| eval_dataset=None, | |
| formatting_func=formatting_prompts_func, | |
| args=sft_config | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize SFTTrainer: {str(e)}") | |
| raise | |
| trainer.train() | |
| logging.info("Saving fine-tuned model...") | |
| trainer.save_model(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| # Create tarball for local retrieval | |
| try: | |
| with tarfile.open(output_tarball, "w:gz") as tar: | |
| tar.add(output_dir, arcname=os.path.basename(output_dir)) | |
| logger.info(f"Model tarball created: {output_tarball}") | |
| except Exception as e: | |
| logger.error(f"Failed to create model tarball: {str(e)}") | |
| raise | |
| # Upload model to HF Hub | |
| try: | |
| api = HfApi() | |
| logger.info(f"Creating model repository: {model_repo}") | |
| api.create_repo( | |
| repo_id=model_repo, | |
| repo_type="model", | |
| token=hf_token, | |
| private=True, | |
| exist_ok=True | |
| ) | |
| logger.info(f"Uploading model to {model_repo}") | |
| api.upload_folder( | |
| folder_path=output_dir, | |
| repo_id=model_repo, | |
| repo_type="model", | |
| token=hf_token, | |
| create_pr=False | |
| ) | |
| logger.info(f"Fine-tuned model uploaded to {model_repo}") | |
| except Exception as e: | |
| logger.error(f"Failed to upload model to HF Hub: {str(e)}") | |
| logger.warning("Continuing to tarball upload despite model upload failure") | |
| # Upload tarball to HF Hub dataset repository | |
| try: | |
| api = HfApi() | |
| logger.info(f"Creating dataset repository: {artifact_repo}") | |
| api.create_repo( | |
| repo_id=artifact_repo, | |
| repo_type="dataset", | |
| token=hf_token, | |
| private=True, | |
| exist_ok=True | |
| ) | |
| logger.info(f"Uploading tarball to {artifact_repo}") | |
| api.upload_file( | |
| path_or_fileobj=output_tarball, | |
| path_in_repo="granite-8b-finetuned-ascii.tar.gz", | |
| repo_id=artifact_repo, | |
| repo_type="dataset" | |
| token=hf_token | |
| ) | |
| logger.info(f"Tarball uploaded to {artifact_repo}/granite-8b-finetuned-ascii.tar.gz") | |
| except Exception as e: | |
| logger.error(f"Failed to upload tarball to HF Hub: {str(e)}") | |
| raise |