Update inference.py
Browse files- inference.py +6 -5
inference.py
CHANGED
|
@@ -30,8 +30,9 @@ def save_output(mask, save_path):
|
|
| 30 |
mask_image = Image.fromarray(mask[0])
|
| 31 |
mask_image.save(save_path)
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
| 30 |
mask_image = Image.fromarray(mask[0])
|
| 31 |
mask_image.save(save_path)
|
| 32 |
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
weights_path = "unet_model.pth"
|
| 35 |
+
model = load_model(weights_path, device)
|
| 36 |
+
image_tensor = preprocess_image("DUTS-TE-Image/ILSVRC2012_test_00000003.jpg")
|
| 37 |
+
mask = predict(model, image_tensor, device)
|
| 38 |
+
save_output(mask, "predicted_mask.jpg")
|