Ahmad Hathim bin Ahmad Azman commited on
Commit
8438377
Β·
1 Parent(s): f3ce8a7

fixed pytorch

Browse files
Files changed (1) hide show
  1. model_inference.py +29 -7
model_inference.py CHANGED
@@ -19,26 +19,48 @@ def ensure_model_file(filename: str):
19
  return path
20
 
21
 
22
- def load_model(path):
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  checkpoint = torch.load(path, map_location="cpu")
24
-
25
- # Recreate the same model architecture
26
  model = EnsembleBertBiLSTMRegressor(
27
  model_name_mcq="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
28
  model_name_clinical="emilyalsentzer/Bio_ClinicalBERT",
29
  hidden_dim=768,
30
- extra_dim=67 # e.g. 10 if you have 10 engineered + categorical features
31
  )
32
 
33
- # Load saved weights
34
- model.load_state_dict(checkpoint["model_state"])
35
- model.eval()
 
 
 
 
36
 
 
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  model.to(device)
 
 
 
39
  return model, device
40
 
41
 
 
42
  def predict_from_input(data, model, device, tok_mcq, tok_clin, encoder, scaler):
43
  """
44
  Predict difficulty and discrimination index for a single MCQ item.
 
19
  return path
20
 
21
 
22
+ import os
23
+ import torch
24
+ from model_architecture import EnsembleBertBiLSTMRegressor
25
+
26
+ def load_model(path: str = "assets/best_checkpoint_regression.pt"):
27
+ """
28
+ Load the trained EnsembleBertBiLSTMRegressor model using saved checkpoint weights.
29
+ Supports CPU/GPU execution.
30
+ """
31
+
32
+ if not os.path.exists(path):
33
+ raise FileNotFoundError(f"❌ Model checkpoint not found at: {path}")
34
+
35
+ print(f"βœ… Loading model weights from: {path}")
36
  checkpoint = torch.load(path, map_location="cpu")
37
+
38
+ # βœ… Recreate model architecture (must match training exactly!)
39
  model = EnsembleBertBiLSTMRegressor(
40
  model_name_mcq="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
41
  model_name_clinical="emilyalsentzer/Bio_ClinicalBERT",
42
  hidden_dim=768,
43
+ extra_dim=67 # Adjust if your engineered features size differs
44
  )
45
 
46
+ # βœ… Load weights into model
47
+ if "model_state" in checkpoint:
48
+ model.load_state_dict(checkpoint["model_state"])
49
+ elif "state_dict" in checkpoint: # support alternative saving formats
50
+ model.load_state_dict(checkpoint["state_dict"])
51
+ else:
52
+ raise KeyError("❌ No 'model_state' or 'state_dict' found in checkpoint")
53
 
54
+ # βœ… Set eval mode and move to device
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  model.to(device)
57
+ model.eval()
58
+
59
+ print(f"βœ… Model loaded successfully on device: {device}")
60
  return model, device
61
 
62
 
63
+
64
  def predict_from_input(data, model, device, tok_mcq, tok_clin, encoder, scaler):
65
  """
66
  Predict difficulty and discrimination index for a single MCQ item.