Sarthak
commited on
Commit
·
7837959
1
Parent(s):
37196da
chore: update dependencies and configuration for improved training
Browse filesThis commit updates the model configuration in `.codemap.yml` to use a lighter version of the model. Additionally, it enhances the `pyproject.toml` and `uv.lock` files by adding new dependencies such as `jinja2`, `joblib`, `rich`, and `safetensors`, while also replacing `tokenlearn` with `tokenizers`. The report has been adjusted to reflect changes in model performance metrics and the dataset configuration has been improved to support optimized dataset usage during training.
- .codemap.yml +1 -1
- REPORT.md +18 -90
- patches/model2vec.patch +0 -39
- patches/tokenlearn.patch +0 -25
- pyproject.toml +14 -3
- src/distiller/__main__.py +52 -2
- src/distiller/analyze.py +1 -1
- src/distiller/config.py +7 -1
- src/distiller/dataset.py +659 -0
- src/distiller/distill.py +345 -194
- src/distiller/patch_utils.py +0 -276
- uv.lock +21 -55
.codemap.yml
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
# LLM Configuration - Controls which model is used for AI operations
|
| 6 |
llm:
|
| 7 |
# Format: "provider:model-name", e.g., "openai:gpt-4o", "anthropic:claude-3-opus"
|
| 8 |
-
model: "google-gla:gemini-2.0-flash"
|
| 9 |
temperature: 0.5 # Lower for more deterministic outputs, higher for creativity
|
| 10 |
max_input_tokens: 1000000 # Maximum tokens in input
|
| 11 |
max_output_tokens: 10000 # Maximum tokens in responses
|
|
|
|
| 5 |
# LLM Configuration - Controls which model is used for AI operations
|
| 6 |
llm:
|
| 7 |
# Format: "provider:model-name", e.g., "openai:gpt-4o", "anthropic:claude-3-opus"
|
| 8 |
+
model: "google-gla:gemini-2.0-flash-lite"
|
| 9 |
temperature: 0.5 # Lower for more deterministic outputs, higher for creativity
|
| 10 |
max_input_tokens: 1000000 # Maximum tokens in input
|
| 11 |
max_output_tokens: 10000 # Maximum tokens in responses
|
REPORT.md
CHANGED
|
@@ -28,8 +28,8 @@ This report presents a comprehensive analysis of Model2Vec distillation experime
|
|
| 28 |
| code_model2vec_all_MiniLM_L6_v2 | [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 0.7385 | 0.7049 | 0.7910 | 🥈 2nd |
|
| 29 |
| code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | 🥉 3rd |
|
| 30 |
| code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
|
| 31 |
-
|
|
| 32 |
-
|
|
| 33 |
| code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
|
| 34 |
| code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
|
| 35 |
| code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
|
|
@@ -50,8 +50,8 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
| 50 |
| all_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
| 51 |
| jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
|
| 52 |
| paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
| 53 |
-
| all_mpnet_base_v2_fine_tuned | 77,316 | 19.8M | 256 | 75.5MB |
|
| 54 |
| Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
|
|
|
|
| 55 |
| bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
|
| 56 |
| jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
|
| 57 |
| nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
|
|
@@ -69,9 +69,9 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
| 69 |
#### Key Insights from Model Specifications:
|
| 70 |
|
| 71 |
|
| 72 |
-
- **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg:
|
| 73 |
-
- **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg:
|
| 74 |
-
- **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg:
|
| 75 |
- **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
|
| 76 |
|
| 77 |
|
|
@@ -81,85 +81,13 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
| 81 |
- **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
|
| 82 |
- **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
|
| 83 |
- **Performance Range**: 62.4% difference between best and worst
|
| 84 |
-
- **Average Performance**: 0.
|
| 85 |
|
| 86 |
|
| 87 |
## 🎯 Language Performance Radar Charts
|
| 88 |
|
| 89 |
### Best Model vs Peer Models Comparison
|
| 90 |
|
| 91 |
-

|
| 92 |
-
|
| 93 |
-
*Comparative view showing how the best simplified distillation model performs against top peer models across programming languages.*
|
| 94 |
-
|
| 95 |
-
### Individual Model Performance by Language
|
| 96 |
-
|
| 97 |
-
#### code_model2vec_all_mpnet_base_v2 (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.7387
|
| 98 |
-
|
| 99 |
-

|
| 100 |
-
|
| 101 |
-
#### code_model2vec_all_MiniLM_L6_v2 (Teacher: [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)) - NDCG@10: 0.7385
|
| 102 |
-
|
| 103 |
-

|
| 104 |
-
|
| 105 |
-
#### code_model2vec_jina_embeddings_v2_base_code (Teacher: [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code)) - NDCG@10: 0.7381
|
| 106 |
-
|
| 107 |
-

|
| 108 |
-
|
| 109 |
-
#### code_model2vec_paraphrase_MiniLM_L6_v2 (Teacher: [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)) - NDCG@10: 0.7013
|
| 110 |
-
|
| 111 |
-

|
| 112 |
-
|
| 113 |
-
#### code_model2vec_all_mpnet_base_v2_fine_tuned (Teacher: [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)) - NDCG@10: 0.6906
|
| 114 |
-
|
| 115 |
-

|
| 116 |
-
|
| 117 |
-
#### code_model2vec_Reason_ModernColBERT (Teacher: [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT)) - NDCG@10: 0.6598
|
| 118 |
-
|
| 119 |
-

|
| 120 |
-
|
| 121 |
-
#### code_model2vec_bge_m3 (Teacher: [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3)) - NDCG@10: 0.4863
|
| 122 |
-
|
| 123 |
-

|
| 124 |
-
|
| 125 |
-
#### code_model2vec_jina_embeddings_v3 (Teacher: [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3)) - NDCG@10: 0.4755
|
| 126 |
-
|
| 127 |
-

|
| 128 |
-
|
| 129 |
-
#### code_model2vec_nomic_embed_text_v2_moe (Teacher: [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe)) - NDCG@10: 0.4532
|
| 130 |
-
|
| 131 |
-

|
| 132 |
-
|
| 133 |
-
#### code_model2vec_gte_Qwen2_1.5B_instruct (Teacher: [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)) - NDCG@10: 0.4238
|
| 134 |
-
|
| 135 |
-

|
| 136 |
-
|
| 137 |
-
#### code_model2vec_Qodo_Embed_1_1.5B (Teacher: [Qodo/Qodo-Embed-1-1.5B](https://huggingface.co/Qodo/Qodo-Embed-1-1.5B)) - NDCG@10: 0.4101
|
| 138 |
-
|
| 139 |
-

|
| 140 |
-
|
| 141 |
-
#### code_model2vec_graphcodebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.3420
|
| 142 |
-
|
| 143 |
-

|
| 144 |
-
|
| 145 |
-
#### code_model2vec_Linq_Embed_Mistral (Teacher: [Linq-AI-Research/Linq-Embed-Mistral](https://huggingface.co/Linq-AI-Research/Linq-Embed-Mistral)) - NDCG@10: 0.2868
|
| 146 |
-
|
| 147 |
-

|
| 148 |
-
|
| 149 |
-
#### code_model2vec_codebert_base (Teacher: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base)) - NDCG@10: 0.2779
|
| 150 |
-
|
| 151 |
-

|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
## 🏆 Peer Model Comparison
|
| 156 |
-
|
| 157 |
-

|
| 158 |
-
|
| 159 |
-
*Comparison with established code-specialized embedding models using actual evaluation results.*
|
| 160 |
-
|
| 161 |
-
### Complete Model Ranking
|
| 162 |
-
|
| 163 |
| Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
|
| 164 |
|------|-------|------|---------|-----|----------|
|
| 165 |
| 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
|
|
@@ -180,10 +108,10 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
| 180 |
| 16 | code_model2vec_all_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7385 | 0.7049 | 0.7910 |
|
| 181 |
| 17 | code_model2vec_jina_embeddings_v2_base_code | **🔥 Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
|
| 182 |
| 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
|
| 183 |
-
| 19 |
|
| 184 |
-
| 20 |
|
| 185 |
-
| 21 |
|
| 186 |
-
| 22 |
|
| 187 |
| 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
|
| 188 |
| 24 | code_model2vec_bge_m3 | **🔥 Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
|
| 189 |
| 25 | code_model2vec_jina_embeddings_v3 | **🔥 Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
|
|
@@ -243,12 +171,12 @@ Our distilled models exhibit consistent architectural characteristics across dif
|
|
| 243 |
|
| 244 |
| Language | Best Model Performance | Average Performance | Language Difficulty |
|
| 245 |
|----------|------------------------|--------------------|--------------------|
|
| 246 |
-
| Go | 0.9780 | 0.
|
| 247 |
-
| Java | 0.9921 | 0.
|
| 248 |
-
| Javascript | 0.9550 | 0.
|
| 249 |
-
| Php | 1.0000 | 0.
|
| 250 |
-
| Python | 1.0000 | 0.
|
| 251 |
-
| Ruby | 0.9493 | 0.
|
| 252 |
|
| 253 |
|
| 254 |
## 🎯 Conclusions and Recommendations
|
|
@@ -302,5 +230,5 @@ Based on the evaluation results across all simplified distillation models:
|
|
| 302 |
|
| 303 |
---
|
| 304 |
|
| 305 |
-
*Report generated on 2025-05-31
|
| 306 |
*For questions about methodology or results, please refer to the CodeSearchNet documentation.*
|
|
|
|
| 28 |
| code_model2vec_all_MiniLM_L6_v2 | [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) | 0.7385 | 0.7049 | 0.7910 | 🥈 2nd |
|
| 29 |
| code_model2vec_jina_embeddings_v2_base_code | [jina-embeddings-v2-base-code](https://huggingface.co/jina-embeddings-v2-base-code) | 0.7381 | 0.6996 | 0.8130 | 🥉 3rd |
|
| 30 |
| code_model2vec_paraphrase_MiniLM_L6_v2 | [sentence-transformers/paraphrase-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2) | 0.7013 | 0.6638 | 0.7665 | #4 |
|
| 31 |
+
| code_model2vec_Reason_ModernColBERT | [lightonai/Reason-ModernColBERT](https://huggingface.co/lightonai/Reason-ModernColBERT) | 0.6598 | 0.6228 | 0.7260 | #5 |
|
| 32 |
+
| code_model2vec_all_mpnet_base_v2_fine_tuned | [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 0.5347 | 0.4875 | 0.6200 | #6 |
|
| 33 |
| code_model2vec_bge_m3 | [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) | 0.4863 | 0.4439 | 0.5514 | #7 |
|
| 34 |
| code_model2vec_jina_embeddings_v3 | [jinaai/jina-embeddings-v3](https://huggingface.co/jinaai/jina-embeddings-v3) | 0.4755 | 0.4416 | 0.5456 | #8 |
|
| 35 |
| code_model2vec_nomic_embed_text_v2_moe | [nomic-ai/nomic-embed-text-v2-moe](https://huggingface.co/nomic-ai/nomic-embed-text-v2-moe) | 0.4532 | 0.4275 | 0.5094 | #9 |
|
|
|
|
| 50 |
| all_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
| 51 |
| jina_embeddings_v2_base_code | 61,053 | 15.6M | 256 | 29.8MB |
|
| 52 |
| paraphrase_MiniLM_L6_v2 | 29,525 | 7.6M | 256 | 14.4MB |
|
|
|
|
| 53 |
| Reason_ModernColBERT | 50,254 | 12.9M | 256 | 24.5MB |
|
| 54 |
+
| all_mpnet_base_v2_fine_tuned | 29,528 | 7.6M | 256 | 28.8MB |
|
| 55 |
| bge_m3 | 249,999 | 64.0M | 256 | 122.1MB |
|
| 56 |
| jina_embeddings_v3 | 249,999 | 64.0M | 256 | 122.1MB |
|
| 57 |
| nomic_embed_text_v2_moe | 249,999 | 64.0M | 256 | 122.1MB |
|
|
|
|
| 69 |
#### Key Insights from Model Specifications:
|
| 70 |
|
| 71 |
|
| 72 |
+
- **Vocabulary Consistency**: All models use vocabulary sizes ranging from 29,525 to 249,999 tokens (avg: 101,087)
|
| 73 |
+
- **Parameter Efficiency**: Models range from 7.6M to 64.0M parameters (avg: 25.9M)
|
| 74 |
+
- **Storage Efficiency**: Disk usage ranges from 14.4MB to 122.1MB (avg: 50.4MB)
|
| 75 |
- **Embedding Dimensions**: Consistent 256 dimensions across all models (optimized for efficiency)
|
| 76 |
|
| 77 |
|
|
|
|
| 81 |
- **Best Teacher Model**: code_model2vec_all_mpnet_base_v2 (NDCG@10: 0.7387)
|
| 82 |
- **Least Effective Teacher**: code_model2vec_codebert_base (NDCG@10: 0.2779)
|
| 83 |
- **Performance Range**: 62.4% difference between best and worst
|
| 84 |
+
- **Average Performance**: 0.5190 NDCG@10
|
| 85 |
|
| 86 |
|
| 87 |
## 🎯 Language Performance Radar Charts
|
| 88 |
|
| 89 |
### Best Model vs Peer Models Comparison
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
| Rank | Model | Type | NDCG@10 | MRR | Recall@5 |
|
| 92 |
|------|-------|------|---------|-----|----------|
|
| 93 |
| 1 | Alibaba-NLP/gte-Qwen2-1.5B-instruct | General | 0.9729 | 0.9676 | 0.9825 |
|
|
|
|
| 108 |
| 16 | code_model2vec_all_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7385 | 0.7049 | 0.7910 |
|
| 109 |
| 17 | code_model2vec_jina_embeddings_v2_base_code | **🔥 Simplified Distillation** | 0.7381 | 0.6996 | 0.8130 |
|
| 110 |
| 18 | code_model2vec_paraphrase_MiniLM_L6_v2 | **🔥 Simplified Distillation** | 0.7013 | 0.6638 | 0.7665 |
|
| 111 |
+
| 19 | code_model2vec_Reason_ModernColBERT | **🔥 Simplified Distillation** | 0.6598 | 0.6228 | 0.7260 |
|
| 112 |
+
| 20 | potion-multilingual-128M | Model2Vec | 0.6124 | 0.5683 | 0.7017 |
|
| 113 |
+
| 21 | huggingface/CodeBERTa-small-v1 | Code-Specific | 0.5903 | 0.5350 | 0.6779 |
|
| 114 |
+
| 22 | code_model2vec_all_mpnet_base_v2_fine_tuned | **🎓 Fine-tuned Distillation** | 0.5347 | 0.4875 | 0.6200 |
|
| 115 |
| 23 | Salesforce/codet5-base | Code-Specific | 0.4872 | 0.4500 | 0.5742 |
|
| 116 |
| 24 | code_model2vec_bge_m3 | **🔥 Simplified Distillation** | 0.4863 | 0.4439 | 0.5514 |
|
| 117 |
| 25 | code_model2vec_jina_embeddings_v3 | **🔥 Simplified Distillation** | 0.4755 | 0.4416 | 0.5456 |
|
|
|
|
| 171 |
|
| 172 |
| Language | Best Model Performance | Average Performance | Language Difficulty |
|
| 173 |
|----------|------------------------|--------------------|--------------------|
|
| 174 |
+
| Go | 0.9780 | 0.6923 | Easy |
|
| 175 |
+
| Java | 0.9921 | 0.6545 | Easy |
|
| 176 |
+
| Javascript | 0.9550 | 0.5831 | Easy |
|
| 177 |
+
| Php | 1.0000 | 0.6325 | Easy |
|
| 178 |
+
| Python | 1.0000 | 0.8599 | Easy |
|
| 179 |
+
| Ruby | 0.9493 | 0.6333 | Easy |
|
| 180 |
|
| 181 |
|
| 182 |
## 🎯 Conclusions and Recommendations
|
|
|
|
| 230 |
|
| 231 |
---
|
| 232 |
|
| 233 |
+
*Report generated on 2025-05-31 21:07:06 using automated analysis pipeline.*
|
| 234 |
*For questions about methodology or results, please refer to the CodeSearchNet documentation.*
|
patches/model2vec.patch
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
--- a/model2vec/train/base.py
|
| 2 |
-
+++ b/model2vec/train/base.py
|
| 3 |
-
@@ -35,7 +35,7 @@ class FinetunableStaticModel(nn.Module):
|
| 4 |
-
)
|
| 5 |
-
self.vectors = vectors.float()
|
| 6 |
-
|
| 7 |
-
- self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
|
| 8 |
-
+ self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=pad_id)
|
| 9 |
-
self.head = self.construct_head()
|
| 10 |
-
self.w = self.construct_weights()
|
| 11 |
-
self.tokenizer = tokenizer
|
| 12 |
-
--- a/model2vec/distill/distillation.py
|
| 13 |
-
+++ b/model2vec/distill/distillation.py
|
| 14 |
-
@@ -137,7 +137,10 @@ def distill_from_model(
|
| 15 |
-
# Get the language from the model card.
|
| 16 |
-
try:
|
| 17 |
-
info = model_info(model_name)
|
| 18 |
-
- language = info.cardData.get("language", None)
|
| 19 |
-
+ if info is not None and hasattr(info, 'cardData') and info.cardData is not None:
|
| 20 |
-
+ language = info.cardData.get("language", None)
|
| 21 |
-
+ else:
|
| 22 |
-
+ language = None
|
| 23 |
-
except RepositoryNotFoundError:
|
| 24 |
-
logger.info("No model info found for the model. Setting language to None.")
|
| 25 |
-
language = None
|
| 26 |
-
--- a/model2vec/distill/inference.py
|
| 27 |
-
+++ b/model2vec/distill/inference.py
|
| 28 |
-
@@ -109,5 +109,12 @@ def create_embeddings(
|
| 29 |
-
out_tokens.extend([Token(x, False) for x in tokens])
|
| 30 |
-
out_weights = np.stack(intermediate_weights)
|
| 31 |
-
|
| 32 |
-
+ # Validate token-vector consistency to prevent failures
|
| 33 |
-
+ if len(out_tokens) != out_weights.shape[0]:
|
| 34 |
-
+ logger.warning(f"Token-vector mismatch: {len(out_tokens)} tokens vs {out_weights.shape[0]} vectors. Truncating to prevent failure.")
|
| 35 |
-
+ min_count = min(len(out_tokens), out_weights.shape[0])
|
| 36 |
-
+ out_tokens = out_tokens[:min_count]
|
| 37 |
-
+ out_weights = out_weights[:min_count]
|
| 38 |
-
+
|
| 39 |
-
return out_tokens, out_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patches/tokenlearn.patch
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
--- a/tokenlearn/pretrain.py
|
| 2 |
-
+++ b/tokenlearn/pretrain.py
|
| 3 |
-
@@ -38,7 +38,10 @@ class FinetunableStaticModel(nn.Module):
|
| 4 |
-
"""Run the model using input IDs."""
|
| 5 |
-
input_ids = input_ids.view(-1)
|
| 6 |
-
input_ids = input_ids[input_ids != self.pad_token_id]
|
| 7 |
-
- w = self.w[input_ids]
|
| 8 |
-
+ # Fix for index out of bounds issue
|
| 9 |
-
+ # Clamp input_ids to valid range to prevent IndexError during training
|
| 10 |
-
+ valid_input_ids = torch.clamp(input_ids, 0, self.w.shape[0] - 1)
|
| 11 |
-
+ w = self.w[valid_input_ids]
|
| 12 |
-
return self.sub_forward(w)
|
| 13 |
-
|
| 14 |
-
def forward(self, x):
|
| 15 |
-
@@ -46,7 +49,10 @@ class FinetunableStaticModel(nn.Module):
|
| 16 |
-
# Add a small epsilon to avoid division by zero
|
| 17 |
-
length = zeros.sum(1) + 1e-16
|
| 18 |
-
- embedded = self.embeddings(input_ids)
|
| 19 |
-
+ # Fix for embedding index out of bounds issue
|
| 20 |
-
+ # Clamp input_ids to valid embedding range
|
| 21 |
-
+ valid_input_ids = torch.clamp(input_ids, 0, self.embeddings.num_embeddings - 1)
|
| 22 |
-
+ embedded = self.embeddings(valid_input_ids)
|
| 23 |
-
# Zero out the padding
|
| 24 |
-
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
|
| 25 |
-
# Simulate actual mean
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
|
@@ -19,24 +19,31 @@ dependencies = [
|
|
| 19 |
"flash-attn>=2.7.4.post1",
|
| 20 |
"hatchling>=1.27.0",
|
| 21 |
"iso639>=0.1.4",
|
|
|
|
|
|
|
| 22 |
"kaleido==1.0.0rc13",
|
| 23 |
"lightning>=2.5.1.post0",
|
| 24 |
"matplotlib>=3.10.3",
|
| 25 |
-
"
|
| 26 |
"mteb>=1.14.15",
|
| 27 |
"numpy>=1.26.4",
|
| 28 |
"plotly>=6.1.1",
|
| 29 |
"psutil>=7.0.0",
|
| 30 |
"pydantic>=2.11.5",
|
| 31 |
"requests>=2.32.3",
|
|
|
|
|
|
|
| 32 |
"scikit-learn>=1.6.1",
|
| 33 |
"seaborn>=0.13.2",
|
| 34 |
"sentence-transformers>=4.1.0",
|
| 35 |
"setuptools>=80.8.0",
|
|
|
|
| 36 |
"smart-open[s3]>=7.1.0",
|
| 37 |
"statsmodels>=0.14.4",
|
| 38 |
-
"
|
| 39 |
"torch>=2.7.0",
|
|
|
|
|
|
|
| 40 |
"typer>=0.16.0",
|
| 41 |
]
|
| 42 |
|
|
@@ -78,7 +85,9 @@ exclude = [
|
|
| 78 |
"__pycache__",
|
| 79 |
"build",
|
| 80 |
"dist",
|
| 81 |
-
"vendor"
|
|
|
|
|
|
|
| 82 |
]
|
| 83 |
|
| 84 |
[tool.ruff.lint]
|
|
@@ -114,6 +123,8 @@ ignore = [
|
|
| 114 |
"E501", # Line too long
|
| 115 |
"PLR2004",
|
| 116 |
"RUF001",
|
|
|
|
|
|
|
| 117 |
]
|
| 118 |
|
| 119 |
[tool.ruff.lint.mccabe]
|
|
|
|
| 19 |
"flash-attn>=2.7.4.post1",
|
| 20 |
"hatchling>=1.27.0",
|
| 21 |
"iso639>=0.1.4",
|
| 22 |
+
"jinja2>=3.0.0",
|
| 23 |
+
"joblib>=1.0.0",
|
| 24 |
"kaleido==1.0.0rc13",
|
| 25 |
"lightning>=2.5.1.post0",
|
| 26 |
"matplotlib>=3.10.3",
|
| 27 |
+
"more-itertools>=10.5.0",
|
| 28 |
"mteb>=1.14.15",
|
| 29 |
"numpy>=1.26.4",
|
| 30 |
"plotly>=6.1.1",
|
| 31 |
"psutil>=7.0.0",
|
| 32 |
"pydantic>=2.11.5",
|
| 33 |
"requests>=2.32.3",
|
| 34 |
+
"rich>=10.0.0",
|
| 35 |
+
"safetensors>=0.3.0",
|
| 36 |
"scikit-learn>=1.6.1",
|
| 37 |
"seaborn>=0.13.2",
|
| 38 |
"sentence-transformers>=4.1.0",
|
| 39 |
"setuptools>=80.8.0",
|
| 40 |
+
"skops>=0.11.0",
|
| 41 |
"smart-open[s3]>=7.1.0",
|
| 42 |
"statsmodels>=0.14.4",
|
| 43 |
+
"tokenizers>=0.20",
|
| 44 |
"torch>=2.7.0",
|
| 45 |
+
"transformers<=4.52.1",
|
| 46 |
+
"tqdm>=4.65.0",
|
| 47 |
"typer>=0.16.0",
|
| 48 |
]
|
| 49 |
|
|
|
|
| 85 |
"__pycache__",
|
| 86 |
"build",
|
| 87 |
"dist",
|
| 88 |
+
"vendor",
|
| 89 |
+
"src/distiller/model2vec",
|
| 90 |
+
"src/distiller/tokenlearn"
|
| 91 |
]
|
| 92 |
|
| 93 |
[tool.ruff.lint]
|
|
|
|
| 123 |
"E501", # Line too long
|
| 124 |
"PLR2004",
|
| 125 |
"RUF001",
|
| 126 |
+
"D100", # Missing docstring in public module
|
| 127 |
+
"D101", # Missing docstring in public class
|
| 128 |
]
|
| 129 |
|
| 130 |
[tool.ruff.lint.mccabe]
|
src/distiller/__main__.py
CHANGED
|
@@ -17,12 +17,41 @@ def distill(
|
|
| 17 |
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
|
| 18 |
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
|
| 19 |
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
) -> None:
|
| 21 |
"""Run unified Model2Vec distillation with optional training."""
|
| 22 |
from .distill import main as distill_main
|
| 23 |
|
| 24 |
-
# Call the distill main function with arguments
|
| 25 |
-
distill_main(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
@app.command()
|
|
@@ -53,5 +82,26 @@ def analyze(
|
|
| 53 |
analyze_main(results_dir or "code_model2vec/evaluation_results", model_name, output, export_csv)
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
if __name__ == "__main__":
|
| 57 |
app()
|
|
|
|
| 17 |
train: Annotated[bool, typer.Option(help="Enable advanced training (CodeSearchNet fine-tuning)")] = False,
|
| 18 |
teacher_models: Annotated[list[str] | None, typer.Option(help="Specific teacher models to distill")] = None,
|
| 19 |
pca_dims: Annotated[int | None, typer.Option(help="PCA dimensions (uses config default if not specified)")] = None,
|
| 20 |
+
clear_cache: Annotated[
|
| 21 |
+
bool, typer.Option(help="Clear HuggingFace cache for problematic models before distillation")
|
| 22 |
+
] = False,
|
| 23 |
+
clear_checkpoints: Annotated[
|
| 24 |
+
bool, typer.Option(help="Clear tokenlearn checkpoints to force fresh featurization and training")
|
| 25 |
+
] = False,
|
| 26 |
+
skip_ptr: Annotated[
|
| 27 |
+
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
| 28 |
+
] = False,
|
| 29 |
+
use_optimized_dataset: Annotated[
|
| 30 |
+
bool,
|
| 31 |
+
typer.Option(
|
| 32 |
+
"--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset"
|
| 33 |
+
),
|
| 34 |
+
] = False,
|
| 35 |
+
dataset_path: Annotated[
|
| 36 |
+
str | None,
|
| 37 |
+
typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"),
|
| 38 |
+
] = None,
|
| 39 |
) -> None:
|
| 40 |
"""Run unified Model2Vec distillation with optional training."""
|
| 41 |
from .distill import main as distill_main
|
| 42 |
|
| 43 |
+
# Call the distill main function with all arguments
|
| 44 |
+
distill_main(
|
| 45 |
+
use_beam,
|
| 46 |
+
train,
|
| 47 |
+
teacher_models,
|
| 48 |
+
pca_dims,
|
| 49 |
+
clear_cache,
|
| 50 |
+
clear_checkpoints,
|
| 51 |
+
skip_ptr,
|
| 52 |
+
use_optimized_dataset,
|
| 53 |
+
dataset_path,
|
| 54 |
+
)
|
| 55 |
|
| 56 |
|
| 57 |
@app.command()
|
|
|
|
| 82 |
analyze_main(results_dir or "code_model2vec/evaluation_results", model_name, output, export_csv)
|
| 83 |
|
| 84 |
|
| 85 |
+
@app.command()
|
| 86 |
+
def dataset(
|
| 87 |
+
max_samples_per_lang: Annotated[int, typer.Option(help="Maximum samples per language")] = 50000,
|
| 88 |
+
min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = 3,
|
| 89 |
+
max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = 100,
|
| 90 |
+
min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = 50,
|
| 91 |
+
max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = 2000,
|
| 92 |
+
output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None,
|
| 93 |
+
simple_format: Annotated[
|
| 94 |
+
bool, typer.Option(help="Create only simple format (not multiple training formats)")
|
| 95 |
+
] = False,
|
| 96 |
+
) -> None:
|
| 97 |
+
"""Create optimized training dataset from CodeSearchNet for code search tasks."""
|
| 98 |
+
from .dataset import main as dataset_main
|
| 99 |
+
|
| 100 |
+
# Call the dataset main function with arguments
|
| 101 |
+
dataset_main(
|
| 102 |
+
max_samples_per_lang, min_doc_words, max_doc_words, min_code_chars, max_code_chars, output_dir, simple_format
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
if __name__ == "__main__":
|
| 107 |
app()
|
src/distiller/analyze.py
CHANGED
|
@@ -510,7 +510,7 @@ class CodeSearchNetAnalyzer:
|
|
| 510 |
|
| 511 |
try:
|
| 512 |
# Try to load the model and get specifications
|
| 513 |
-
from model2vec import StaticModel
|
| 514 |
|
| 515 |
model = StaticModel.from_pretrained(str(model_dir))
|
| 516 |
|
|
|
|
| 510 |
|
| 511 |
try:
|
| 512 |
# Try to load the model and get specifications
|
| 513 |
+
from distiller.model2vec import StaticModel
|
| 514 |
|
| 515 |
model = StaticModel.from_pretrained(str(model_dir))
|
| 516 |
|
src/distiller/config.py
CHANGED
|
@@ -212,13 +212,19 @@ class DistillationConfig(BaseModel):
|
|
| 212 |
# Tokenlearn-specific parameters (POTION approach)
|
| 213 |
tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
|
| 214 |
tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
|
| 215 |
-
tokenlearn_text_key: str =
|
|
|
|
|
|
|
| 216 |
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
| 217 |
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
| 218 |
|
| 219 |
# Post-training configuration
|
| 220 |
skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
distillation_config = DistillationConfig()
|
| 224 |
|
|
|
|
| 212 |
# Tokenlearn-specific parameters (POTION approach)
|
| 213 |
tokenlearn_dataset: str = "sentence-transformers/codesearchnet" # Dataset for tokenlearn featurization
|
| 214 |
tokenlearn_dataset_name: str = "pair" # Use 'pair' configuration (only available config)
|
| 215 |
+
tokenlearn_text_key: str = (
|
| 216 |
+
"combined_text" # Text field to use from the dataset ('combined_text' for doc-code pairs)
|
| 217 |
+
)
|
| 218 |
tokenlearn_timeout_featurize: int = 21600 # 6 hour timeout for featurization (dataset needs ~5 hours)
|
| 219 |
tokenlearn_timeout_train: int = 7200 # 2 hour timeout for training
|
| 220 |
|
| 221 |
# Post-training configuration
|
| 222 |
skip_post_training_regularization: bool = False # Skip PCA + SIF re-regularization step
|
| 223 |
|
| 224 |
+
# Dataset configuration
|
| 225 |
+
use_optimized_dataset: bool = True # Use the pre-created optimized dataset from dataset.py
|
| 226 |
+
custom_dataset_path: str | None = "code_model2vec/dataset" # Path to custom dataset directory
|
| 227 |
+
|
| 228 |
|
| 229 |
distillation_config = DistillationConfig()
|
| 230 |
|
src/distiller/dataset.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Dataset Generation for Code-Specialized Model Training.
|
| 3 |
+
|
| 4 |
+
This module creates optimized training datasets from CodeSearchNet that are specifically
|
| 5 |
+
designed to improve performance on code search evaluation tasks.
|
| 6 |
+
|
| 7 |
+
Features:
|
| 8 |
+
- High-quality doc-code pairs optimized for retrieval
|
| 9 |
+
- Balanced sampling across programming languages
|
| 10 |
+
- Multiple training formats (doc-only, code-only, combined)
|
| 11 |
+
- Quality filtering and data cleaning
|
| 12 |
+
- Train/test/eval splits with proper stratification
|
| 13 |
+
- Efficient parquet format output
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import time
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Annotated, Any
|
| 21 |
+
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import typer
|
| 24 |
+
from datasets import load_dataset
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from .config import languages_config
|
| 28 |
+
|
| 29 |
+
# Set up logging
|
| 30 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# Dataset configuration
|
| 34 |
+
DATASET_OUTPUT_DIR = Path("code_model2vec/dataset")
|
| 35 |
+
DEFAULT_MAX_SAMPLES_PER_LANG = 50000
|
| 36 |
+
DEFAULT_MIN_DOC_WORDS = 3
|
| 37 |
+
DEFAULT_MAX_DOC_WORDS = 100
|
| 38 |
+
DEFAULT_MIN_CODE_CHARS = 50
|
| 39 |
+
DEFAULT_MAX_CODE_CHARS = 2000
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def create_optimized_dataset(
|
| 43 |
+
max_samples_per_lang: int = DEFAULT_MAX_SAMPLES_PER_LANG,
|
| 44 |
+
min_doc_words: int = DEFAULT_MIN_DOC_WORDS,
|
| 45 |
+
max_doc_words: int = DEFAULT_MAX_DOC_WORDS,
|
| 46 |
+
min_code_chars: int = DEFAULT_MIN_CODE_CHARS,
|
| 47 |
+
max_code_chars: int = DEFAULT_MAX_CODE_CHARS,
|
| 48 |
+
output_dir: Path | None = None,
|
| 49 |
+
create_multiple_formats: bool = True,
|
| 50 |
+
) -> dict[str, Any]:
|
| 51 |
+
"""
|
| 52 |
+
Create optimized training dataset from CodeSearchNet for code search tasks.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
max_samples_per_lang: Maximum samples per programming language
|
| 56 |
+
min_doc_words: Minimum words in documentation
|
| 57 |
+
max_doc_words: Maximum words in documentation
|
| 58 |
+
min_code_chars: Minimum characters in code
|
| 59 |
+
max_code_chars: Maximum characters in code
|
| 60 |
+
output_dir: Output directory for dataset
|
| 61 |
+
create_multiple_formats: Create multiple training formats
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Dictionary with dataset statistics and file paths
|
| 65 |
+
"""
|
| 66 |
+
output_dir = DATASET_OUTPUT_DIR if output_dir is None else Path(output_dir)
|
| 67 |
+
|
| 68 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
logger.info("🚀 Starting optimized CodeSearchNet dataset creation...")
|
| 71 |
+
logger.info(f"📁 Output directory: {output_dir}")
|
| 72 |
+
logger.info(f"📊 Target: {max_samples_per_lang} samples per language")
|
| 73 |
+
logger.info(f"🔍 Languages: {', '.join(languages_config.all)}")
|
| 74 |
+
|
| 75 |
+
start_time = time.time()
|
| 76 |
+
all_samples = []
|
| 77 |
+
language_stats = {}
|
| 78 |
+
|
| 79 |
+
# Process each programming language
|
| 80 |
+
for language in languages_config.all:
|
| 81 |
+
logger.info(f"\n🔄 Processing {language}...")
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# Load CodeSearchNet dataset for this language
|
| 85 |
+
dataset = load_dataset("code_search_net", language, split="train", trust_remote_code=True)
|
| 86 |
+
|
| 87 |
+
language_samples = []
|
| 88 |
+
processed_count = 0
|
| 89 |
+
quality_filtered = 0
|
| 90 |
+
|
| 91 |
+
# Process examples with quality filtering
|
| 92 |
+
for example in tqdm(dataset, desc=f"Processing {language}", unit="examples"):
|
| 93 |
+
processed_count += 1
|
| 94 |
+
|
| 95 |
+
# Extract documentation and code
|
| 96 |
+
doc_string = example.get("func_documentation_string", "").strip()
|
| 97 |
+
code_string = example.get("func_code_string", "").strip()
|
| 98 |
+
func_name = example.get("func_name", "").strip()
|
| 99 |
+
|
| 100 |
+
# Quality filters
|
| 101 |
+
if not _passes_quality_filters(
|
| 102 |
+
doc_string, code_string, func_name, min_doc_words, max_doc_words, min_code_chars, max_code_chars
|
| 103 |
+
):
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
quality_filtered += 1
|
| 107 |
+
|
| 108 |
+
# Create optimized training samples
|
| 109 |
+
samples = _create_training_samples(
|
| 110 |
+
doc_string, code_string, func_name, language, create_multiple_formats
|
| 111 |
+
)
|
| 112 |
+
language_samples.extend(samples)
|
| 113 |
+
|
| 114 |
+
# Stop if we have enough samples
|
| 115 |
+
if len(language_samples) >= max_samples_per_lang:
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
# Truncate to exact target size
|
| 119 |
+
language_samples = language_samples[:max_samples_per_lang]
|
| 120 |
+
all_samples.extend(language_samples)
|
| 121 |
+
|
| 122 |
+
# Track statistics
|
| 123 |
+
language_stats[language] = {
|
| 124 |
+
"processed": processed_count,
|
| 125 |
+
"quality_filtered": quality_filtered,
|
| 126 |
+
"final_samples": len(language_samples),
|
| 127 |
+
"quality_rate": quality_filtered / processed_count if processed_count > 0 else 0,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
logger.info(f"✅ {language}: {len(language_samples)} samples from {quality_filtered} quality examples")
|
| 131 |
+
|
| 132 |
+
except Exception:
|
| 133 |
+
logger.exception(f"❌ Failed to process {language}")
|
| 134 |
+
language_stats[language] = {
|
| 135 |
+
"processed": 0,
|
| 136 |
+
"quality_filtered": 0,
|
| 137 |
+
"final_samples": 0,
|
| 138 |
+
"quality_rate": 0.0,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# Create DataFrame
|
| 142 |
+
logger.info(f"\n📊 Creating dataset with {len(all_samples)} total samples...")
|
| 143 |
+
df = pd.DataFrame(all_samples)
|
| 144 |
+
|
| 145 |
+
# Create stratified splits
|
| 146 |
+
train_df, test_df = _create_stratified_splits(df)
|
| 147 |
+
|
| 148 |
+
# Save datasets
|
| 149 |
+
dataset_files = _save_datasets(output_dir, train_df, test_df)
|
| 150 |
+
|
| 151 |
+
# Save metadata
|
| 152 |
+
metadata = {
|
| 153 |
+
"creation_time": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 154 |
+
"total_samples": len(all_samples),
|
| 155 |
+
"train_samples": len(train_df),
|
| 156 |
+
"test_samples": len(test_df),
|
| 157 |
+
"languages": languages_config.all,
|
| 158 |
+
"language_stats": language_stats,
|
| 159 |
+
"quality_filters": {
|
| 160 |
+
"min_doc_words": min_doc_words,
|
| 161 |
+
"max_doc_words": max_doc_words,
|
| 162 |
+
"min_code_chars": min_code_chars,
|
| 163 |
+
"max_code_chars": max_code_chars,
|
| 164 |
+
},
|
| 165 |
+
"files": dataset_files,
|
| 166 |
+
"processing_time": time.time() - start_time,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
metadata_file = output_dir / "metadata.json"
|
| 170 |
+
with metadata_file.open("w") as f:
|
| 171 |
+
json.dump(metadata, f, indent=2)
|
| 172 |
+
|
| 173 |
+
logger.info(f"\n🎉 Dataset creation completed in {metadata['processing_time']:.2f} seconds!")
|
| 174 |
+
logger.info("📊 Final statistics:")
|
| 175 |
+
logger.info(f" - Total samples: {metadata['total_samples']}")
|
| 176 |
+
logger.info(f" - Train: {metadata['train_samples']}")
|
| 177 |
+
logger.info(f" - Test: {metadata['test_samples']}")
|
| 178 |
+
logger.info(f"💾 Metadata saved to: {metadata_file}")
|
| 179 |
+
|
| 180 |
+
return metadata
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _passes_quality_filters(
|
| 184 |
+
doc_string: str,
|
| 185 |
+
code_string: str,
|
| 186 |
+
func_name: str,
|
| 187 |
+
min_doc_words: int,
|
| 188 |
+
max_doc_words: int,
|
| 189 |
+
min_code_chars: int,
|
| 190 |
+
max_code_chars: int,
|
| 191 |
+
) -> bool:
|
| 192 |
+
"""Apply quality filters optimized for code retrieval following RAG best practices."""
|
| 193 |
+
# Basic existence checks
|
| 194 |
+
if not doc_string or not code_string or not func_name:
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
# Documentation quality filters for code retrieval
|
| 198 |
+
doc_words = len(doc_string.split())
|
| 199 |
+
if doc_words < min_doc_words or doc_words > max_doc_words:
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
# Code quality filters
|
| 203 |
+
code_length = len(code_string)
|
| 204 |
+
if code_length < min_code_chars or code_length > max_code_chars:
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
# Content quality filters for code retrieval
|
| 208 |
+
doc_lower = doc_string.lower()
|
| 209 |
+
code_string.lower()
|
| 210 |
+
|
| 211 |
+
# Skip low-quality documentation (expanded for code context)
|
| 212 |
+
skip_phrases = [
|
| 213 |
+
"todo",
|
| 214 |
+
"fixme",
|
| 215 |
+
"hack",
|
| 216 |
+
"temp",
|
| 217 |
+
"test",
|
| 218 |
+
"placeholder",
|
| 219 |
+
"not implemented",
|
| 220 |
+
"coming soon",
|
| 221 |
+
"tbd",
|
| 222 |
+
"xxx",
|
| 223 |
+
"broken",
|
| 224 |
+
"deprecated",
|
| 225 |
+
"legacy",
|
| 226 |
+
"old version",
|
| 227 |
+
"outdated",
|
| 228 |
+
]
|
| 229 |
+
if any(phrase in doc_lower for phrase in skip_phrases):
|
| 230 |
+
return False
|
| 231 |
+
|
| 232 |
+
# Ensure meaningful documentation for code retrieval
|
| 233 |
+
if func_name.lower() in doc_lower and doc_words < 5:
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
# Code structure validation (more comprehensive for retrieval)
|
| 237 |
+
has_function = any(
|
| 238 |
+
pattern in code_string for pattern in ["def ", "function ", "class ", "public ", "private ", "static "]
|
| 239 |
+
)
|
| 240 |
+
if not has_function:
|
| 241 |
+
return False
|
| 242 |
+
|
| 243 |
+
# Skip trivial or incomplete code
|
| 244 |
+
trivial_code_patterns = [
|
| 245 |
+
"pass",
|
| 246 |
+
"return None",
|
| 247 |
+
"return;",
|
| 248 |
+
"throw new Error",
|
| 249 |
+
"# TODO",
|
| 250 |
+
"// TODO",
|
| 251 |
+
"print(",
|
| 252 |
+
"console.log(",
|
| 253 |
+
]
|
| 254 |
+
if any(pattern in code_string for pattern in trivial_code_patterns) and len(code_string) < 100:
|
| 255 |
+
return False
|
| 256 |
+
|
| 257 |
+
# Ensure documentation describes functionality (not just naming)
|
| 258 |
+
generic_docs = [
|
| 259 |
+
"returns a value",
|
| 260 |
+
"does something",
|
| 261 |
+
"helper function",
|
| 262 |
+
"utility method",
|
| 263 |
+
"this function",
|
| 264 |
+
"this method",
|
| 265 |
+
"returns the result",
|
| 266 |
+
"performs operation",
|
| 267 |
+
]
|
| 268 |
+
if any(generic in doc_lower for generic in generic_docs):
|
| 269 |
+
return False
|
| 270 |
+
|
| 271 |
+
# Ensure documentation has descriptive content for retrieval
|
| 272 |
+
descriptive_words = [
|
| 273 |
+
"parse",
|
| 274 |
+
"convert",
|
| 275 |
+
"transform",
|
| 276 |
+
"calculate",
|
| 277 |
+
"validate",
|
| 278 |
+
"format",
|
| 279 |
+
"filter",
|
| 280 |
+
"sort",
|
| 281 |
+
"search",
|
| 282 |
+
"find",
|
| 283 |
+
"create",
|
| 284 |
+
"generate",
|
| 285 |
+
"process",
|
| 286 |
+
"handle",
|
| 287 |
+
"manage",
|
| 288 |
+
"update",
|
| 289 |
+
"modify",
|
| 290 |
+
"remove",
|
| 291 |
+
"delete",
|
| 292 |
+
"add",
|
| 293 |
+
]
|
| 294 |
+
if not any(word in doc_lower for word in descriptive_words) and doc_words < 8:
|
| 295 |
+
return False
|
| 296 |
+
|
| 297 |
+
# Code-documentation alignment check (key for retrieval quality)
|
| 298 |
+
return _check_code_doc_alignment(doc_string, code_string, func_name)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _check_code_doc_alignment(doc_string: str, code_string: str, func_name: str) -> bool:
|
| 302 |
+
"""Check if documentation and code are well-aligned for retrieval tasks."""
|
| 303 |
+
doc_lower = doc_string.lower()
|
| 304 |
+
code_lower = code_string.lower()
|
| 305 |
+
|
| 306 |
+
# Function name should relate to documentation
|
| 307 |
+
func_base = func_name.lower().replace("_", " ").replace("-", " ")
|
| 308 |
+
|
| 309 |
+
# Check for obvious mismatches
|
| 310 |
+
doc_has_return = any(word in doc_lower for word in ["return", "returns", "gives", "outputs"])
|
| 311 |
+
code_has_return = "return " in code_lower
|
| 312 |
+
|
| 313 |
+
# If doc mentions returning something, code should have returns
|
| 314 |
+
if doc_has_return and not code_has_return and len(code_string.split("\n")) > 3:
|
| 315 |
+
return False
|
| 316 |
+
|
| 317 |
+
# Check for parameter mentions alignment
|
| 318 |
+
any(word in doc_lower for word in ["parameter", "param", "argument", "input"])
|
| 319 |
+
"(" in func_name and func_name.count("(") == 1
|
| 320 |
+
|
| 321 |
+
# Basic semantic alignment
|
| 322 |
+
action_words = ["sort", "parse", "convert", "validate", "format", "filter", "search", "calculate"]
|
| 323 |
+
doc_actions = [word for word in action_words if word in doc_lower]
|
| 324 |
+
[word for word in action_words if word in code_lower or word in func_base]
|
| 325 |
+
|
| 326 |
+
# If documentation mentions specific actions, code or function name should reflect them
|
| 327 |
+
return not (doc_actions and not any(action in code_lower or action in func_base for action in doc_actions))
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def _create_training_samples(
|
| 331 |
+
doc_string: str,
|
| 332 |
+
code_string: str,
|
| 333 |
+
func_name: str,
|
| 334 |
+
language: str,
|
| 335 |
+
create_multiple_formats: bool,
|
| 336 |
+
) -> list[dict[str, Any]]:
|
| 337 |
+
"""Create optimized training samples for code retrieval with proper training schema."""
|
| 338 |
+
samples = []
|
| 339 |
+
|
| 340 |
+
if create_multiple_formats:
|
| 341 |
+
# Format 1: Documentation query → Code (direct evaluation format)
|
| 342 |
+
query_1 = doc_string
|
| 343 |
+
text_1 = _format_training_text(query_1, code_string, language)
|
| 344 |
+
samples.append(
|
| 345 |
+
{
|
| 346 |
+
"language": language,
|
| 347 |
+
"query": query_1,
|
| 348 |
+
"code": code_string,
|
| 349 |
+
"text": text_1,
|
| 350 |
+
}
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Format 2: How-to query (realistic developer search)
|
| 354 |
+
query_2 = _generate_how_to_query(doc_string, func_name, language)
|
| 355 |
+
text_2 = _format_training_text(query_2, code_string, language)
|
| 356 |
+
samples.append(
|
| 357 |
+
{
|
| 358 |
+
"language": language,
|
| 359 |
+
"query": query_2,
|
| 360 |
+
"code": code_string,
|
| 361 |
+
"text": text_2,
|
| 362 |
+
}
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Format 3: Functional requirement query
|
| 366 |
+
query_3 = _generate_functional_query(doc_string, func_name)
|
| 367 |
+
text_3 = _format_training_text(query_3, code_string, language)
|
| 368 |
+
samples.append(
|
| 369 |
+
{
|
| 370 |
+
"language": language,
|
| 371 |
+
"query": query_3,
|
| 372 |
+
"code": code_string,
|
| 373 |
+
"text": text_3,
|
| 374 |
+
}
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Format 4: Implementation-specific query
|
| 378 |
+
query_4 = _generate_implementation_query(doc_string, func_name, language)
|
| 379 |
+
text_4 = _format_training_text(query_4, code_string, language)
|
| 380 |
+
samples.append(
|
| 381 |
+
{
|
| 382 |
+
"language": language,
|
| 383 |
+
"query": query_4,
|
| 384 |
+
"code": code_string,
|
| 385 |
+
"text": text_4,
|
| 386 |
+
}
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
else:
|
| 390 |
+
# Simple format - direct documentation to code
|
| 391 |
+
query = doc_string
|
| 392 |
+
text = _format_training_text(query, code_string, language)
|
| 393 |
+
samples.append(
|
| 394 |
+
{
|
| 395 |
+
"language": language,
|
| 396 |
+
"query": query,
|
| 397 |
+
"code": code_string,
|
| 398 |
+
"text": text,
|
| 399 |
+
}
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
return samples
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def _format_training_text(query: str, code: str, language: str) -> str:
|
| 406 |
+
"""Format query and code into a single training text chunk with markdown-style code blocks."""
|
| 407 |
+
# Clean up query but preserve internal code formatting
|
| 408 |
+
query_clean = query.strip()
|
| 409 |
+
code_clean = code.strip()
|
| 410 |
+
|
| 411 |
+
# Create training text with proper markdown format and newline separation
|
| 412 |
+
# Structure: query + empty line + markdown code block with language
|
| 413 |
+
return f"{query_clean}\n\n```{language}\n{code_clean}\n```"
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def _generate_how_to_query(doc_string: str, func_name: str, language: str) -> str:
|
| 417 |
+
"""Generate realistic 'how to' queries that developers might actually search for."""
|
| 418 |
+
# Extract key action words from documentation
|
| 419 |
+
doc_lower = doc_string.lower()
|
| 420 |
+
func_lower = func_name.lower()
|
| 421 |
+
|
| 422 |
+
# Common developer query patterns
|
| 423 |
+
if "sort" in doc_lower or "sort" in func_lower:
|
| 424 |
+
return f"How to sort data in {language}"
|
| 425 |
+
if "parse" in doc_lower or "parse" in func_lower:
|
| 426 |
+
return f"How to parse data in {language}"
|
| 427 |
+
if "convert" in doc_lower or "transform" in doc_lower or "convert" in func_lower:
|
| 428 |
+
return f"How to convert data in {language}"
|
| 429 |
+
if "validate" in doc_lower or "check" in doc_lower or "validate" in func_lower:
|
| 430 |
+
return f"How to validate input in {language}"
|
| 431 |
+
if "calculate" in doc_lower or "compute" in doc_lower or "calc" in func_lower:
|
| 432 |
+
return f"How to calculate values in {language}"
|
| 433 |
+
if "format" in doc_lower or "format" in func_lower:
|
| 434 |
+
return f"How to format output in {language}"
|
| 435 |
+
if "filter" in doc_lower or "filter" in func_lower:
|
| 436 |
+
return f"How to filter data in {language}"
|
| 437 |
+
if "search" in doc_lower or "find" in doc_lower or "search" in func_lower or "find" in func_lower:
|
| 438 |
+
return f"How to search through data in {language}"
|
| 439 |
+
# Use function name for more specific queries
|
| 440 |
+
if func_name and len(func_name) > 2:
|
| 441 |
+
# Extract meaningful words from function name
|
| 442 |
+
func_words = func_name.replace("_", " ").replace("-", " ").strip()
|
| 443 |
+
if func_words:
|
| 444 |
+
return f"How to {func_words.lower()} in {language}"
|
| 445 |
+
# Fallback to more generic query
|
| 446 |
+
action = doc_string.split()[0] if doc_string.split() else "implement"
|
| 447 |
+
return f"How to {action.lower()} in {language}"
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def _generate_functional_query(doc_string: str, func_name: str) -> str:
|
| 451 |
+
"""Generate functional requirement queries focusing on what the code accomplishes."""
|
| 452 |
+
# Clean up documentation to create natural query
|
| 453 |
+
doc_clean = doc_string.strip().rstrip(".")
|
| 454 |
+
|
| 455 |
+
# Transform to question format
|
| 456 |
+
if doc_clean.startswith(("Returns", "Return")):
|
| 457 |
+
return f"Function that {doc_clean.lower()}"
|
| 458 |
+
if doc_clean.startswith(("Creates", "Create")):
|
| 459 |
+
return f"Code to {doc_clean.lower()}"
|
| 460 |
+
if doc_clean.startswith(("Checks", "Check")):
|
| 461 |
+
return f"Function to {doc_clean.lower()}"
|
| 462 |
+
|
| 463 |
+
# Use function name to enhance the query if available
|
| 464 |
+
if func_name and len(func_name) > 2:
|
| 465 |
+
func_words = func_name.replace("_", " ").replace("-", " ").strip()
|
| 466 |
+
if func_words and len(doc_clean) < 30: # Only for short docs
|
| 467 |
+
return f"Function named '{func_name}' that {doc_clean.lower()}"
|
| 468 |
+
|
| 469 |
+
return f"Implementation that {doc_clean.lower()}"
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _generate_implementation_query(doc_string: str, func_name: str, language: str) -> str:
|
| 473 |
+
"""Generate implementation-specific queries with technical details."""
|
| 474 |
+
doc_lower = doc_string.lower()
|
| 475 |
+
func_lower = func_name.lower() if func_name else ""
|
| 476 |
+
|
| 477 |
+
# Add language-specific implementation details
|
| 478 |
+
if language == "python":
|
| 479 |
+
if "list" in doc_lower or "array" in doc_lower or "list" in func_lower:
|
| 480 |
+
return f"Python function to {doc_string.lower()} using lists"
|
| 481 |
+
if "dict" in doc_lower or "hash" in doc_lower or "dict" in func_lower:
|
| 482 |
+
return f"Python function to {doc_string.lower()} using dictionaries"
|
| 483 |
+
# Include function name for context if available
|
| 484 |
+
if func_name and len(func_name) > 2:
|
| 485 |
+
return f"Python implementation of {func_name}: {doc_string.lower()}"
|
| 486 |
+
return f"Python implementation: {doc_string.lower()}"
|
| 487 |
+
if language == "java":
|
| 488 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
| 489 |
+
return f"Java method to {doc_string.lower()}{func_suffix}"
|
| 490 |
+
if language == "javascript":
|
| 491 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
| 492 |
+
return f"JavaScript function to {doc_string.lower()}{func_suffix}"
|
| 493 |
+
if language == "php":
|
| 494 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
| 495 |
+
return f"PHP function to {doc_string.lower()}{func_suffix}"
|
| 496 |
+
if language == "ruby":
|
| 497 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
| 498 |
+
return f"Ruby method to {doc_string.lower()}{func_suffix}"
|
| 499 |
+
if language == "go":
|
| 500 |
+
func_suffix = f" ({func_name})" if func_name and len(func_name) > 2 else ""
|
| 501 |
+
return f"Go function to {doc_string.lower()}{func_suffix}"
|
| 502 |
+
return f"{language} code to {doc_string.lower()}"
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def _create_stratified_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 506 |
+
"""Create stratified train/test splits preserving language distribution."""
|
| 507 |
+
# Define split ratios
|
| 508 |
+
train_ratio = 0.9
|
| 509 |
+
# test_ratio = 0.1 (remainder)
|
| 510 |
+
|
| 511 |
+
train_dfs = []
|
| 512 |
+
test_dfs = []
|
| 513 |
+
|
| 514 |
+
# Split by language to ensure balanced representation
|
| 515 |
+
for language in df["language"].unique():
|
| 516 |
+
lang_df = df[df["language"] == language].copy()
|
| 517 |
+
n_samples = len(lang_df)
|
| 518 |
+
|
| 519 |
+
# Calculate split sizes
|
| 520 |
+
n_train = int(n_samples * train_ratio)
|
| 521 |
+
# Remainder goes to test
|
| 522 |
+
|
| 523 |
+
# Shuffle and split
|
| 524 |
+
lang_df = lang_df.sample(frac=1, random_state=42).reset_index(drop=True)
|
| 525 |
+
|
| 526 |
+
train_dfs.append(lang_df[:n_train])
|
| 527 |
+
test_dfs.append(lang_df[n_train:])
|
| 528 |
+
|
| 529 |
+
# Combine and shuffle again
|
| 530 |
+
train_df = pd.concat(train_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
|
| 531 |
+
test_df = pd.concat(test_dfs, ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
|
| 532 |
+
|
| 533 |
+
logger.info("📊 Created stratified splits:")
|
| 534 |
+
logger.info(f" - Train: {len(train_df)} samples")
|
| 535 |
+
logger.info(f" - Test: {len(test_df)} samples")
|
| 536 |
+
|
| 537 |
+
return train_df, test_df
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def _save_datasets(
|
| 541 |
+
output_dir: Path,
|
| 542 |
+
train_df: pd.DataFrame,
|
| 543 |
+
test_df: pd.DataFrame,
|
| 544 |
+
) -> dict[str, str]:
|
| 545 |
+
"""Save datasets in parquet format with compression."""
|
| 546 |
+
dataset_files = {}
|
| 547 |
+
|
| 548 |
+
# Save each split
|
| 549 |
+
for split_name, df in [("train", train_df), ("test", test_df)]:
|
| 550 |
+
filepath = output_dir / f"{split_name}.parquet"
|
| 551 |
+
df.to_parquet(
|
| 552 |
+
filepath,
|
| 553 |
+
compression="snappy",
|
| 554 |
+
index=False,
|
| 555 |
+
)
|
| 556 |
+
dataset_files[split_name] = str(filepath)
|
| 557 |
+
logger.info(f"💾 Saved {split_name}: {len(df)} samples → {filepath}")
|
| 558 |
+
|
| 559 |
+
# Also save a combined dataset for convenience
|
| 560 |
+
combined_df = pd.concat([train_df, test_df], ignore_index=True)
|
| 561 |
+
combined_filepath = output_dir / "combined.parquet"
|
| 562 |
+
combined_df.to_parquet(combined_filepath, compression="snappy", index=False)
|
| 563 |
+
dataset_files["combined"] = str(combined_filepath)
|
| 564 |
+
logger.info(f"💾 Saved combined: {len(combined_df)} samples → {combined_filepath}")
|
| 565 |
+
|
| 566 |
+
return dataset_files
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def load_optimized_dataset(
|
| 570 |
+
output_dir: Path | None = None,
|
| 571 |
+
split: str = "train",
|
| 572 |
+
) -> pd.DataFrame:
|
| 573 |
+
"""
|
| 574 |
+
Load a previously created optimized dataset.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
output_dir: Directory containing the dataset files
|
| 578 |
+
split: Which split to load ('train', 'test', 'combined')
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
DataFrame with the requested dataset split
|
| 582 |
+
"""
|
| 583 |
+
if output_dir is None:
|
| 584 |
+
output_dir = DATASET_OUTPUT_DIR
|
| 585 |
+
|
| 586 |
+
filepath = output_dir / f"{split}.parquet"
|
| 587 |
+
|
| 588 |
+
if not filepath.exists():
|
| 589 |
+
available_files = list(output_dir.glob("*.parquet"))
|
| 590 |
+
available_splits = [f.stem for f in available_files]
|
| 591 |
+
msg = f"Dataset split '{split}' not found at {filepath}. Available splits: {available_splits}"
|
| 592 |
+
raise FileNotFoundError(msg)
|
| 593 |
+
|
| 594 |
+
logger.info(f"📂 Loading {split} dataset from {filepath}")
|
| 595 |
+
df = pd.read_parquet(filepath)
|
| 596 |
+
logger.info(f"✅ Loaded {len(df)} samples")
|
| 597 |
+
|
| 598 |
+
return df
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def main(
|
| 602 |
+
max_samples_per_lang: Annotated[
|
| 603 |
+
int, typer.Option(help="Maximum samples per language")
|
| 604 |
+
] = DEFAULT_MAX_SAMPLES_PER_LANG,
|
| 605 |
+
min_doc_words: Annotated[int, typer.Option(help="Minimum words in documentation")] = DEFAULT_MIN_DOC_WORDS,
|
| 606 |
+
max_doc_words: Annotated[int, typer.Option(help="Maximum words in documentation")] = DEFAULT_MAX_DOC_WORDS,
|
| 607 |
+
min_code_chars: Annotated[int, typer.Option(help="Minimum characters in code")] = DEFAULT_MIN_CODE_CHARS,
|
| 608 |
+
max_code_chars: Annotated[int, typer.Option(help="Maximum characters in code")] = DEFAULT_MAX_CODE_CHARS,
|
| 609 |
+
output_dir: Annotated[str | None, typer.Option(help="Output directory for dataset")] = None,
|
| 610 |
+
simple_format: Annotated[
|
| 611 |
+
bool, typer.Option(help="Create only simple format (not multiple training formats)")
|
| 612 |
+
] = False,
|
| 613 |
+
) -> None:
|
| 614 |
+
"""Create optimized training dataset from CodeSearchNet for code search tasks."""
|
| 615 |
+
logger.info("🚀 Starting optimized dataset creation command...")
|
| 616 |
+
|
| 617 |
+
# Convert output_dir to Path if provided
|
| 618 |
+
output_path = Path(output_dir) if output_dir else None
|
| 619 |
+
|
| 620 |
+
# Create the dataset
|
| 621 |
+
try:
|
| 622 |
+
metadata = create_optimized_dataset(
|
| 623 |
+
max_samples_per_lang=max_samples_per_lang,
|
| 624 |
+
min_doc_words=min_doc_words,
|
| 625 |
+
max_doc_words=max_doc_words,
|
| 626 |
+
min_code_chars=min_code_chars,
|
| 627 |
+
max_code_chars=max_code_chars,
|
| 628 |
+
output_dir=output_path,
|
| 629 |
+
create_multiple_formats=not simple_format,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
logger.info("✅ Dataset creation completed successfully!")
|
| 633 |
+
logger.info(f"📁 Output directory: {metadata['files']['train']}")
|
| 634 |
+
|
| 635 |
+
# Print summary statistics
|
| 636 |
+
print("\n" + "=" * 60)
|
| 637 |
+
print("📊 DATASET CREATION SUMMARY")
|
| 638 |
+
print("=" * 60)
|
| 639 |
+
print(f"Total samples created: {metadata['total_samples']:,}")
|
| 640 |
+
print(f"Processing time: {metadata['processing_time']:.2f} seconds")
|
| 641 |
+
print("\nSplit distribution:")
|
| 642 |
+
print(f" • Train: {metadata['train_samples']:,} samples")
|
| 643 |
+
print(f" • Test: {metadata['test_samples']:,} samples")
|
| 644 |
+
|
| 645 |
+
print("\nLanguage distribution:")
|
| 646 |
+
for lang, stats in metadata["language_stats"].items():
|
| 647 |
+
if "error" not in stats:
|
| 648 |
+
print(f" • {lang}: {stats['final_samples']:,} samples ({stats['quality_rate']:.1%} quality rate)")
|
| 649 |
+
|
| 650 |
+
print(f"\nDataset files saved to: {output_path or DATASET_OUTPUT_DIR}")
|
| 651 |
+
print("=" * 60)
|
| 652 |
+
|
| 653 |
+
except Exception as e:
|
| 654 |
+
logger.exception("❌ Dataset creation failed")
|
| 655 |
+
raise typer.Exit(1) from e
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
if __name__ == "__main__":
|
| 659 |
+
typer.run(main)
|
src/distiller/distill.py
CHANGED
|
@@ -28,13 +28,14 @@ import time
|
|
| 28 |
from pathlib import Path
|
| 29 |
from typing import Annotated, Any
|
| 30 |
|
|
|
|
| 31 |
import torch
|
| 32 |
import typer
|
| 33 |
from beam import function
|
| 34 |
-
from datasets import load_dataset
|
| 35 |
-
from model2vec.distill import distill
|
| 36 |
from sentence_transformers import SentenceTransformer
|
| 37 |
|
|
|
|
|
|
|
| 38 |
# Try to import flash_attn to check if it's available
|
| 39 |
from .beam_utils import (
|
| 40 |
BeamCheckpointManager,
|
|
@@ -145,25 +146,6 @@ def load_model_with_flash_attention(model_path: str, device: str = "auto") -> Se
|
|
| 145 |
# =============================================================================
|
| 146 |
|
| 147 |
|
| 148 |
-
def apply_local_patches() -> bool:
|
| 149 |
-
"""Apply patches locally without requiring Beam utilities."""
|
| 150 |
-
try:
|
| 151 |
-
try:
|
| 152 |
-
from .patch_utils import apply_all_patches
|
| 153 |
-
|
| 154 |
-
patches_applied = apply_all_patches()
|
| 155 |
-
logger.info(f"Successfully applied {patches_applied} patches via patch_utils")
|
| 156 |
-
return True
|
| 157 |
-
except ImportError:
|
| 158 |
-
logger.warning("patch_utils not available, trying direct patching")
|
| 159 |
-
|
| 160 |
-
return False
|
| 161 |
-
|
| 162 |
-
except Exception as e:
|
| 163 |
-
logger.warning(f"Failed to apply patches: {e}")
|
| 164 |
-
return False
|
| 165 |
-
|
| 166 |
-
|
| 167 |
def get_current_config_hash(enable_training: bool) -> str:
|
| 168 |
"""Generate a hash of current configuration parameters for checkpoint validation."""
|
| 169 |
import hashlib
|
|
@@ -217,22 +199,22 @@ def check_existing_final_model(teacher_name: str, enable_training: bool = False)
|
|
| 217 |
model_name = f"code_model2vec_{teacher_name}"
|
| 218 |
if enable_training:
|
| 219 |
model_name += "_fine_tuned"
|
| 220 |
-
|
| 221 |
|
| 222 |
-
if
|
| 223 |
# Check for essential model files
|
| 224 |
-
has_config = (
|
| 225 |
has_model_file = any(
|
| 226 |
[
|
| 227 |
-
(
|
| 228 |
-
(
|
| 229 |
-
(
|
| 230 |
]
|
| 231 |
)
|
| 232 |
|
| 233 |
if has_config and has_model_file:
|
| 234 |
logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
| 235 |
-
return str(
|
| 236 |
|
| 237 |
return None
|
| 238 |
|
|
@@ -427,11 +409,65 @@ def simple_distillation(
|
|
| 427 |
return None
|
| 428 |
|
| 429 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
def load_codesearchnet_dataset(
|
| 431 |
max_samples: int = 50000,
|
| 432 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
| 433 |
) -> list[str]:
|
| 434 |
"""Load and format the CodeSearchNet dataset for token frequency computation."""
|
|
|
|
|
|
|
| 435 |
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
|
| 436 |
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
| 437 |
logger.info(f"Languages: {', '.join(languages_config.all)}")
|
|
@@ -482,6 +518,8 @@ def load_codesearchnet_dataset(
|
|
| 482 |
|
| 483 |
try:
|
| 484 |
# Load training split for the specific language (same format as evaluate.py)
|
|
|
|
|
|
|
| 485 |
dataset = load_dataset(
|
| 486 |
codesearchnet_config.dataset_name,
|
| 487 |
language,
|
|
@@ -709,8 +747,33 @@ def compute_token_frequencies_for_sif(
|
|
| 709 |
logger.info("📊 Computing token frequencies for SIF weighting...")
|
| 710 |
|
| 711 |
try:
|
| 712 |
-
# Load
|
| 713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
logger.info(f"📊 Computing frequencies on {len(dataset_texts)} texts...")
|
| 716 |
|
|
@@ -763,7 +826,6 @@ def apply_post_training_regularization(
|
|
| 763 |
"""
|
| 764 |
import json
|
| 765 |
|
| 766 |
-
import numpy as np
|
| 767 |
from sklearn.decomposition import PCA
|
| 768 |
|
| 769 |
logger.info("🔧 Starting post-training re-regularization (POTION Step 4)")
|
|
@@ -836,7 +898,7 @@ def apply_post_training_regularization(
|
|
| 836 |
final_embeddings = embeddings_pca.astype(np.float32)
|
| 837 |
|
| 838 |
# Create new model with updated embeddings
|
| 839 |
-
from model2vec.model import StaticModel
|
| 840 |
|
| 841 |
# Save tokenizer and config from original model
|
| 842 |
tokenizer = model.tokenizer
|
|
@@ -866,7 +928,6 @@ def tokenlearn_training(
|
|
| 866 |
3. Tokenlearn training
|
| 867 |
4. Post-training re-regularization (PCA + SIF weighting)
|
| 868 |
"""
|
| 869 |
-
import subprocess
|
| 870 |
from pathlib import Path
|
| 871 |
|
| 872 |
logger.info("🧪 Starting tokenlearn training (POTION approach)...")
|
|
@@ -914,6 +975,9 @@ def tokenlearn_training(
|
|
| 914 |
|
| 915 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
| 916 |
|
|
|
|
|
|
|
|
|
|
| 917 |
# Check if featurization already completed (checkpoint detection)
|
| 918 |
featurization_complete_marker = features_dir / ".featurization_complete"
|
| 919 |
if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
|
|
@@ -936,47 +1000,42 @@ def tokenlearn_training(
|
|
| 936 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
| 937 |
|
| 938 |
try:
|
| 939 |
-
# Use
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
"tokenlearn.featurize",
|
| 944 |
-
"--model-name",
|
| 945 |
-
str(teacher_model_name),
|
| 946 |
-
"--output-dir",
|
| 947 |
-
str(features_dir),
|
| 948 |
-
"--dataset-path",
|
| 949 |
-
str(distillation_config.tokenlearn_dataset),
|
| 950 |
-
"--dataset-name",
|
| 951 |
-
str(distillation_config.tokenlearn_dataset_name),
|
| 952 |
-
"--dataset-split",
|
| 953 |
-
"train",
|
| 954 |
-
"--key",
|
| 955 |
-
str(distillation_config.tokenlearn_text_key), # Use configured text field
|
| 956 |
-
"--batch-size",
|
| 957 |
-
"1024", # Optimized batch size for A100-40G
|
| 958 |
-
]
|
| 959 |
|
| 960 |
logger.info("🔄 Running tokenlearn featurization...")
|
| 961 |
-
logger.info(
|
| 962 |
-
|
| 963 |
-
)
|
| 964 |
-
logger.info(f"📝 Text field: {distillation_config.tokenlearn_text_key}")
|
| 965 |
-
logger.info(f"Command: {' '.join(featurize_cmd)}")
|
| 966 |
-
print(f"\n🔄 Executing: {' '.join(featurize_cmd)}\n")
|
| 967 |
-
|
| 968 |
-
result = subprocess.run( # noqa: S603
|
| 969 |
-
featurize_cmd,
|
| 970 |
-
text=True,
|
| 971 |
-
timeout=distillation_config.tokenlearn_timeout_featurize,
|
| 972 |
-
check=False,
|
| 973 |
-
)
|
| 974 |
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
|
| 981 |
logger.info("✅ Featurization completed successfully")
|
| 982 |
|
|
@@ -1025,65 +1084,74 @@ def tokenlearn_training(
|
|
| 1025 |
logger.info("🔄 No valid training checkpoint found - starting training...")
|
| 1026 |
|
| 1027 |
try:
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
"tokenlearn.train",
|
| 1032 |
-
"--model-name",
|
| 1033 |
-
str(teacher_model_name),
|
| 1034 |
-
"--data-path",
|
| 1035 |
-
str(features_dir),
|
| 1036 |
-
"--save-path",
|
| 1037 |
-
str(trained_dir),
|
| 1038 |
-
]
|
| 1039 |
|
| 1040 |
-
|
| 1041 |
-
logger.info(
|
| 1042 |
-
|
| 1043 |
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1051 |
|
| 1052 |
-
|
| 1053 |
-
|
|
|
|
|
|
|
|
|
|
| 1054 |
|
| 1055 |
-
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
-
if result.stdout:
|
| 1059 |
-
logger.info(f"stdout: {result.stdout}")
|
| 1060 |
|
| 1061 |
-
#
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1066 |
|
| 1067 |
# Create training marker to indicate we tried but failed
|
| 1068 |
training_fallback_marker = trained_dir / ".training_fallback"
|
| 1069 |
training_fallback_marker.touch()
|
| 1070 |
|
| 1071 |
-
logger.
|
| 1072 |
-
msg = f"
|
| 1073 |
-
raise RuntimeError(msg)
|
| 1074 |
-
logger.error("💥 Tokenlearn training failed with different error")
|
| 1075 |
-
msg = f"Tokenlearn training failed with return code: {result.returncode}"
|
| 1076 |
-
raise RuntimeError(msg)
|
| 1077 |
-
logger.info("✅ Tokenlearn training completed successfully")
|
| 1078 |
-
|
| 1079 |
-
# Create checkpoint marker to indicate training is complete
|
| 1080 |
-
training_complete_marker.touch()
|
| 1081 |
-
logger.info(f"💾 Created training checkpoint: {training_complete_marker}")
|
| 1082 |
|
| 1083 |
except Exception as e:
|
| 1084 |
-
logger.
|
| 1085 |
-
logger.exception("💥
|
| 1086 |
-
msg = f"
|
| 1087 |
raise RuntimeError(msg) from e
|
| 1088 |
|
| 1089 |
# Step 4: Load the trained model and apply post-training re-regularization
|
|
@@ -1098,7 +1166,7 @@ def tokenlearn_training(
|
|
| 1098 |
raise RuntimeError(msg)
|
| 1099 |
|
| 1100 |
try:
|
| 1101 |
-
from model2vec.model import StaticModel
|
| 1102 |
|
| 1103 |
# Load the trained model from tokenlearn
|
| 1104 |
trained_model_path = trained_dir / "model"
|
|
@@ -1213,12 +1281,13 @@ def distill_single_teacher(
|
|
| 1213 |
existing_final = check_existing_final_model(teacher_name, enable_training)
|
| 1214 |
if existing_final:
|
| 1215 |
logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
|
|
|
| 1216 |
return {
|
| 1217 |
"teacher_model": teacher_model,
|
| 1218 |
"teacher_name": teacher_name,
|
| 1219 |
"status": "skipped_existing_final",
|
| 1220 |
"final_path": existing_final,
|
| 1221 |
-
"distillation_time":
|
| 1222 |
}
|
| 1223 |
|
| 1224 |
# Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
|
|
@@ -1236,7 +1305,7 @@ def distill_single_teacher(
|
|
| 1236 |
logger.info(f"✅ Found existing base model: {teacher_name}")
|
| 1237 |
if enable_training:
|
| 1238 |
# Load base model for training
|
| 1239 |
-
from model2vec.model import StaticModel
|
| 1240 |
|
| 1241 |
base_model = StaticModel.from_pretrained(existing_base)
|
| 1242 |
elif use_beam_utilities:
|
|
@@ -1244,7 +1313,7 @@ def distill_single_teacher(
|
|
| 1244 |
if synced:
|
| 1245 |
existing_base = str(base_dir)
|
| 1246 |
if enable_training:
|
| 1247 |
-
from model2vec.model import StaticModel
|
| 1248 |
|
| 1249 |
base_model = StaticModel.from_pretrained(existing_base)
|
| 1250 |
|
|
@@ -1263,11 +1332,13 @@ def distill_single_teacher(
|
|
| 1263 |
base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
|
| 1264 |
|
| 1265 |
if base_model is None:
|
|
|
|
| 1266 |
return {
|
| 1267 |
"teacher_model": teacher_model,
|
| 1268 |
"teacher_name": teacher_name,
|
| 1269 |
"status": "failed_base_distillation",
|
| 1270 |
"error": "Simple distillation failed",
|
|
|
|
| 1271 |
}
|
| 1272 |
|
| 1273 |
# Sync base model and checkpoints to Beam
|
|
@@ -1280,71 +1351,74 @@ def distill_single_teacher(
|
|
| 1280 |
|
| 1281 |
existing_base = str(base_dir)
|
| 1282 |
|
| 1283 |
-
|
| 1284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1285 |
# Perform tokenlearn training (POTION approach)
|
| 1286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1287 |
|
| 1288 |
-
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
|
| 1292 |
-
|
| 1293 |
-
# Perform tokenlearn training (POTION approach)
|
| 1294 |
-
final_model = tokenlearn_training(
|
| 1295 |
-
base_model,
|
| 1296 |
-
teacher_st_model,
|
| 1297 |
-
checkpoint_mgr,
|
| 1298 |
-
skip_post_training_regularization=distillation_config.skip_post_training_regularization,
|
| 1299 |
-
)
|
| 1300 |
|
| 1301 |
-
|
| 1302 |
-
|
| 1303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1304 |
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
|
| 1308 |
-
if checkpoint_mgr:
|
| 1309 |
-
sync_checkpoints_to_beam(
|
| 1310 |
-
VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
|
| 1311 |
-
)
|
| 1312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1313 |
del teacher_st_model
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
except RuntimeError as e:
|
| 1318 |
-
# Training failed - clean up and return failure
|
| 1319 |
-
logger.exception(f"❌ Training failed for {teacher_name}")
|
| 1320 |
-
|
| 1321 |
-
# Clean up teacher model if it was loaded
|
| 1322 |
-
if "teacher_st_model" in locals():
|
| 1323 |
-
del teacher_st_model
|
| 1324 |
-
if torch.cuda.is_available():
|
| 1325 |
-
torch.cuda.empty_cache()
|
| 1326 |
-
|
| 1327 |
-
return {
|
| 1328 |
-
"teacher_model": teacher_model,
|
| 1329 |
-
"teacher_name": teacher_name,
|
| 1330 |
-
"status": "failed_training",
|
| 1331 |
-
"error": f"Training failed: {e!s}",
|
| 1332 |
-
"base_path": existing_base, # Base model was created successfully
|
| 1333 |
-
}
|
| 1334 |
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
}
|
| 1345 |
|
| 1346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1347 |
|
|
|
|
| 1348 |
return {
|
| 1349 |
"teacher_model": teacher_model,
|
| 1350 |
"teacher_name": teacher_name,
|
|
@@ -1357,11 +1431,13 @@ def distill_single_teacher(
|
|
| 1357 |
|
| 1358 |
except Exception as e:
|
| 1359 |
logger.exception(f"❌ Failed to process {teacher_model}")
|
|
|
|
| 1360 |
return {
|
| 1361 |
"teacher_model": teacher_model,
|
| 1362 |
"teacher_name": teacher_name,
|
| 1363 |
"status": "failed",
|
| 1364 |
"error": str(e),
|
|
|
|
| 1365 |
}
|
| 1366 |
|
| 1367 |
|
|
@@ -1382,13 +1458,6 @@ def run_local_distillation(
|
|
| 1382 |
if teacher_models is None:
|
| 1383 |
teacher_models = DEFAULT_TEACHER_MODELS
|
| 1384 |
|
| 1385 |
-
# Apply patches
|
| 1386 |
-
patch_success = apply_local_patches()
|
| 1387 |
-
if patch_success:
|
| 1388 |
-
logger.info("✅ Successfully applied patches")
|
| 1389 |
-
else:
|
| 1390 |
-
logger.warning("⚠️ Failed to apply patches - some models may fail")
|
| 1391 |
-
|
| 1392 |
results = {}
|
| 1393 |
successful_models = []
|
| 1394 |
|
|
@@ -1468,13 +1537,6 @@ def _beam_distill_internal(
|
|
| 1468 |
clear_cache: bool = False,
|
| 1469 |
) -> dict[str, Any]:
|
| 1470 |
"""Shared internal implementation for beam distillation."""
|
| 1471 |
-
# Apply patches
|
| 1472 |
-
patch_success = apply_local_patches()
|
| 1473 |
-
if patch_success:
|
| 1474 |
-
logger.info("✅ Successfully applied patches")
|
| 1475 |
-
else:
|
| 1476 |
-
logger.warning("⚠️ Failed to apply patches - some models may fail")
|
| 1477 |
-
|
| 1478 |
if teacher_models is None:
|
| 1479 |
teacher_models = DEFAULT_TEACHER_MODELS
|
| 1480 |
|
|
@@ -1647,6 +1709,16 @@ def main(
|
|
| 1647 |
skip_ptr: Annotated[
|
| 1648 |
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
| 1649 |
] = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1650 |
) -> None:
|
| 1651 |
"""Unified distillation command with optional training."""
|
| 1652 |
logger.info("🚀 Starting unified Model2Vec distillation workflow")
|
|
@@ -1656,6 +1728,13 @@ def main(
|
|
| 1656 |
if skip_ptr and train:
|
| 1657 |
logger.info("⏭️ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
|
| 1658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1659 |
logger.info(f"🎓 Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
|
| 1660 |
logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
|
| 1661 |
|
|
@@ -1894,7 +1973,7 @@ def salesforce_model_distillation(
|
|
| 1894 |
logger.info("✅ Successfully loaded with SentenceTransformer method")
|
| 1895 |
|
| 1896 |
# Now use Model2Vec's distill_from_model function directly
|
| 1897 |
-
from model2vec.distill.distillation import distill_from_model
|
| 1898 |
|
| 1899 |
distilled_model = distill_from_model(
|
| 1900 |
model=model,
|
|
@@ -2004,7 +2083,7 @@ def baai_bge_model_distillation(
|
|
| 2004 |
return None
|
| 2005 |
|
| 2006 |
# Now use Model2Vec's distill_from_model function directly
|
| 2007 |
-
from model2vec.distill.distillation import distill_from_model
|
| 2008 |
|
| 2009 |
distilled_model = distill_from_model(
|
| 2010 |
model=model,
|
|
@@ -2090,5 +2169,77 @@ def verify_training_output(trained_dir: Path) -> bool:
|
|
| 2090 |
return False
|
| 2091 |
|
| 2092 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2093 |
if __name__ == "__main__":
|
| 2094 |
typer.run(main)
|
|
|
|
| 28 |
from pathlib import Path
|
| 29 |
from typing import Annotated, Any
|
| 30 |
|
| 31 |
+
import numpy as np
|
| 32 |
import torch
|
| 33 |
import typer
|
| 34 |
from beam import function
|
|
|
|
|
|
|
| 35 |
from sentence_transformers import SentenceTransformer
|
| 36 |
|
| 37 |
+
from distiller.model2vec.distill import distill
|
| 38 |
+
|
| 39 |
# Try to import flash_attn to check if it's available
|
| 40 |
from .beam_utils import (
|
| 41 |
BeamCheckpointManager,
|
|
|
|
| 146 |
# =============================================================================
|
| 147 |
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
def get_current_config_hash(enable_training: bool) -> str:
|
| 150 |
"""Generate a hash of current configuration parameters for checkpoint validation."""
|
| 151 |
import hashlib
|
|
|
|
| 199 |
model_name = f"code_model2vec_{teacher_name}"
|
| 200 |
if enable_training:
|
| 201 |
model_name += "_fine_tuned"
|
| 202 |
+
final_path = final_dir / model_name
|
| 203 |
|
| 204 |
+
if final_path.exists():
|
| 205 |
# Check for essential model files
|
| 206 |
+
has_config = (final_path / "config.json").exists()
|
| 207 |
has_model_file = any(
|
| 208 |
[
|
| 209 |
+
(final_path / "model.safetensors").exists(),
|
| 210 |
+
(final_path / "model.bin").exists(),
|
| 211 |
+
(final_path / "pytorch_model.bin").exists(),
|
| 212 |
]
|
| 213 |
)
|
| 214 |
|
| 215 |
if has_config and has_model_file:
|
| 216 |
logger.info(f"✅ Found existing final model: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
| 217 |
+
return str(final_path)
|
| 218 |
|
| 219 |
return None
|
| 220 |
|
|
|
|
| 409 |
return None
|
| 410 |
|
| 411 |
|
| 412 |
+
def load_optimized_dataset(
|
| 413 |
+
max_samples: int = 50000,
|
| 414 |
+
checkpoint_manager: BeamCheckpointManager | None = None,
|
| 415 |
+
dataset_path: str | None = None,
|
| 416 |
+
) -> list[str]:
|
| 417 |
+
"""Load our pre-created optimized dataset for tokenlearn training."""
|
| 418 |
+
from .dataset import DATASET_OUTPUT_DIR
|
| 419 |
+
from .dataset import load_optimized_dataset as load_dataset_func
|
| 420 |
+
|
| 421 |
+
# Use configuration if not provided as parameter
|
| 422 |
+
if dataset_path is None:
|
| 423 |
+
dataset_path = distillation_config.custom_dataset_path
|
| 424 |
+
|
| 425 |
+
dataset_dir = Path(dataset_path) if dataset_path else DATASET_OUTPUT_DIR
|
| 426 |
+
|
| 427 |
+
logger.info(f"🎯 Loading optimized dataset from {dataset_dir}")
|
| 428 |
+
logger.info(f"📊 Target samples: {max_samples}")
|
| 429 |
+
|
| 430 |
+
try:
|
| 431 |
+
# Load the training split of our optimized dataset
|
| 432 |
+
df = load_dataset_func(output_dir=dataset_dir, split="train")
|
| 433 |
+
|
| 434 |
+
# Extract the text column (which contains our formatted query + code)
|
| 435 |
+
texts = df["text"].tolist()
|
| 436 |
+
|
| 437 |
+
# Shuffle for better training distribution
|
| 438 |
+
import random
|
| 439 |
+
|
| 440 |
+
random.seed(42)
|
| 441 |
+
random.shuffle(texts)
|
| 442 |
+
|
| 443 |
+
# Limit to max_samples
|
| 444 |
+
if len(texts) > max_samples:
|
| 445 |
+
texts = texts[:max_samples]
|
| 446 |
+
|
| 447 |
+
logger.info(f"✅ Loaded {len(texts)} optimized training samples")
|
| 448 |
+
|
| 449 |
+
# Log language distribution
|
| 450 |
+
languages = df["language"].value_counts()
|
| 451 |
+
logger.info("📊 Language distribution:")
|
| 452 |
+
for lang, count in languages.items():
|
| 453 |
+
percentage = (count / len(df)) * 100
|
| 454 |
+
logger.info(f" {lang}: {count} samples ({percentage:.1f}%)")
|
| 455 |
+
|
| 456 |
+
return texts
|
| 457 |
+
|
| 458 |
+
except Exception as e:
|
| 459 |
+
logger.warning(f"⚠️ Failed to load optimized dataset: {e}")
|
| 460 |
+
logger.info("🔄 Falling back to original CodeSearchNet loading...")
|
| 461 |
+
return load_codesearchnet_dataset(max_samples, checkpoint_manager)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
def load_codesearchnet_dataset(
|
| 465 |
max_samples: int = 50000,
|
| 466 |
checkpoint_manager: BeamCheckpointManager | None = None,
|
| 467 |
) -> list[str]:
|
| 468 |
"""Load and format the CodeSearchNet dataset for token frequency computation."""
|
| 469 |
+
from datasets import load_dataset
|
| 470 |
+
|
| 471 |
logger.info(f"Loading CodeSearchNet dataset from {codesearchnet_config.dataset_name}")
|
| 472 |
logger.info(f"Limiting to {max_samples} samples for training efficiency")
|
| 473 |
logger.info(f"Languages: {', '.join(languages_config.all)}")
|
|
|
|
| 518 |
|
| 519 |
try:
|
| 520 |
# Load training split for the specific language (same format as evaluate.py)
|
| 521 |
+
from datasets import load_dataset
|
| 522 |
+
|
| 523 |
dataset = load_dataset(
|
| 524 |
codesearchnet_config.dataset_name,
|
| 525 |
language,
|
|
|
|
| 747 |
logger.info("📊 Computing token frequencies for SIF weighting...")
|
| 748 |
|
| 749 |
try:
|
| 750 |
+
# Load dataset to compute frequencies (limited sample for efficiency)
|
| 751 |
+
if distillation_config.use_optimized_dataset:
|
| 752 |
+
# Use the custom optimized dataset
|
| 753 |
+
from .dataset import load_optimized_dataset as load_custom_dataset
|
| 754 |
+
|
| 755 |
+
custom_dataset_dir = (
|
| 756 |
+
Path(distillation_config.custom_dataset_path)
|
| 757 |
+
if distillation_config.custom_dataset_path
|
| 758 |
+
else Path("code_model2vec/dataset")
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
if custom_dataset_dir.exists() and (custom_dataset_dir / "train.parquet").exists():
|
| 762 |
+
train_df = load_custom_dataset(output_dir=custom_dataset_dir, split="train")
|
| 763 |
+
# Sample a subset for frequency computation
|
| 764 |
+
sample_size = min(10000, len(train_df))
|
| 765 |
+
train_df_sample = train_df.sample(n=sample_size, random_state=42)
|
| 766 |
+
dataset_texts = train_df_sample["text"].tolist()
|
| 767 |
+
logger.info(f"📊 Using {len(dataset_texts)} samples from custom optimized dataset")
|
| 768 |
+
else:
|
| 769 |
+
# Fallback to original dataset loading
|
| 770 |
+
dataset_texts = load_codesearchnet_dataset(max_samples=10000)
|
| 771 |
+
logger.info(
|
| 772 |
+
f"📊 Custom dataset not found, using original CodeSearchNet with {len(dataset_texts)} texts"
|
| 773 |
+
)
|
| 774 |
+
else:
|
| 775 |
+
dataset_texts = load_codesearchnet_dataset(max_samples=10000)
|
| 776 |
+
logger.info(f"📊 Using original CodeSearchNet with {len(dataset_texts)} texts")
|
| 777 |
|
| 778 |
logger.info(f"📊 Computing frequencies on {len(dataset_texts)} texts...")
|
| 779 |
|
|
|
|
| 826 |
"""
|
| 827 |
import json
|
| 828 |
|
|
|
|
| 829 |
from sklearn.decomposition import PCA
|
| 830 |
|
| 831 |
logger.info("🔧 Starting post-training re-regularization (POTION Step 4)")
|
|
|
|
| 898 |
final_embeddings = embeddings_pca.astype(np.float32)
|
| 899 |
|
| 900 |
# Create new model with updated embeddings
|
| 901 |
+
from distiller.model2vec.model import StaticModel
|
| 902 |
|
| 903 |
# Save tokenizer and config from original model
|
| 904 |
tokenizer = model.tokenizer
|
|
|
|
| 928 |
3. Tokenlearn training
|
| 929 |
4. Post-training re-regularization (PCA + SIF weighting)
|
| 930 |
"""
|
|
|
|
| 931 |
from pathlib import Path
|
| 932 |
|
| 933 |
logger.info("🧪 Starting tokenlearn training (POTION approach)...")
|
|
|
|
| 975 |
|
| 976 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
| 977 |
|
| 978 |
+
# Prepare dataset for tokenlearn featurization
|
| 979 |
+
dataset_path, dataset_name, text_key = _prepare_tokenlearn_dataset(persistent_tokenlearn_dir)
|
| 980 |
+
|
| 981 |
# Check if featurization already completed (checkpoint detection)
|
| 982 |
featurization_complete_marker = features_dir / ".featurization_complete"
|
| 983 |
if featurization_complete_marker.exists() and verify_featurization_output(features_dir):
|
|
|
|
| 1000 |
logger.info(f"📊 Using teacher model: {teacher_model_name}")
|
| 1001 |
|
| 1002 |
try:
|
| 1003 |
+
# Use direct function call instead of subprocess
|
| 1004 |
+
from datasets import load_dataset
|
| 1005 |
+
|
| 1006 |
+
from distiller.tokenlearn.featurize import featurize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1007 |
|
| 1008 |
logger.info("🔄 Running tokenlearn featurization...")
|
| 1009 |
+
logger.info(f"📊 Dataset: {dataset_path} (config: {dataset_name})")
|
| 1010 |
+
logger.info(f"📝 Text field: {text_key}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
|
| 1012 |
+
# Load the dataset
|
| 1013 |
+
if dataset_name is None:
|
| 1014 |
+
# For local JSON files, don't pass name parameter
|
| 1015 |
+
dataset = load_dataset(
|
| 1016 |
+
"json",
|
| 1017 |
+
data_files=dataset_path,
|
| 1018 |
+
split="train",
|
| 1019 |
+
streaming=True,
|
| 1020 |
+
)
|
| 1021 |
+
else:
|
| 1022 |
+
# For remote datasets with specific configurations
|
| 1023 |
+
dataset = load_dataset(
|
| 1024 |
+
dataset_path,
|
| 1025 |
+
name=dataset_name,
|
| 1026 |
+
split="train",
|
| 1027 |
+
streaming=True,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
# Call featurization function directly
|
| 1031 |
+
featurize(
|
| 1032 |
+
dataset=iter(dataset),
|
| 1033 |
+
model=teacher_model,
|
| 1034 |
+
output_dir=str(features_dir),
|
| 1035 |
+
max_means=50000, # IMPROVEMENT: Limit means to prevent overfitting
|
| 1036 |
+
batch_size=512, # IMPROVEMENT: Smaller batch for better gradients
|
| 1037 |
+
text_key=text_key,
|
| 1038 |
+
)
|
| 1039 |
|
| 1040 |
logger.info("✅ Featurization completed successfully")
|
| 1041 |
|
|
|
|
| 1084 |
logger.info("🔄 No valid training checkpoint found - starting training...")
|
| 1085 |
|
| 1086 |
try:
|
| 1087 |
+
# Use direct function call instead of subprocess
|
| 1088 |
+
from distiller.tokenlearn.train import train_model
|
| 1089 |
+
from distiller.tokenlearn.utils import collect_means_and_texts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1090 |
|
| 1091 |
+
# IMPROVED APPROACH: Try optimized parameters first
|
| 1092 |
+
logger.info("🚀 Attempting IMPROVED tokenlearn training with optimized parameters...")
|
| 1093 |
+
logger.info("📊 Using smaller vocabulary and conservative PCA to prevent overfitting")
|
| 1094 |
|
| 1095 |
+
# Collect training data from features directory
|
| 1096 |
+
paths = sorted(features_dir.glob("*.json"))
|
| 1097 |
+
train_txt, train_vec = collect_means_and_texts(paths)
|
| 1098 |
+
|
| 1099 |
+
logger.info(f"📊 Collected {len(train_txt)} texts and {train_vec.shape[0]} vectors for training")
|
| 1100 |
+
|
| 1101 |
+
try:
|
| 1102 |
+
# Try improved parameters first
|
| 1103 |
+
trained_model = train_model(
|
| 1104 |
+
model_name=str(teacher_model_name),
|
| 1105 |
+
train_txt=train_txt,
|
| 1106 |
+
train_vec=train_vec,
|
| 1107 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 1108 |
+
vocab_size=25000, # IMPROVEMENT: Smaller vocabulary to prevent overfitting
|
| 1109 |
+
pca_dims=256, # IMPROVEMENT: Conservative PCA dimensions
|
| 1110 |
+
)
|
| 1111 |
|
| 1112 |
+
# Save the trained model
|
| 1113 |
+
trained_model.save_pretrained(str(trained_dir))
|
| 1114 |
+
logger.info("✅ IMPROVED tokenlearn training completed successfully")
|
| 1115 |
+
training_complete_marker.touch()
|
| 1116 |
+
logger.info(f"💾 Created improved training checkpoint: {training_complete_marker}")
|
| 1117 |
|
| 1118 |
+
except Exception as e:
|
| 1119 |
+
logger.warning(f"⚠️ Improved training failed: {e}")
|
| 1120 |
+
logger.info("🔄 Falling back to CONSERVATIVE tokenlearn training...")
|
|
|
|
|
|
|
| 1121 |
|
| 1122 |
+
# FALLBACK: Ultra-conservative training approach
|
| 1123 |
+
try:
|
| 1124 |
+
trained_model = train_model(
|
| 1125 |
+
model_name=str(teacher_model_name),
|
| 1126 |
+
train_txt=train_txt,
|
| 1127 |
+
train_vec=train_vec,
|
| 1128 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 1129 |
+
vocab_size=15000, # FALLBACK: Even smaller vocabulary
|
| 1130 |
+
pca_dims=128, # FALLBACK: Smaller PCA dimensions
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
# Save the trained model
|
| 1134 |
+
trained_model.save_pretrained(str(trained_dir))
|
| 1135 |
+
logger.info("✅ Conservative tokenlearn training completed successfully")
|
| 1136 |
+
training_complete_marker.touch()
|
| 1137 |
+
logger.info(f"💾 Created conservative training checkpoint: {training_complete_marker}")
|
| 1138 |
+
|
| 1139 |
+
except Exception as e2:
|
| 1140 |
+
logger.exception("❌ Conservative tokenlearn training also failed")
|
| 1141 |
+
logger.exception("💥 All training approaches failed - check output above for details")
|
| 1142 |
|
| 1143 |
# Create training marker to indicate we tried but failed
|
| 1144 |
training_fallback_marker = trained_dir / ".training_fallback"
|
| 1145 |
training_fallback_marker.touch()
|
| 1146 |
|
| 1147 |
+
logger.exception("💥 Tokenlearn training failed completely")
|
| 1148 |
+
msg = f"All tokenlearn training approaches failed: {e2}"
|
| 1149 |
+
raise RuntimeError(msg) from e2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1150 |
|
| 1151 |
except Exception as e:
|
| 1152 |
+
logger.warning("💥 All tokenlearn training approaches failed")
|
| 1153 |
+
logger.exception("💥 All training approaches failed completely - cannot proceed")
|
| 1154 |
+
msg = f"All training approaches failed: {e}"
|
| 1155 |
raise RuntimeError(msg) from e
|
| 1156 |
|
| 1157 |
# Step 4: Load the trained model and apply post-training re-regularization
|
|
|
|
| 1166 |
raise RuntimeError(msg)
|
| 1167 |
|
| 1168 |
try:
|
| 1169 |
+
from distiller.model2vec.model import StaticModel
|
| 1170 |
|
| 1171 |
# Load the trained model from tokenlearn
|
| 1172 |
trained_model_path = trained_dir / "model"
|
|
|
|
| 1281 |
existing_final = check_existing_final_model(teacher_name, enable_training)
|
| 1282 |
if existing_final:
|
| 1283 |
logger.info(f"✅ Final model already exists: {teacher_name}{'_fine_tuned' if enable_training else ''}")
|
| 1284 |
+
total_time = time.time() - start_time
|
| 1285 |
return {
|
| 1286 |
"teacher_model": teacher_model,
|
| 1287 |
"teacher_name": teacher_name,
|
| 1288 |
"status": "skipped_existing_final",
|
| 1289 |
"final_path": existing_final,
|
| 1290 |
+
"distillation_time": total_time,
|
| 1291 |
}
|
| 1292 |
|
| 1293 |
# Step 1.5: Sync existing checkpoints from Beam if using Beam utilities
|
|
|
|
| 1305 |
logger.info(f"✅ Found existing base model: {teacher_name}")
|
| 1306 |
if enable_training:
|
| 1307 |
# Load base model for training
|
| 1308 |
+
from distiller.model2vec.model import StaticModel
|
| 1309 |
|
| 1310 |
base_model = StaticModel.from_pretrained(existing_base)
|
| 1311 |
elif use_beam_utilities:
|
|
|
|
| 1313 |
if synced:
|
| 1314 |
existing_base = str(base_dir)
|
| 1315 |
if enable_training:
|
| 1316 |
+
from distiller.model2vec.model import StaticModel
|
| 1317 |
|
| 1318 |
base_model = StaticModel.from_pretrained(existing_base)
|
| 1319 |
|
|
|
|
| 1332 |
base_model = simple_distillation(teacher_model, str(base_dir), pca_dims)
|
| 1333 |
|
| 1334 |
if base_model is None:
|
| 1335 |
+
total_time = time.time() - start_time
|
| 1336 |
return {
|
| 1337 |
"teacher_model": teacher_model,
|
| 1338 |
"teacher_name": teacher_name,
|
| 1339 |
"status": "failed_base_distillation",
|
| 1340 |
"error": "Simple distillation failed",
|
| 1341 |
+
"distillation_time": total_time,
|
| 1342 |
}
|
| 1343 |
|
| 1344 |
# Sync base model and checkpoints to Beam
|
|
|
|
| 1351 |
|
| 1352 |
existing_base = str(base_dir)
|
| 1353 |
|
| 1354 |
+
# Step 3: Handle final model creation
|
| 1355 |
+
if enable_training and base_model is not None:
|
| 1356 |
+
# Perform tokenlearn training (POTION approach)
|
| 1357 |
+
logger.info(f"🧪 Starting tokenlearn training for {teacher_name}")
|
| 1358 |
+
|
| 1359 |
+
try:
|
| 1360 |
+
# Load teacher model for training
|
| 1361 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1362 |
+
teacher_st_model = load_model_with_flash_attention(teacher_model, device)
|
| 1363 |
+
|
| 1364 |
# Perform tokenlearn training (POTION approach)
|
| 1365 |
+
final_model = tokenlearn_training(
|
| 1366 |
+
base_model,
|
| 1367 |
+
teacher_st_model,
|
| 1368 |
+
checkpoint_mgr,
|
| 1369 |
+
skip_post_training_regularization=distillation_config.skip_post_training_regularization,
|
| 1370 |
+
)
|
| 1371 |
|
| 1372 |
+
# Save final model
|
| 1373 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 1374 |
+
final_model.save_pretrained(str(final_dir))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1375 |
|
| 1376 |
+
# Sync final model and training checkpoints to Beam
|
| 1377 |
+
if use_beam_utilities:
|
| 1378 |
+
sync_model_to_beam(f"{teacher_name}_final", str(final_dir), use_beam_utilities)
|
| 1379 |
+
if checkpoint_mgr:
|
| 1380 |
+
sync_checkpoints_to_beam(
|
| 1381 |
+
VOLUME_CONFIG.name, f"training_{teacher_name}", directories.checkpoints
|
| 1382 |
+
)
|
| 1383 |
|
| 1384 |
+
del teacher_st_model
|
| 1385 |
+
if torch.cuda.is_available():
|
| 1386 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1387 |
|
| 1388 |
+
except RuntimeError as e:
|
| 1389 |
+
# Training failed - clean up and return failure
|
| 1390 |
+
logger.exception(f"❌ Training failed for {teacher_name}")
|
| 1391 |
+
|
| 1392 |
+
# Clean up teacher model if it was loaded
|
| 1393 |
+
if "teacher_st_model" in locals():
|
| 1394 |
del teacher_st_model
|
| 1395 |
+
if torch.cuda.is_available():
|
| 1396 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1397 |
|
| 1398 |
+
total_time = time.time() - start_time
|
| 1399 |
+
return {
|
| 1400 |
+
"teacher_model": teacher_model,
|
| 1401 |
+
"teacher_name": teacher_name,
|
| 1402 |
+
"status": "failed_training",
|
| 1403 |
+
"error": f"Training failed: {e!s}",
|
| 1404 |
+
"base_path": existing_base, # Base model was created successfully
|
| 1405 |
+
"distillation_time": total_time,
|
| 1406 |
+
}
|
|
|
|
| 1407 |
|
| 1408 |
+
else:
|
| 1409 |
+
# Copy base to final (no training)
|
| 1410 |
+
logger.info(f"📁 Copying base to final for {teacher_name}")
|
| 1411 |
+
if not copy_base_to_final(teacher_name, enable_training):
|
| 1412 |
+
total_time = time.time() - start_time
|
| 1413 |
+
return {
|
| 1414 |
+
"teacher_model": teacher_model,
|
| 1415 |
+
"teacher_name": teacher_name,
|
| 1416 |
+
"status": "failed_copy_to_final",
|
| 1417 |
+
"error": "Failed to copy base to final",
|
| 1418 |
+
"distillation_time": total_time,
|
| 1419 |
+
}
|
| 1420 |
|
| 1421 |
+
total_time = time.time() - start_time
|
| 1422 |
return {
|
| 1423 |
"teacher_model": teacher_model,
|
| 1424 |
"teacher_name": teacher_name,
|
|
|
|
| 1431 |
|
| 1432 |
except Exception as e:
|
| 1433 |
logger.exception(f"❌ Failed to process {teacher_model}")
|
| 1434 |
+
total_time = time.time() - start_time
|
| 1435 |
return {
|
| 1436 |
"teacher_model": teacher_model,
|
| 1437 |
"teacher_name": teacher_name,
|
| 1438 |
"status": "failed",
|
| 1439 |
"error": str(e),
|
| 1440 |
+
"distillation_time": total_time,
|
| 1441 |
}
|
| 1442 |
|
| 1443 |
|
|
|
|
| 1458 |
if teacher_models is None:
|
| 1459 |
teacher_models = DEFAULT_TEACHER_MODELS
|
| 1460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1461 |
results = {}
|
| 1462 |
successful_models = []
|
| 1463 |
|
|
|
|
| 1537 |
clear_cache: bool = False,
|
| 1538 |
) -> dict[str, Any]:
|
| 1539 |
"""Shared internal implementation for beam distillation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1540 |
if teacher_models is None:
|
| 1541 |
teacher_models = DEFAULT_TEACHER_MODELS
|
| 1542 |
|
|
|
|
| 1709 |
skip_ptr: Annotated[
|
| 1710 |
bool, typer.Option("--skip-ptr", help="Skip post-training re-regularization (PCA + SIF weighting) step")
|
| 1711 |
] = False,
|
| 1712 |
+
use_optimized_dataset: Annotated[
|
| 1713 |
+
bool,
|
| 1714 |
+
typer.Option(
|
| 1715 |
+
"--use-optimized-dataset", help="Use the pre-created optimized dataset from code_model2vec/dataset"
|
| 1716 |
+
),
|
| 1717 |
+
] = False,
|
| 1718 |
+
dataset_path: Annotated[
|
| 1719 |
+
str | None,
|
| 1720 |
+
typer.Option("--dataset-path", help="Path to custom dataset directory (defaults to code_model2vec/dataset)"),
|
| 1721 |
+
] = None,
|
| 1722 |
) -> None:
|
| 1723 |
"""Unified distillation command with optional training."""
|
| 1724 |
logger.info("🚀 Starting unified Model2Vec distillation workflow")
|
|
|
|
| 1728 |
if skip_ptr and train:
|
| 1729 |
logger.info("⏭️ Post-training re-regularization will be skipped (PCA + SIF weighting disabled)")
|
| 1730 |
|
| 1731 |
+
# Set dataset configuration
|
| 1732 |
+
distillation_config.use_optimized_dataset = use_optimized_dataset
|
| 1733 |
+
distillation_config.custom_dataset_path = dataset_path
|
| 1734 |
+
if use_optimized_dataset and train:
|
| 1735 |
+
dataset_source = dataset_path or "code_model2vec/dataset"
|
| 1736 |
+
logger.info(f"🎯 Using optimized dataset from: {dataset_source}")
|
| 1737 |
+
|
| 1738 |
logger.info(f"🎓 Training mode: {'Tokenlearn (POTION) training' if train else 'Basic distillation only'}")
|
| 1739 |
logger.info(f"☁️ Execution: {'Beam' if use_beam else 'Local'}")
|
| 1740 |
|
|
|
|
| 1973 |
logger.info("✅ Successfully loaded with SentenceTransformer method")
|
| 1974 |
|
| 1975 |
# Now use Model2Vec's distill_from_model function directly
|
| 1976 |
+
from distiller.model2vec.distill.distillation import distill_from_model
|
| 1977 |
|
| 1978 |
distilled_model = distill_from_model(
|
| 1979 |
model=model,
|
|
|
|
| 2083 |
return None
|
| 2084 |
|
| 2085 |
# Now use Model2Vec's distill_from_model function directly
|
| 2086 |
+
from distiller.model2vec.distill.distillation import distill_from_model
|
| 2087 |
|
| 2088 |
distilled_model = distill_from_model(
|
| 2089 |
model=model,
|
|
|
|
| 2169 |
return False
|
| 2170 |
|
| 2171 |
|
| 2172 |
+
def _prepare_tokenlearn_dataset(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
|
| 2173 |
+
"""
|
| 2174 |
+
Prepare dataset for tokenlearn featurization.
|
| 2175 |
+
|
| 2176 |
+
Returns:
|
| 2177 |
+
Tuple of (dataset_path, dataset_name, text_key) for tokenlearn
|
| 2178 |
+
"""
|
| 2179 |
+
if distillation_config.use_optimized_dataset:
|
| 2180 |
+
return _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir)
|
| 2181 |
+
return _prepare_original_dataset_for_tokenlearn()
|
| 2182 |
+
|
| 2183 |
+
|
| 2184 |
+
def _prepare_custom_dataset_for_tokenlearn(tokenlearn_dir: Path) -> tuple[str, str | None, str]:
|
| 2185 |
+
"""Prepare custom optimized dataset for tokenlearn featurization."""
|
| 2186 |
+
logger.info("🎯 Preparing custom optimized dataset for tokenlearn...")
|
| 2187 |
+
|
| 2188 |
+
# Import the dataset module
|
| 2189 |
+
from .dataset import create_optimized_dataset, load_optimized_dataset
|
| 2190 |
+
|
| 2191 |
+
# Define paths
|
| 2192 |
+
custom_dataset_dir = (
|
| 2193 |
+
Path(distillation_config.custom_dataset_path)
|
| 2194 |
+
if distillation_config.custom_dataset_path
|
| 2195 |
+
else Path("code_model2vec/dataset")
|
| 2196 |
+
)
|
| 2197 |
+
tokenlearn_dataset_dir = tokenlearn_dir / "custom_dataset"
|
| 2198 |
+
|
| 2199 |
+
# Check if we need to create the custom dataset
|
| 2200 |
+
if not custom_dataset_dir.exists() or not (custom_dataset_dir / "train.parquet").exists():
|
| 2201 |
+
logger.info("📊 Custom dataset not found - creating optimized dataset...")
|
| 2202 |
+
create_optimized_dataset(
|
| 2203 |
+
max_samples_per_lang=10000, # Reasonable size for tokenlearn
|
| 2204 |
+
output_dir=custom_dataset_dir,
|
| 2205 |
+
create_multiple_formats=False, # Use simple format for tokenlearn
|
| 2206 |
+
)
|
| 2207 |
+
|
| 2208 |
+
# Load the custom dataset
|
| 2209 |
+
logger.info(f"📂 Loading custom dataset from {custom_dataset_dir}")
|
| 2210 |
+
train_df = load_optimized_dataset(output_dir=custom_dataset_dir, split="train")
|
| 2211 |
+
|
| 2212 |
+
# Prepare dataset for tokenlearn (save as JSON files that load_dataset can read)
|
| 2213 |
+
tokenlearn_dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 2214 |
+
|
| 2215 |
+
# Save as JSON file that tokenlearn can load with load_dataset()
|
| 2216 |
+
train_json_path = tokenlearn_dataset_dir / "train.json"
|
| 2217 |
+
|
| 2218 |
+
# Create JSON lines format
|
| 2219 |
+
import json
|
| 2220 |
+
|
| 2221 |
+
with train_json_path.open("w") as f:
|
| 2222 |
+
for text in train_df["text"]:
|
| 2223 |
+
json.dump({"text": text}, f)
|
| 2224 |
+
f.write("\n")
|
| 2225 |
+
|
| 2226 |
+
logger.info(f"✅ Prepared custom dataset with {len(train_df)} samples for tokenlearn")
|
| 2227 |
+
logger.info(f"💾 Saved JSON dataset to {train_json_path}")
|
| 2228 |
+
|
| 2229 |
+
# Return the JSON file path directly (not directory) and no config name for JSON loading
|
| 2230 |
+
return str(train_json_path), None, "text"
|
| 2231 |
+
|
| 2232 |
+
|
| 2233 |
+
def _prepare_original_dataset_for_tokenlearn() -> tuple[str, str, str]:
|
| 2234 |
+
"""Prepare original CodeSearchNet dataset for tokenlearn featurization."""
|
| 2235 |
+
logger.info("📊 Using original CodeSearchNet dataset for tokenlearn...")
|
| 2236 |
+
|
| 2237 |
+
return (
|
| 2238 |
+
str(distillation_config.tokenlearn_dataset), # "sentence-transformers/codesearchnet"
|
| 2239 |
+
str(distillation_config.tokenlearn_dataset_name), # "pair"
|
| 2240 |
+
str(distillation_config.tokenlearn_text_key), # "combined_text"
|
| 2241 |
+
)
|
| 2242 |
+
|
| 2243 |
+
|
| 2244 |
if __name__ == "__main__":
|
| 2245 |
typer.run(main)
|
src/distiller/patch_utils.py
DELETED
|
@@ -1,276 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Patch utilities for applying fixes to installed packages.
|
| 3 |
-
|
| 4 |
-
This module provides functionality to automatically apply all patches
|
| 5 |
-
from the patches directory to fix bugs in third-party libraries.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import logging
|
| 9 |
-
import subprocess
|
| 10 |
-
import sys
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
|
| 13 |
-
logger = logging.getLogger(__name__)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def find_patches_directory() -> Path:
|
| 17 |
-
"""Find the patches directory relative to the current script location."""
|
| 18 |
-
# Go up from src/distiller/ to project root, then to patches/
|
| 19 |
-
current_file = Path(__file__)
|
| 20 |
-
project_root = current_file.parent.parent.parent # Go up 3 levels: distiller -> src -> project_root
|
| 21 |
-
patches_dir = project_root / "patches"
|
| 22 |
-
|
| 23 |
-
if not patches_dir.exists():
|
| 24 |
-
# Alternative: try relative to current working directory
|
| 25 |
-
patches_dir = Path("patches")
|
| 26 |
-
|
| 27 |
-
return patches_dir
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def get_site_packages_path() -> Path:
|
| 31 |
-
"""Get the site-packages directory path."""
|
| 32 |
-
import site
|
| 33 |
-
|
| 34 |
-
# Try to get the site-packages from the current environment
|
| 35 |
-
site_packages_dirs = site.getsitepackages()
|
| 36 |
-
|
| 37 |
-
# Prefer the first site-packages directory
|
| 38 |
-
if site_packages_dirs:
|
| 39 |
-
return Path(site_packages_dirs[0])
|
| 40 |
-
|
| 41 |
-
# Fallback: try to find it relative to Python executable
|
| 42 |
-
python_path = Path(sys.executable)
|
| 43 |
-
if python_path.name == "python" or python_path.name.startswith("python"):
|
| 44 |
-
# Standard virtual environment structure
|
| 45 |
-
venv_lib = python_path.parent.parent / "lib"
|
| 46 |
-
for item in venv_lib.iterdir():
|
| 47 |
-
if item.name.startswith("python"):
|
| 48 |
-
site_packages = item / "site-packages"
|
| 49 |
-
if site_packages.exists():
|
| 50 |
-
return site_packages
|
| 51 |
-
|
| 52 |
-
# Last resort: use current directory
|
| 53 |
-
return Path()
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def apply_patch_file(patch_file: Path, target_dir: Path) -> bool:
|
| 57 |
-
"""
|
| 58 |
-
Apply a single patch file to the target directory.
|
| 59 |
-
|
| 60 |
-
Args:
|
| 61 |
-
patch_file: Path to the .patch file
|
| 62 |
-
target_dir: Target directory (usually site-packages)
|
| 63 |
-
|
| 64 |
-
Returns:
|
| 65 |
-
True if patch was applied successfully, False otherwise
|
| 66 |
-
"""
|
| 67 |
-
try:
|
| 68 |
-
logger.info(f"Applying patch: {patch_file.name}")
|
| 69 |
-
|
| 70 |
-
# Check if patch is already applied
|
| 71 |
-
if is_patch_already_applied(patch_file, target_dir):
|
| 72 |
-
logger.info(f"Patch {patch_file.name} already applied")
|
| 73 |
-
return True
|
| 74 |
-
|
| 75 |
-
# Clean any duplicate validation code before applying
|
| 76 |
-
if "model2vec.patch" in patch_file.name:
|
| 77 |
-
clean_duplicate_validation_code(target_dir)
|
| 78 |
-
|
| 79 |
-
# Use patch command with the following options:
|
| 80 |
-
# -p1: strip 1 leading directory from paths
|
| 81 |
-
# -d: change to directory before applying
|
| 82 |
-
# -f: force (don't ask questions)
|
| 83 |
-
# -N: don't reverse patches that appear to be already applied
|
| 84 |
-
result = subprocess.run( # noqa: S603
|
| 85 |
-
["patch", "-p1", "-d", str(target_dir), "-f", "-N"], # noqa: S607
|
| 86 |
-
input=patch_file.read_text(),
|
| 87 |
-
text=True,
|
| 88 |
-
capture_output=True,
|
| 89 |
-
check=False, # Don't raise exception on non-zero exit
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
if result.returncode == 0:
|
| 93 |
-
logger.info(f"Successfully applied patch: {patch_file.name}")
|
| 94 |
-
return True
|
| 95 |
-
if "already applied" in result.stderr.lower() or "reversed" in result.stderr.lower():
|
| 96 |
-
logger.info(f"Patch {patch_file.name} already applied")
|
| 97 |
-
return True
|
| 98 |
-
logger.warning(f"Failed to apply patch {patch_file.name}: {result.stderr}")
|
| 99 |
-
return False
|
| 100 |
-
|
| 101 |
-
except FileNotFoundError:
|
| 102 |
-
logger.exception("'patch' command not found. Please install patch utility.")
|
| 103 |
-
return False
|
| 104 |
-
except Exception:
|
| 105 |
-
logger.exception(f"Error applying patch {patch_file.name}")
|
| 106 |
-
return False
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def apply_all_patches() -> int:
|
| 110 |
-
"""
|
| 111 |
-
Apply all patches from the patches directory.
|
| 112 |
-
|
| 113 |
-
Returns:
|
| 114 |
-
Number of patches successfully applied
|
| 115 |
-
"""
|
| 116 |
-
patches_dir = find_patches_directory()
|
| 117 |
-
|
| 118 |
-
if not patches_dir.exists():
|
| 119 |
-
logger.warning(f"Patches directory not found: {patches_dir}")
|
| 120 |
-
return 0
|
| 121 |
-
|
| 122 |
-
# Find all .patch files
|
| 123 |
-
patch_files = list(patches_dir.glob("*.patch"))
|
| 124 |
-
|
| 125 |
-
if not patch_files:
|
| 126 |
-
logger.info("No patch files found")
|
| 127 |
-
return 0
|
| 128 |
-
|
| 129 |
-
# Get target directory (site-packages)
|
| 130 |
-
target_dir = get_site_packages_path()
|
| 131 |
-
logger.info(f"Applying patches to: {target_dir}")
|
| 132 |
-
|
| 133 |
-
# Clean any existing duplicates first
|
| 134 |
-
clean_duplicate_validation_code(target_dir)
|
| 135 |
-
|
| 136 |
-
success_count = 0
|
| 137 |
-
|
| 138 |
-
# Sort patch files for consistent ordering
|
| 139 |
-
for patch_file in sorted(patch_files):
|
| 140 |
-
if apply_patch_file(patch_file, target_dir):
|
| 141 |
-
success_count += 1
|
| 142 |
-
|
| 143 |
-
logger.info(f"Applied {success_count}/{len(patch_files)} patches successfully")
|
| 144 |
-
return success_count
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
def is_patch_already_applied(patch_file: Path, target_dir: Path) -> bool:
|
| 148 |
-
"""
|
| 149 |
-
Check if a patch has already been applied by looking for specific markers.
|
| 150 |
-
|
| 151 |
-
Args:
|
| 152 |
-
patch_file: Path to the .patch file
|
| 153 |
-
target_dir: Target directory (usually site-packages)
|
| 154 |
-
|
| 155 |
-
Returns:
|
| 156 |
-
True if patch appears to be already applied, False otherwise
|
| 157 |
-
"""
|
| 158 |
-
try:
|
| 159 |
-
# For model2vec.patch, check if the validation code is already present
|
| 160 |
-
if "model2vec.patch" in patch_file.name:
|
| 161 |
-
inference_file = target_dir / "model2vec" / "distill" / "inference.py"
|
| 162 |
-
if inference_file.exists():
|
| 163 |
-
inference_content = inference_file.read_text()
|
| 164 |
-
# Check for the specific validation code we're adding
|
| 165 |
-
if (
|
| 166 |
-
"Token-vector mismatch:" in inference_content
|
| 167 |
-
and "Truncating to prevent failure" in inference_content
|
| 168 |
-
):
|
| 169 |
-
# Also make sure it's in the right place (before return statement, not after)
|
| 170 |
-
lines = inference_content.split("\n")
|
| 171 |
-
for i, line in enumerate(lines):
|
| 172 |
-
if "return out_tokens, out_weights" in line:
|
| 173 |
-
# Check if validation code appears before this return
|
| 174 |
-
preceding_lines = lines[max(0, i - 10) : i]
|
| 175 |
-
if any("Token-vector mismatch:" in pline for pline in preceding_lines):
|
| 176 |
-
return True
|
| 177 |
-
break
|
| 178 |
-
|
| 179 |
-
# For tokenlearn.patch, check if the indexing fix is already present
|
| 180 |
-
if "tokenlearn.patch" in patch_file.name:
|
| 181 |
-
pretrain_file = target_dir / "tokenlearn" / "pretrain.py"
|
| 182 |
-
if pretrain_file.exists():
|
| 183 |
-
pretrain_content = pretrain_file.read_text()
|
| 184 |
-
# Check for the specific fix we're adding
|
| 185 |
-
if (
|
| 186 |
-
"Fix for index out of bounds issue" in pretrain_content
|
| 187 |
-
and "torch.clamp(input_ids, 0, self.w.shape[0] - 1)" in pretrain_content
|
| 188 |
-
):
|
| 189 |
-
return True
|
| 190 |
-
|
| 191 |
-
return False
|
| 192 |
-
|
| 193 |
-
except Exception as e:
|
| 194 |
-
logger.warning(f"Error checking if patch {patch_file.name} is applied: {e}")
|
| 195 |
-
return False
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
def clean_duplicate_validation_code(target_dir: Path) -> bool:
|
| 199 |
-
"""
|
| 200 |
-
Clean up duplicate validation code that might have been added by multiple patch applications.
|
| 201 |
-
|
| 202 |
-
Args:
|
| 203 |
-
target_dir: Target directory (usually site-packages)
|
| 204 |
-
|
| 205 |
-
Returns:
|
| 206 |
-
True if cleanup was successful, False otherwise
|
| 207 |
-
"""
|
| 208 |
-
try:
|
| 209 |
-
inference_file = target_dir / "model2vec" / "distill" / "inference.py"
|
| 210 |
-
if not inference_file.exists():
|
| 211 |
-
return True
|
| 212 |
-
|
| 213 |
-
content = inference_file.read_text()
|
| 214 |
-
lines = content.split("\n")
|
| 215 |
-
|
| 216 |
-
# Find all instances of the validation code
|
| 217 |
-
validation_indices = []
|
| 218 |
-
for i, line in enumerate(lines):
|
| 219 |
-
if "Token-vector mismatch:" in line:
|
| 220 |
-
validation_indices.append(i)
|
| 221 |
-
|
| 222 |
-
if len(validation_indices) <= 1:
|
| 223 |
-
return True # No duplicates or no validation code
|
| 224 |
-
|
| 225 |
-
# Keep only the validation code that appears before a return statement
|
| 226 |
-
lines_to_keep = []
|
| 227 |
-
skip_until = -1
|
| 228 |
-
|
| 229 |
-
for i, line in enumerate(lines):
|
| 230 |
-
if i <= skip_until:
|
| 231 |
-
continue
|
| 232 |
-
|
| 233 |
-
# If this is validation code
|
| 234 |
-
if "Token-vector mismatch:" in line:
|
| 235 |
-
# Look ahead to see if there's a return statement nearby
|
| 236 |
-
has_return_after = False
|
| 237 |
-
for j in range(i, min(len(lines), i + 20)):
|
| 238 |
-
if "return out_tokens, out_weights" in lines[j]:
|
| 239 |
-
has_return_after = True
|
| 240 |
-
break
|
| 241 |
-
|
| 242 |
-
# Keep this validation block only if it's followed by a return
|
| 243 |
-
if has_return_after:
|
| 244 |
-
lines_to_keep.append(line)
|
| 245 |
-
else:
|
| 246 |
-
# Skip this validation block (it's a duplicate)
|
| 247 |
-
# Find the end of this validation block
|
| 248 |
-
for j in range(i + 1, len(lines)):
|
| 249 |
-
if lines[j].strip() == "" or not lines[j].startswith(" "):
|
| 250 |
-
skip_until = j - 1
|
| 251 |
-
break
|
| 252 |
-
else:
|
| 253 |
-
lines_to_keep.append(line)
|
| 254 |
-
|
| 255 |
-
# Write back the cleaned content
|
| 256 |
-
cleaned_content = "\n".join(lines_to_keep)
|
| 257 |
-
inference_file.write_text(cleaned_content)
|
| 258 |
-
logger.info("Cleaned duplicate validation code from inference.py")
|
| 259 |
-
return True
|
| 260 |
-
|
| 261 |
-
except Exception as e:
|
| 262 |
-
logger.warning(f"Error cleaning duplicate validation code: {e}")
|
| 263 |
-
return False
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
def main() -> None:
|
| 267 |
-
"""Main function for standalone execution."""
|
| 268 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 269 |
-
|
| 270 |
-
print("Applying all patches...")
|
| 271 |
-
success_count = apply_all_patches()
|
| 272 |
-
print(f"Done. Applied {success_count} patches.")
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
if __name__ == "__main__":
|
| 276 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
CHANGED
|
@@ -774,24 +774,31 @@ dependencies = [
|
|
| 774 |
{ name = "flash-attn" },
|
| 775 |
{ name = "hatchling" },
|
| 776 |
{ name = "iso639" },
|
|
|
|
|
|
|
| 777 |
{ name = "kaleido" },
|
| 778 |
{ name = "lightning" },
|
| 779 |
{ name = "matplotlib" },
|
| 780 |
-
{ name = "
|
| 781 |
{ name = "mteb" },
|
| 782 |
{ name = "numpy" },
|
| 783 |
{ name = "plotly" },
|
| 784 |
{ name = "psutil" },
|
| 785 |
{ name = "pydantic" },
|
| 786 |
{ name = "requests" },
|
|
|
|
|
|
|
| 787 |
{ name = "scikit-learn" },
|
| 788 |
{ name = "seaborn" },
|
| 789 |
{ name = "sentence-transformers" },
|
| 790 |
{ name = "setuptools" },
|
|
|
|
| 791 |
{ name = "smart-open", extra = ["s3"] },
|
| 792 |
{ name = "statsmodels" },
|
| 793 |
-
{ name = "
|
| 794 |
{ name = "torch" },
|
|
|
|
|
|
|
| 795 |
{ name = "typer" },
|
| 796 |
]
|
| 797 |
|
|
@@ -813,24 +820,31 @@ requires-dist = [
|
|
| 813 |
{ name = "flash-attn", specifier = ">=2.7.4.post1" },
|
| 814 |
{ name = "hatchling", specifier = ">=1.27.0" },
|
| 815 |
{ name = "iso639", specifier = ">=0.1.4" },
|
|
|
|
|
|
|
| 816 |
{ name = "kaleido", specifier = "==1.0.0rc13" },
|
| 817 |
{ name = "lightning", specifier = ">=2.5.1.post0" },
|
| 818 |
{ name = "matplotlib", specifier = ">=3.10.3" },
|
| 819 |
-
{ name = "
|
| 820 |
{ name = "mteb", specifier = ">=1.14.15" },
|
| 821 |
{ name = "numpy", specifier = ">=1.26.4" },
|
| 822 |
{ name = "plotly", specifier = ">=6.1.1" },
|
| 823 |
{ name = "psutil", specifier = ">=7.0.0" },
|
| 824 |
{ name = "pydantic", specifier = ">=2.11.5" },
|
| 825 |
{ name = "requests", specifier = ">=2.32.3" },
|
|
|
|
|
|
|
| 826 |
{ name = "scikit-learn", specifier = ">=1.6.1" },
|
| 827 |
{ name = "seaborn", specifier = ">=0.13.2" },
|
| 828 |
{ name = "sentence-transformers", specifier = ">=4.1.0" },
|
| 829 |
{ name = "setuptools", specifier = ">=80.8.0" },
|
|
|
|
| 830 |
{ name = "smart-open", extras = ["s3"], specifier = ">=7.1.0" },
|
| 831 |
{ name = "statsmodels", specifier = ">=0.14.4" },
|
| 832 |
-
{ name = "
|
| 833 |
{ name = "torch", specifier = ">=2.7.0" },
|
|
|
|
|
|
|
| 834 |
{ name = "typer", specifier = ">=0.16.0" },
|
| 835 |
]
|
| 836 |
|
|
@@ -1187,38 +1201,6 @@ wheels = [
|
|
| 1187 |
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
|
| 1188 |
]
|
| 1189 |
|
| 1190 |
-
[[package]]
|
| 1191 |
-
name = "model2vec"
|
| 1192 |
-
version = "0.5.0"
|
| 1193 |
-
source = { registry = "https://pypi.org/simple" }
|
| 1194 |
-
dependencies = [
|
| 1195 |
-
{ name = "jinja2" },
|
| 1196 |
-
{ name = "joblib" },
|
| 1197 |
-
{ name = "numpy" },
|
| 1198 |
-
{ name = "rich" },
|
| 1199 |
-
{ name = "safetensors" },
|
| 1200 |
-
{ name = "setuptools" },
|
| 1201 |
-
{ name = "tokenizers" },
|
| 1202 |
-
{ name = "tqdm" },
|
| 1203 |
-
]
|
| 1204 |
-
sdist = { url = "https://files.pythonhosted.org/packages/93/18/c546916657e47e52b6e25b231803903bcf4e7ef2497fe41e9869236d7dee/model2vec-0.5.0.tar.gz", hash = "sha256:0771fd99d5c58fac631a2faa233759a8cec7a3be6e9aeeeeeca2d5e7048d1c7b", size = 2665840 }
|
| 1205 |
-
wheels = [
|
| 1206 |
-
{ url = "https://files.pythonhosted.org/packages/66/ab/5263bc4605e9960fece76b710c01fef33859dc6ae72832d5987db75eed63/model2vec-0.5.0-py3-none-any.whl", hash = "sha256:12f14a18556975c037961a836a702388876bfec1ff76176f056884d219735271", size = 44578 },
|
| 1207 |
-
]
|
| 1208 |
-
|
| 1209 |
-
[package.optional-dependencies]
|
| 1210 |
-
distill = [
|
| 1211 |
-
{ name = "scikit-learn" },
|
| 1212 |
-
{ name = "torch" },
|
| 1213 |
-
{ name = "transformers" },
|
| 1214 |
-
]
|
| 1215 |
-
train = [
|
| 1216 |
-
{ name = "lightning" },
|
| 1217 |
-
{ name = "scikit-learn" },
|
| 1218 |
-
{ name = "skops" },
|
| 1219 |
-
{ name = "torch" },
|
| 1220 |
-
]
|
| 1221 |
-
|
| 1222 |
[[package]]
|
| 1223 |
name = "more-itertools"
|
| 1224 |
version = "10.7.0"
|
|
@@ -2492,22 +2474,6 @@ wheels = [
|
|
| 2492 |
{ url = "https://files.pythonhosted.org/packages/e6/b6/072a8e053ae600dcc2ac0da81a23548e3b523301a442a6ca900e92ac35be/tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382", size = 2435481 },
|
| 2493 |
]
|
| 2494 |
|
| 2495 |
-
[[package]]
|
| 2496 |
-
name = "tokenlearn"
|
| 2497 |
-
version = "0.2.0"
|
| 2498 |
-
source = { registry = "https://pypi.org/simple" }
|
| 2499 |
-
dependencies = [
|
| 2500 |
-
{ name = "datasets" },
|
| 2501 |
-
{ name = "model2vec", extra = ["distill"] },
|
| 2502 |
-
{ name = "more-itertools" },
|
| 2503 |
-
{ name = "sentence-transformers" },
|
| 2504 |
-
{ name = "torch" },
|
| 2505 |
-
]
|
| 2506 |
-
sdist = { url = "https://files.pythonhosted.org/packages/58/b6/f9587ea271a9a7464cd25025b65f471d49bbceb48cc90742a89ac085edfd/tokenlearn-0.2.0.tar.gz", hash = "sha256:7a8faa0f51a510d185a40bef197a88116464adb8ce85ffd12c1d6905369c2375", size = 149042 }
|
| 2507 |
-
wheels = [
|
| 2508 |
-
{ url = "https://files.pythonhosted.org/packages/40/3d/1c2b2e80ffd929bb8e7930d6a48e3b4252676cdc6c0c38f13a6f0f374b9c/tokenlearn-0.2.0-py3-none-any.whl", hash = "sha256:7a05e2800420eb2914c30e7377adeb14822c63585a0b9ed018bc82735dae1f29", size = 11970 },
|
| 2509 |
-
]
|
| 2510 |
-
|
| 2511 |
[[package]]
|
| 2512 |
name = "torch"
|
| 2513 |
version = "2.7.0"
|
|
@@ -2580,7 +2546,7 @@ wheels = [
|
|
| 2580 |
|
| 2581 |
[[package]]
|
| 2582 |
name = "transformers"
|
| 2583 |
-
version = "4.52.
|
| 2584 |
source = { registry = "https://pypi.org/simple" }
|
| 2585 |
dependencies = [
|
| 2586 |
{ name = "filelock" },
|
|
@@ -2594,9 +2560,9 @@ dependencies = [
|
|
| 2594 |
{ name = "tokenizers" },
|
| 2595 |
{ name = "tqdm" },
|
| 2596 |
]
|
| 2597 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 2598 |
wheels = [
|
| 2599 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 2600 |
]
|
| 2601 |
|
| 2602 |
[[package]]
|
|
|
|
| 774 |
{ name = "flash-attn" },
|
| 775 |
{ name = "hatchling" },
|
| 776 |
{ name = "iso639" },
|
| 777 |
+
{ name = "jinja2" },
|
| 778 |
+
{ name = "joblib" },
|
| 779 |
{ name = "kaleido" },
|
| 780 |
{ name = "lightning" },
|
| 781 |
{ name = "matplotlib" },
|
| 782 |
+
{ name = "more-itertools" },
|
| 783 |
{ name = "mteb" },
|
| 784 |
{ name = "numpy" },
|
| 785 |
{ name = "plotly" },
|
| 786 |
{ name = "psutil" },
|
| 787 |
{ name = "pydantic" },
|
| 788 |
{ name = "requests" },
|
| 789 |
+
{ name = "rich" },
|
| 790 |
+
{ name = "safetensors" },
|
| 791 |
{ name = "scikit-learn" },
|
| 792 |
{ name = "seaborn" },
|
| 793 |
{ name = "sentence-transformers" },
|
| 794 |
{ name = "setuptools" },
|
| 795 |
+
{ name = "skops" },
|
| 796 |
{ name = "smart-open", extra = ["s3"] },
|
| 797 |
{ name = "statsmodels" },
|
| 798 |
+
{ name = "tokenizers" },
|
| 799 |
{ name = "torch" },
|
| 800 |
+
{ name = "tqdm" },
|
| 801 |
+
{ name = "transformers" },
|
| 802 |
{ name = "typer" },
|
| 803 |
]
|
| 804 |
|
|
|
|
| 820 |
{ name = "flash-attn", specifier = ">=2.7.4.post1" },
|
| 821 |
{ name = "hatchling", specifier = ">=1.27.0" },
|
| 822 |
{ name = "iso639", specifier = ">=0.1.4" },
|
| 823 |
+
{ name = "jinja2", specifier = ">=3.0.0" },
|
| 824 |
+
{ name = "joblib", specifier = ">=1.0.0" },
|
| 825 |
{ name = "kaleido", specifier = "==1.0.0rc13" },
|
| 826 |
{ name = "lightning", specifier = ">=2.5.1.post0" },
|
| 827 |
{ name = "matplotlib", specifier = ">=3.10.3" },
|
| 828 |
+
{ name = "more-itertools", specifier = ">=10.5.0" },
|
| 829 |
{ name = "mteb", specifier = ">=1.14.15" },
|
| 830 |
{ name = "numpy", specifier = ">=1.26.4" },
|
| 831 |
{ name = "plotly", specifier = ">=6.1.1" },
|
| 832 |
{ name = "psutil", specifier = ">=7.0.0" },
|
| 833 |
{ name = "pydantic", specifier = ">=2.11.5" },
|
| 834 |
{ name = "requests", specifier = ">=2.32.3" },
|
| 835 |
+
{ name = "rich", specifier = ">=10.0.0" },
|
| 836 |
+
{ name = "safetensors", specifier = ">=0.3.0" },
|
| 837 |
{ name = "scikit-learn", specifier = ">=1.6.1" },
|
| 838 |
{ name = "seaborn", specifier = ">=0.13.2" },
|
| 839 |
{ name = "sentence-transformers", specifier = ">=4.1.0" },
|
| 840 |
{ name = "setuptools", specifier = ">=80.8.0" },
|
| 841 |
+
{ name = "skops", specifier = ">=0.11.0" },
|
| 842 |
{ name = "smart-open", extras = ["s3"], specifier = ">=7.1.0" },
|
| 843 |
{ name = "statsmodels", specifier = ">=0.14.4" },
|
| 844 |
+
{ name = "tokenizers", specifier = ">=0.20" },
|
| 845 |
{ name = "torch", specifier = ">=2.7.0" },
|
| 846 |
+
{ name = "tqdm", specifier = ">=4.65.0" },
|
| 847 |
+
{ name = "transformers", specifier = "<=4.52.1" },
|
| 848 |
{ name = "typer", specifier = ">=0.16.0" },
|
| 849 |
]
|
| 850 |
|
|
|
|
| 1201 |
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 },
|
| 1202 |
]
|
| 1203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1204 |
[[package]]
|
| 1205 |
name = "more-itertools"
|
| 1206 |
version = "10.7.0"
|
|
|
|
| 2474 |
{ url = "https://files.pythonhosted.org/packages/e6/b6/072a8e053ae600dcc2ac0da81a23548e3b523301a442a6ca900e92ac35be/tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382", size = 2435481 },
|
| 2475 |
]
|
| 2476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2477 |
[[package]]
|
| 2478 |
name = "torch"
|
| 2479 |
version = "2.7.0"
|
|
|
|
| 2546 |
|
| 2547 |
[[package]]
|
| 2548 |
name = "transformers"
|
| 2549 |
+
version = "4.52.1"
|
| 2550 |
source = { registry = "https://pypi.org/simple" }
|
| 2551 |
dependencies = [
|
| 2552 |
{ name = "filelock" },
|
|
|
|
| 2560 |
{ name = "tokenizers" },
|
| 2561 |
{ name = "tqdm" },
|
| 2562 |
]
|
| 2563 |
+
sdist = { url = "https://files.pythonhosted.org/packages/4a/de/f3f3a0649dc522aeff55a5739e06e132c875c53701307a2ddd7ce7528ec5/transformers-4.52.1.tar.gz", hash = "sha256:c380d583ed9c7ebe3e30ca5e55ec1249db39eb9ee277f8e74dab1abc6a03c938", size = 8944009 }
|
| 2564 |
wheels = [
|
| 2565 |
+
{ url = "https://files.pythonhosted.org/packages/b8/1e/2b00e5021c3545d4a0ae32f3d332ae29e62a6259092f1468976e7b9d4adb/transformers-4.52.1-py3-none-any.whl", hash = "sha256:604b2bb357c480dc5883b7944e8562c967f6b06f63dfb6a1c4665d13d067148f", size = 10459023 },
|
| 2566 |
]
|
| 2567 |
|
| 2568 |
[[package]]
|