Spaces:
Runtime error
Runtime error
| import torch.multiprocessing | |
| import torchvision.transforms as T | |
| import numpy as np | |
| from utils import transform_to_pil, compute_biodiv_score, plot_imgs_labels, plot_image | |
| from utils_gee import get_image | |
| from dateutil.relativedelta import relativedelta | |
| from model import LitUnsupervisedSegmenter | |
| import datetime | |
| import matplotlib as mpl | |
| from joblib import Parallel, cpu_count, delayed | |
| import logging | |
| from inference import inference | |
| import streamlit as st | |
| import cv2 | |
| def inference_on_location(model, longitude=2.98, latitude=48.81, start_date=2020, end_date=2022, how="year"): | |
| """Performe an inference on the latitude and longitude between the start date and the end date | |
| Args: | |
| latitude (float): the latitude of the landscape | |
| longitude (float): the longitude of the landscape | |
| start_date (str): the start date for our inference | |
| end_date (str): the end date for our inference | |
| model (_type_, optional): _description_. Defaults to model. | |
| Returns: | |
| img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape | |
| """ | |
| logging.info("Running Inference on location") | |
| logging.info(f"latitude : {latitude} & longitude : {longitude}") | |
| logging.info(f"start date : {start_date} & end_date : {end_date}") | |
| logging.info(f"Prediction on intervale : {how}") | |
| if how == "month": | |
| delta_month = 1 | |
| elif how == "2months": | |
| delta_month = 2 | |
| elif how == "year": | |
| delta_month = 11 | |
| else: | |
| raise ValueError("Wrong interval") | |
| assert int(end_date) > int(start_date), "end date must be stricly higher than start date" | |
| location = [float(latitude), float(longitude)] | |
| # Extract img numpy from earth engine and transform it to PIL img | |
| dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)] | |
| while dates[-1] < datetime.datetime(int(end_date), 1, 1, 0, 0, 0): | |
| dates.append(dates[-1] + relativedelta(months=delta_month)) | |
| dates = [d.strftime("%Y-%m-%d") for d in dates] | |
| all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:])) | |
| # all_image = [cv2.imread("output/img.png") for i in range(len(dates))] | |
| outputs = inference(np.array(all_image), model) | |
| logging.info("Calculating Biodiversity Scores...") | |
| scores, scores_details = map(list, zip(*[compute_biodiv_score(output["linear_preds"].detach().numpy()) for output in outputs])) | |
| logging.info(f"Calculated Biodiversity Score : {scores}") | |
| imgs, labels, labeled_imgs = map(list, zip(*[transform_to_pil(output) for output in outputs])) | |
| images = [np.asarray(img) for img in imgs] | |
| labeled_imgs = [np.asarray(img) for img in labeled_imgs] | |
| fig = plot_imgs_labels(dates, images, labeled_imgs, scores_details, scores) | |
| # fig.save("test.png") | |
| return fig | |
| def inference_on_location_and_month(model, longitude = 2.98, latitude = 48.81, start_date = '2020-03-20'): | |
| """Performe an inference on the latitude and longitude between the start date and the end date | |
| Args: | |
| latitude (float): the latitude of the landscape | |
| longitude (float): the longitude of the landscape | |
| start_date (str): the start date for our inference | |
| end_date (str): the end date for our inference | |
| model (_type_, optional): _description_. Defaults to model. | |
| Returns: | |
| img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape | |
| """ | |
| logging.info("Running Inference on location and month") | |
| logging.info(f"latitude : {latitude} & longitude : {longitude}") | |
| location = [float(latitude), float(longitude)] | |
| # Extract img numpy from earth engine and transform it to PIL img | |
| end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1) | |
| end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d") | |
| img_test = get_image(location, start_date, end_date) | |
| outputs = inference(np.array([img_test]), model) | |
| logging.info("Calculating Biodiversity Score...") | |
| score, score_details = compute_biodiv_score(outputs[0]["linear_preds"].detach().numpy()) | |
| logging.info(f"Calculated Biodiversity Score : {score}") | |
| img, label, labeled_img = transform_to_pil(outputs[0]) | |
| fig = plot_image([start_date], [np.asarray(img)], [np.asarray(labeled_img)], [score_details], [score]) | |
| return fig | |
| if __name__ == "__main__": | |
| import logging | |
| import hydra | |
| import sys | |
| from model import LitUnsupervisedSegmenter | |
| file_handler = logging.FileHandler(filename='biomap.log') | |
| stdout_handler = logging.StreamHandler(stream=sys.stdout) | |
| handlers = [file_handler, stdout_handler] | |
| logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| # Initialize hydra with configs | |
| hydra.initialize(config_path="configs", job_name="corine") | |
| cfg = hydra.compose(config_name="my_train_config.yml") | |
| logging.info(f"config : {cfg}") | |
| # Load the model | |
| nbclasses = cfg.dir_dataset_n_classes | |
| model = LitUnsupervisedSegmenter(nbclasses, cfg) | |
| logging.info(f"Model Initialiazed") | |
| model_path = "biomap/checkpoint/model/model.pt" | |
| saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
| logging.info(f"Model weights Loaded") | |
| model.load_state_dict(saved_state_dict) | |
| logging.info(f"Model Loaded") | |
| # inference_on_location_and_month(model) | |
| inference_on_location(model) | |