Spaces:
Runtime error
Runtime error
| import torch.multiprocessing | |
| import torchvision.transforms as T | |
| from utils import transform_to_pil | |
| import logging | |
| preprocess = T.Compose( | |
| [ | |
| T.ToPILImage(), | |
| T.Resize((320, 320)), | |
| # T.CenterCrop(224), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| import numpy as np | |
| def inference(images, model): | |
| logging.info("Inference on Images") | |
| x = torch.stack([preprocess(image) for image in images]).cpu() | |
| with torch.no_grad(): | |
| _, code = model.net(x) | |
| linear_pred = model.linear_probe(x, code) | |
| linear_pred = linear_pred.argmax(1) | |
| outputs = [{ | |
| "img": x[i].detach().cpu(), | |
| "linear_preds": linear_pred[i].detach().cpu(), | |
| } for i in range(x.shape[0])] | |
| # water to natural green | |
| for output in outputs: | |
| output["linear_preds"] = torch.where(output["linear_preds"] == 5, 3, output["linear_preds"]) | |
| return outputs | |
| if __name__ == "__main__": | |
| import hydra | |
| from model import LitUnsupervisedSegmenter | |
| from utils_gee import extract_img, transform_ee_img | |
| import os | |
| latitude = 2.98 | |
| longitude = 48.81 | |
| start_date = '2020-03-20' | |
| end_date = '2020-04-20' | |
| location = [float(latitude), float(longitude)] | |
| # Extract img numpy from earth engine and transform it to PIL img | |
| img = extract_img(location, start_date, end_date) | |
| image = transform_ee_img( | |
| img, max=0.3 | |
| ) # max value is the value from numpy file that will be equal to 255 | |
| print("image loaded") | |
| # Initialize hydra with configs | |
| hydra.initialize(config_path="configs", job_name="corine") | |
| cfg = hydra.compose(config_name="my_train_config.yml") | |
| # Load the model | |
| model_path = os.path.join(os.path.dirname(__file__), "checkpoint/model/model.pt") | |
| saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
| nbclasses = cfg.dir_dataset_n_classes | |
| model = LitUnsupervisedSegmenter(nbclasses, cfg) | |
| print("model initialized") | |
| model.load_state_dict(saved_state_dict) | |
| print("model loaded") | |
| # img.save("output/image.png") | |
| inference([image], model) | |
| inference([image,image], model) | |