Habiba A. Elbehairy
Refactor Code Similarity Classifier and update Dockerfile, README, and requirements
a5cd505
| import os | |
| import logging | |
| import torch | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List | |
| import uvicorn | |
| from datetime import datetime | |
| from transformers import AutoTokenizer, AutoModel | |
| import requests | |
| import re | |
| import tempfile | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler()] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # System information - with your current values | |
| DEPLOYMENT_DATE = "2025-06-22 22:15:13" | |
| DEPLOYED_BY = "FASTESTAI" | |
| # Get device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # HuggingFace model repository path just for weights file | |
| REPO_ID = "FastestAI/Redundant_Model" | |
| MODEL_WEIGHTS_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/pytorch_model.bin" | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Test Similarity Analyzer API", | |
| description="API for analyzing similarity between test cases. Deployed by " + DEPLOYED_BY, | |
| version="1.0.0", | |
| docs_url="/", | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define label to class mapping | |
| label_to_class = {0: "Duplicate", 1: "Redundant", 2: "Distinct"} | |
| # Define input models for API | |
| class SourceCode(BaseModel): | |
| class_name: str | |
| code: str | |
| class TestCase(BaseModel): | |
| id: str | |
| test_fixture: str | |
| name: str | |
| code: str | |
| target_class: str | |
| target_method: List[str] | |
| class SimilarityInput(BaseModel): | |
| pair_id: str | |
| source_code: SourceCode | |
| test_case_1: TestCase | |
| test_case_2: TestCase | |
| # Define the model class | |
| class CodeSimilarityClassifier(torch.nn.Module): | |
| def __init__(self, model_name="microsoft/codebert-base", num_labels=3): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| self.dropout = torch.nn.Dropout(0.1) | |
| # Create a more powerful classification head | |
| hidden_size = self.encoder.config.hidden_size | |
| self.classifier = torch.nn.Sequential( | |
| torch.nn.Linear(hidden_size, hidden_size), | |
| torch.nn.LayerNorm(hidden_size), | |
| torch.nn.GELU(), | |
| torch.nn.Dropout(0.1), | |
| torch.nn.Linear(hidden_size, 512), | |
| torch.nn.LayerNorm(512), | |
| torch.nn.GELU(), | |
| torch.nn.Dropout(0.1), | |
| torch.nn.Linear(512, num_labels) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True | |
| ) | |
| pooled_output = outputs.pooler_output | |
| logits = self.classifier(pooled_output) | |
| return logits | |
| def extract_features(source_code, test_code_1, test_code_2): | |
| """Extract specific features to help the model identify similarities""" | |
| # Extract test fixtures | |
| fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1) | |
| fixture1 = fixture1.group(1) if fixture1 else "" | |
| fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2) | |
| fixture2 = fixture2.group(1) if fixture2 else "" | |
| # Extract test names | |
| name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1) | |
| name1 = name1.group(1) if name1 else "" | |
| name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2) | |
| name2 = name2.group(1) if name2 else "" | |
| # Extract assertions | |
| assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1) | |
| assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2) | |
| # Extract function/method calls | |
| calls1 = re.findall(r'(\w+)\s*\(', test_code_1) | |
| calls2 = re.findall(r'(\w+)\s*\(', test_code_2) | |
| # Create explicit feature section | |
| same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE" | |
| common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2])) | |
| common_calls = set(calls1).intersection(set(calls2)) | |
| # Calculate assertion ratio with safety check for zero | |
| assertion_ratio = 0 | |
| if assertions1 and assertions2: | |
| total_assertions = len(assertions1) + len(assertions2) | |
| if total_assertions > 0: | |
| assertion_ratio = len(common_assertions) / total_assertions | |
| features = ( | |
| f"METADATA: {same_fixture} | " | |
| f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | " | |
| f"NAME1: {name1} | NAME2: {name2} | " | |
| f"COMMON_ASSERTIONS: {len(common_assertions)} | " | |
| f"COMMON_CALLS: {len(common_calls)} | " | |
| f"ASSERTION_RATIO: {assertion_ratio}" | |
| ) | |
| return features | |
| # Global variables for model and tokenizer | |
| tokenizer = None | |
| model = None | |
| def download_model_weights(url, save_path): | |
| """Download model weights from URL to a local file""" | |
| try: | |
| logger.info(f"Downloading model weights from {url}...") | |
| response = requests.get(url, stream=True) | |
| if response.status_code != 200: | |
| logger.error(f"Failed to download: HTTP {response.status_code}") | |
| return False | |
| with open(save_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| logger.info(f"Successfully downloaded model weights to {save_path}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error downloading model weights: {e}") | |
| return False | |
| # Load model and tokenizer on startup | |
| async def startup_event(): | |
| global tokenizer, model | |
| try: | |
| logger.info("=== Starting model loading process ===") | |
| # Step 1: Load the tokenizer from the base model | |
| logger.info(f"Loading tokenizer from microsoft/codebert-base...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") | |
| logger.info("✅ Base tokenizer loaded successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load tokenizer: {str(e)}") | |
| raise | |
| # Step 2: Create model with base architecture | |
| logger.info("Creating model architecture...") | |
| try: | |
| # Initialize with base CodeBERT | |
| model = CodeSimilarityClassifier(model_name="microsoft/codebert-base") | |
| logger.info("✅ Model architecture created successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to create model architecture: {str(e)}") | |
| raise | |
| # Step 3: Download and load weights | |
| model_path = "pytorch_model.bin" | |
| # First check if the file already exists | |
| if not os.path.exists(model_path): | |
| # Try downloading | |
| if not download_model_weights(MODEL_WEIGHTS_URL, model_path): | |
| logger.error("❌ Failed to download model weights") | |
| raise RuntimeError("Failed to download model weights") | |
| # Try to load the model weights | |
| try: | |
| # Check if the weights are a state dict or the whole model | |
| logger.info(f"Loading weights from {model_path}...") | |
| checkpoint = torch.load(model_path, map_location=device) | |
| if isinstance(checkpoint, dict): | |
| # If it's a state dict directly | |
| if "state_dict" in checkpoint: | |
| logger.info("Loading from checkpoint['state_dict']") | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| elif "model_state_dict" in checkpoint: | |
| logger.info("Loading from checkpoint['model_state_dict']") | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| else: | |
| logger.info("Loading from checkpoint directly") | |
| model.load_state_dict(checkpoint) | |
| else: | |
| logger.error("❌ Unsupported model format") | |
| raise RuntimeError("Unsupported model format") | |
| logger.info("✅ Model weights loaded successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Error loading model weights: {str(e)}") | |
| raise | |
| # Move model to device and set to evaluation mode | |
| model.to(device) | |
| model.eval() | |
| logger.info(f"✅ Model moved to {device} and set to evaluation mode") | |
| logger.info("=== Model loading process complete ===") | |
| except Exception as e: | |
| logger.error(f"❌ CRITICAL ERROR in startup: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| model = None | |
| tokenizer = None | |
| async def health_check(): | |
| """Health check endpoint that also returns deployment information""" | |
| model_status = model is not None | |
| tokenizer_status = tokenizer is not None | |
| status = "ok" if (model_status and tokenizer_status) else "error" | |
| return { | |
| "status": status, | |
| "model_loaded": model_status, | |
| "tokenizer_loaded": tokenizer_status, | |
| "model": REPO_ID, | |
| "device": str(device), | |
| "deployment_date": DEPLOYMENT_DATE, | |
| "deployed_by": DEPLOYED_BY, | |
| "current_time": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") | |
| } | |
| async def predict(data: SimilarityInput): | |
| """ | |
| Predict similarity class between two test cases for a given source class. | |
| """ | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded correctly") | |
| try: | |
| # Apply heuristics for method and class differences | |
| class_1 = data.test_case_1.target_class | |
| class_2 = data.test_case_2.target_class | |
| method_1 = data.test_case_1.target_method | |
| method_2 = data.test_case_2.target_method | |
| # Check if we can determine similarity without using the model | |
| if class_1 and class_2 and class_1 != class_2: | |
| logger.info(f"Heuristic detection: Different target classes - Distinct") | |
| model_prediction = 2 # Distinct | |
| probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct | |
| elif method_1 and method_2 and not set(method_1).intersection(set(method_2)): | |
| logger.info(f"Heuristic detection: Different target methods - Distinct") | |
| model_prediction = 2 # Distinct | |
| probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct | |
| else: | |
| # No clear heuristic match, use the model | |
| # Extract features to help with classification | |
| features = extract_features(data.source_code.code, data.test_case_1.code, data.test_case_2.code) | |
| # Format the input text with clear section markers as done during training | |
| formatted_text = ( | |
| f"{features}\n\n" | |
| f"SOURCE CODE:\n{data.source_code.code.strip()}\n\n" | |
| f"TEST CASE 1:\n{data.test_case_1.code.strip()}\n\n" | |
| f"TEST CASE 2:\n{data.test_case_2.code.strip()}" | |
| ) | |
| # Tokenize input | |
| inputs = tokenizer( | |
| formatted_text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512 | |
| ).to(device) | |
| # Model inference | |
| with torch.no_grad(): | |
| logits = model( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"] | |
| ) | |
| # Process results | |
| probs = F.softmax(logits, dim=-1)[0].cpu().tolist() | |
| model_prediction = torch.argmax(logits, dim=-1).item() | |
| logger.info(f"Model prediction: {label_to_class[model_prediction]}") | |
| # Map prediction to class name | |
| classification = label_to_class.get(model_prediction, "Unknown") | |
| # For API compatibility, map the model outputs (0,1,2) to API scores (1,2,3) | |
| api_score = model_prediction + 1 | |
| return { | |
| "pair_id": data.pair_id, | |
| "test_case_1_name": data.test_case_1.name, | |
| "test_case_2_name": data.test_case_2.name, | |
| "similarity": { | |
| "score": api_score, | |
| "classification": classification, | |
| }, | |
| "probabilities": probs | |
| } | |
| except Exception as e: | |
| import traceback | |
| error_trace = traceback.format_exc() | |
| logger.error(f"Prediction error: {str(e)}") | |
| logger.error(error_trace) | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| # Root and example endpoints | |
| async def root(): | |
| return { | |
| "message": "Test Similarity Analyzer API", | |
| "documentation": "/docs", | |
| "deployment_date": DEPLOYMENT_DATE, | |
| "deployed_by": DEPLOYED_BY | |
| } | |
| async def get_example(): | |
| """Get an example input to test the API""" | |
| return SimilarityInput( | |
| pair_id="example-1", | |
| source_code=SourceCode( | |
| class_name="Calculator", | |
| code="class Calculator {\n public int add(int a, int b) {\n return a + b;\n }\n}" | |
| ), | |
| test_case_1=TestCase( | |
| id="test-1", | |
| test_fixture="CalculatorTest", | |
| name="testAddsTwoPositiveNumbers", | |
| code="TEST(CalculatorTest, AddsTwoPositiveNumbers) {\n Calculator calc;\n EXPECT_EQ(5, calc.add(2, 3));\n}", | |
| target_class="Calculator", | |
| target_method=["add"] | |
| ), | |
| test_case_2=TestCase( | |
| id="test-2", | |
| test_fixture="CalculatorTest", | |
| name="testAddsTwoPositiveIntegers", | |
| code="TEST(CalculatorTest, AddsTwoPositiveIntegers) {\n Calculator calc;\n EXPECT_EQ(5, calc.add(2, 3));\n}", | |
| target_class="Calculator", | |
| target_method=["add"] | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |