Caden Shokat
		
	commited on
		
		
					Commit 
							
							·
						
						eb13318
	
1
								Parent(s):
							
							13f746f
								
push model to hub
Browse files- src/training/train.py +5 -3
 
    	
        src/training/train.py
    CHANGED
    
    | 
         @@ -5,12 +5,13 @@ from sentence_transformers.losses import MultipleNegativesRankingLoss 
     | 
|
| 5 | 
         
             
            from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
         
     | 
| 6 | 
         
             
            from sentence_transformers.training_args import BatchSamplers
         
     | 
| 7 | 
         
             
            from sentence_transformers.losses import MatryoshkaLoss
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
             
            from src.utils.config import CFG
         
     | 
| 10 | 
         
             
            from src.utils.paths import TRAIN_JSON, TEST_JSON
         
     | 
| 11 | 
         
             
            from src.eval.ir_eval import build_eval
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            def main():
         
     | 
| 
         | 
|
| 14 | 
         
             
                device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
         
     | 
| 15 | 
         | 
| 16 | 
         
             
                # base model with SDPA
         
     | 
| 
         @@ -66,8 +67,9 @@ def main(): 
     | 
|
| 66 | 
         
             
                trainer.train()
         
     | 
| 67 | 
         
             
                trainer.save_model()
         
     | 
| 68 | 
         | 
| 69 | 
         
            -
                
         
     | 
| 70 | 
         
            -
             
     | 
| 
         | 
|
| 71 | 
         | 
| 72 | 
         
             
            if __name__ == "__main__":
         
     | 
| 73 | 
         
             
                main()
         
     | 
| 
         | 
|
| 5 | 
         
             
            from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
         
     | 
| 6 | 
         
             
            from sentence_transformers.training_args import BatchSamplers
         
     | 
| 7 | 
         
             
            from sentence_transformers.losses import MatryoshkaLoss
         
     | 
| 8 | 
         
            +
            from huggingface_hub import login
         
     | 
| 9 | 
         
             
            from src.utils.config import CFG
         
     | 
| 10 | 
         
             
            from src.utils.paths import TRAIN_JSON, TEST_JSON
         
     | 
| 11 | 
         
             
            from src.eval.ir_eval import build_eval
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            def main():
         
     | 
| 14 | 
         
            +
                HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN")
         
     | 
| 15 | 
         
             
                device = "cuda" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
         
     | 
| 16 | 
         | 
| 17 | 
         
             
                # base model with SDPA
         
     | 
| 
         | 
|
| 67 | 
         
             
                trainer.train()
         
     | 
| 68 | 
         
             
                trainer.save_model()
         
     | 
| 69 | 
         | 
| 70 | 
         
            +
                if HF_TOKEN:
         
     | 
| 71 | 
         
            +
                    login(token=HF_TOKEN)
         
     | 
| 72 | 
         
            +
                    trainer.model.push_to_hub(CFG.output_dir)
         
     | 
| 73 | 
         | 
| 74 | 
         
             
            if __name__ == "__main__":
         
     | 
| 75 | 
         
             
                main()
         
     |