|  | import logging | 
					
						
						|  | import traceback | 
					
						
						|  |  | 
					
						
						|  | from datasets import load_dataset | 
					
						
						|  | from sentence_transformers import ( | 
					
						
						|  | SentenceTransformer, | 
					
						
						|  | SentenceTransformerModelCardData, | 
					
						
						|  | SentenceTransformerTrainer, | 
					
						
						|  | SentenceTransformerTrainingArguments, | 
					
						
						|  | ) | 
					
						
						|  | from sentence_transformers.evaluation import InformationRetrievalEvaluator | 
					
						
						|  | from sentence_transformers.losses import CachedMultipleNegativesRankingLoss | 
					
						
						|  | from sentence_transformers.training_args import BatchSamplers | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = SentenceTransformer( | 
					
						
						|  | "google/embeddinggemma-300M", | 
					
						
						|  | model_card_data=SentenceTransformerModelCardData( | 
					
						
						|  | language="en", | 
					
						
						|  | license="apache-2.0", | 
					
						
						|  | model_name="EmbeddingGemma-300M trained on the Medical Instruction and RetrIeval Dataset (MIRIAD)", | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | train_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="train").select(range(100_000)) | 
					
						
						|  | eval_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="eval").select(range(1_000)) | 
					
						
						|  | test_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="test").select(range(1_000)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=8) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | run_name = "embeddinggemma-300M-medical-100k" | 
					
						
						|  | args = SentenceTransformerTrainingArguments( | 
					
						
						|  |  | 
					
						
						|  | output_dir=f"models/{run_name}", | 
					
						
						|  |  | 
					
						
						|  | num_train_epochs=1, | 
					
						
						|  | per_device_train_batch_size=128, | 
					
						
						|  | per_device_eval_batch_size=128, | 
					
						
						|  | learning_rate=2e-5, | 
					
						
						|  | warmup_ratio=0.1, | 
					
						
						|  | fp16=True, | 
					
						
						|  | bf16=False, | 
					
						
						|  | batch_sampler=BatchSamplers.NO_DUPLICATES, | 
					
						
						|  | prompts={ | 
					
						
						|  | "question": model.prompts["query"], | 
					
						
						|  | "passage_text": model.prompts["document"], | 
					
						
						|  | }, | 
					
						
						|  |  | 
					
						
						|  | eval_strategy="steps", | 
					
						
						|  | eval_steps=100, | 
					
						
						|  | save_strategy="steps", | 
					
						
						|  | save_steps=100, | 
					
						
						|  | save_total_limit=2, | 
					
						
						|  | logging_steps=20, | 
					
						
						|  | run_name=run_name, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | queries = dict(enumerate(eval_dataset["question"])) | 
					
						
						|  | corpus = dict(enumerate(eval_dataset["passage_text"] + train_dataset["passage_text"][:30_000])) | 
					
						
						|  | relevant_docs = {idx: [idx] for idx in queries} | 
					
						
						|  | dev_evaluator = InformationRetrievalEvaluator( | 
					
						
						|  | queries=queries, | 
					
						
						|  | corpus=corpus, | 
					
						
						|  | relevant_docs=relevant_docs, | 
					
						
						|  | name="miriad-eval-1kq-31kd", | 
					
						
						|  | show_progress_bar=True, | 
					
						
						|  | ) | 
					
						
						|  | dev_evaluator(model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | trainer = SentenceTransformerTrainer( | 
					
						
						|  | model=model, | 
					
						
						|  | args=args, | 
					
						
						|  | train_dataset=train_dataset, | 
					
						
						|  | eval_dataset=eval_dataset, | 
					
						
						|  | loss=loss, | 
					
						
						|  | evaluator=dev_evaluator, | 
					
						
						|  | ) | 
					
						
						|  | trainer.train() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dev_evaluator(model) | 
					
						
						|  |  | 
					
						
						|  | queries = dict(enumerate(test_dataset["question"])) | 
					
						
						|  | corpus = dict(enumerate(test_dataset["passage_text"] + train_dataset["passage_text"][:30_000])) | 
					
						
						|  | relevant_docs = {idx: [idx] for idx in queries} | 
					
						
						|  | test_evaluator = InformationRetrievalEvaluator( | 
					
						
						|  | queries=queries, | 
					
						
						|  | corpus=corpus, | 
					
						
						|  | relevant_docs=relevant_docs, | 
					
						
						|  | name="miriad-test-1kq-31kd", | 
					
						
						|  | show_progress_bar=True, | 
					
						
						|  | ) | 
					
						
						|  | test_evaluator(model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | final_output_dir = f"models/{run_name}/final" | 
					
						
						|  | model.save_pretrained(final_output_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | model.push_to_hub(run_name) | 
					
						
						|  | except Exception: | 
					
						
						|  | logging.error( | 
					
						
						|  | f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " | 
					
						
						|  | f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " | 
					
						
						|  | f"and saving it using `model.push_to_hub('{run_name}')`." | 
					
						
						|  | ) |