Habiba A. Elbehairy
Refactor Code Similarity Classifier and update Dockerfile, README, and requirements
a5cd505
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel | |
| import re | |
| class CodeSimilarityClassifier(nn.Module): | |
| def __init__(self, model_name="microsoft/codebert-base", num_labels=3): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| self.dropout = nn.Dropout(0.1) | |
| # Create a more powerful classification head | |
| hidden_size = self.encoder.config.hidden_size | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.LayerNorm(hidden_size), | |
| nn.GELU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_size, 512), | |
| nn.LayerNorm(512), | |
| nn.GELU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(512, num_labels) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| pooled_output = outputs.pooler_output | |
| logits = self.classifier(pooled_output) | |
| return logits | |
| def extract_features(source_code, test_code_1, test_code_2): | |
| """Extract specific features to help the model identify similarities""" | |
| # Extract test fixtures | |
| fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1) | |
| fixture1 = fixture1.group(1) if fixture1 else "" | |
| fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2) | |
| fixture2 = fixture2.group(1) if fixture2 else "" | |
| # Extract test names | |
| name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1) | |
| name1 = name1.group(1) if name1 else "" | |
| name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2) | |
| name2 = name2.group(1) if name2 else "" | |
| # Extract assertions | |
| assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1) | |
| assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2) | |
| # Extract function/method calls | |
| calls1 = re.findall(r'(\w+)\s*\(', test_code_1) | |
| calls2 = re.findall(r'(\w+)\s*\(', test_code_2) | |
| # Create explicit feature section | |
| same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE" | |
| common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2])) | |
| common_calls = set(calls1).intersection(set(calls2)) | |
| features = ( | |
| f"METADATA: {same_fixture} | " | |
| f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | " | |
| f"NAME1: {name1} | NAME2: {name2} | " | |
| f"COMMON_ASSERTIONS: {len(common_assertions)} | " | |
| f"COMMON_CALLS: {len(common_calls)} | " | |
| f"ASSERTION_RATIO: {len(common_assertions)/(len(assertions1) + len(assertions2)) if assertions1 and assertions2 else 0}" | |
| ) | |
| return features |