Spaces:
Runtime error
Runtime error
| from utils import * | |
| from modules import * | |
| from data import * | |
| from torch.utils.data import DataLoader | |
| import torch.nn.functional as F | |
| from datetime import datetime | |
| import hydra | |
| from omegaconf import DictConfig, OmegaConf | |
| import pytorch_lightning as pl | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| from pytorch_lightning.utilities.seed import seed_everything | |
| import torch.multiprocessing | |
| import seaborn as sns | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| import sys | |
| import pdb | |
| import matplotlib as mpl | |
| from skimage import measure | |
| from scipy.stats import mode as statsmode | |
| from collections import OrderedDict | |
| import unet | |
| import pdb | |
| torch.multiprocessing.set_sharing_strategy("file_system") | |
| colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") | |
| class_names = ( | |
| "Buildings", | |
| "Cultivation", | |
| "Natural green", | |
| "Wetland", | |
| "Water", | |
| "Infrastructure", | |
| "Background", | |
| ) | |
| bounds = list(np.arange(len(class_names) + 1) + 1) | |
| cmap = mpl.colors.ListedColormap(colors) | |
| norm = mpl.colors.BoundaryNorm(bounds, cmap.N) | |
| def retouch_label(pred_label, true_label): | |
| retouched_label = pred_label + 0 | |
| blobs = measure.label(retouched_label) | |
| for idx in np.unique(blobs): | |
| # most frequent label class in this blob | |
| retouched_label[blobs == idx] = statsmode(true_label[blobs == idx])[0][0] | |
| return retouched_label | |
| def get_class_labels(dataset_name): | |
| if dataset_name.startswith("cityscapes"): | |
| return [ | |
| "road", | |
| "sidewalk", | |
| "parking", | |
| "rail track", | |
| "building", | |
| "wall", | |
| "fence", | |
| "guard rail", | |
| "bridge", | |
| "tunnel", | |
| "pole", | |
| "polegroup", | |
| "traffic light", | |
| "traffic sign", | |
| "vegetation", | |
| "terrain", | |
| "sky", | |
| "person", | |
| "rider", | |
| "car", | |
| "truck", | |
| "bus", | |
| "caravan", | |
| "trailer", | |
| "train", | |
| "motorcycle", | |
| "bicycle", | |
| ] | |
| elif dataset_name == "cocostuff27": | |
| return [ | |
| "electronic", | |
| "appliance", | |
| "food", | |
| "furniture", | |
| "indoor", | |
| "kitchen", | |
| "accessory", | |
| "animal", | |
| "outdoor", | |
| "person", | |
| "sports", | |
| "vehicle", | |
| "ceiling", | |
| "floor", | |
| "food", | |
| "furniture", | |
| "rawmaterial", | |
| "textile", | |
| "wall", | |
| "window", | |
| "building", | |
| "ground", | |
| "plant", | |
| "sky", | |
| "solid", | |
| "structural", | |
| "water", | |
| ] | |
| elif dataset_name == "voc": | |
| return [ | |
| "background", | |
| "aeroplane", | |
| "bicycle", | |
| "bird", | |
| "boat", | |
| "bottle", | |
| "bus", | |
| "car", | |
| "cat", | |
| "chair", | |
| "cow", | |
| "diningtable", | |
| "dog", | |
| "horse", | |
| "motorbike", | |
| "person", | |
| "pottedplant", | |
| "sheep", | |
| "sofa", | |
| "train", | |
| "tvmonitor", | |
| ] | |
| elif dataset_name == "potsdam": | |
| return ["roads and cars", "buildings and clutter", "trees and vegetation"] | |
| else: | |
| raise ValueError("Unknown Dataset {}".format(dataset_name)) | |
| def my_app(cfg: DictConfig) -> None: | |
| OmegaConf.set_struct(cfg, False) | |
| print(OmegaConf.to_yaml(cfg)) | |
| pytorch_data_dir = cfg.pytorch_data_dir | |
| data_dir = join(cfg.output_root, "data") | |
| log_dir = join(cfg.output_root, "logs") | |
| checkpoint_dir = join(cfg.output_root, "checkpoints") | |
| prefix = "{}/{}_{}".format(cfg.log_dir, cfg.dataset_name, cfg.experiment_name) | |
| name = "{}_date_{}".format(prefix, datetime.now().strftime("%b%d_%H-%M-%S")) | |
| cfg.full_name = prefix | |
| os.makedirs(data_dir, exist_ok=True) | |
| os.makedirs(log_dir, exist_ok=True) | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| seed_everything(seed=0) | |
| print(data_dir) | |
| print(cfg.output_root) | |
| geometric_transforms = T.Compose( | |
| [T.RandomHorizontalFlip(), T.RandomResizedCrop(size=cfg.res, scale=(0.8, 1.0))] | |
| ) | |
| photometric_transforms = T.Compose( | |
| [ | |
| T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), | |
| T.RandomGrayscale(0.2), | |
| T.RandomApply([T.GaussianBlur((5, 5))]), | |
| ] | |
| ) | |
| sys.stdout.flush() | |
| train_dataset = ContrastiveSegDataset( | |
| pytorch_data_dir=pytorch_data_dir, | |
| dataset_name=cfg.dataset_name, | |
| crop_type=cfg.crop_type, | |
| image_set="train", | |
| transform=get_transform(cfg.res, False, cfg.loader_crop_type), | |
| target_transform=get_transform(cfg.res, True, cfg.loader_crop_type), | |
| cfg=cfg, | |
| aug_geometric_transform=geometric_transforms, | |
| aug_photometric_transform=photometric_transforms, | |
| num_neighbors=cfg.num_neighbors, | |
| mask=True, | |
| pos_images=True, | |
| pos_labels=True, | |
| ) | |
| if cfg.dataset_name == "voc": | |
| val_loader_crop = None | |
| else: | |
| val_loader_crop = "center" | |
| val_dataset = ContrastiveSegDataset( | |
| pytorch_data_dir=pytorch_data_dir, | |
| dataset_name=cfg.dataset_name, | |
| crop_type=None, | |
| image_set="val", | |
| transform=get_transform(320, False, val_loader_crop), | |
| target_transform=get_transform(320, True, val_loader_crop), | |
| mask=True, | |
| cfg=cfg, | |
| ) | |
| # val_dataset = MaterializedDataset(val_dataset) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| cfg.batch_size, | |
| shuffle=True, | |
| num_workers=cfg.num_workers, | |
| pin_memory=True, | |
| ) | |
| if cfg.submitting_to_aml: | |
| val_batch_size = 16 | |
| else: | |
| val_batch_size = cfg.batch_size | |
| val_loader = DataLoader( | |
| val_dataset, | |
| val_batch_size, | |
| shuffle=False, | |
| num_workers=cfg.num_workers, | |
| pin_memory=True, | |
| ) | |
| model = LitUnsupervisedSegmenter(train_dataset.n_classes, cfg) | |
| tb_logger = TensorBoardLogger(join(log_dir, name), default_hp_metric=False) | |
| if cfg.submitting_to_aml: | |
| gpu_args = dict(gpus=1, val_check_interval=250) | |
| if gpu_args["val_check_interval"] > len(train_loader): | |
| gpu_args.pop("val_check_interval") | |
| else: | |
| gpu_args = dict(gpus=-1, accelerator="ddp", val_check_interval=cfg.val_freq) | |
| # gpu_args = dict(gpus=1, accelerator='ddp', val_check_interval=cfg.val_freq) | |
| if gpu_args["val_check_interval"] > len(train_loader) // 4: | |
| gpu_args.pop("val_check_interval") | |
| trainer = Trainer( | |
| log_every_n_steps=cfg.scalar_log_freq, | |
| logger=tb_logger, | |
| max_steps=cfg.max_steps, | |
| callbacks=[ | |
| ModelCheckpoint( | |
| dirpath=join(checkpoint_dir, name), | |
| every_n_train_steps=400, | |
| save_top_k=2, | |
| monitor="test/cluster/mIoU", | |
| mode="max", | |
| ) | |
| ], | |
| **gpu_args | |
| ) | |
| trainer.fit(model, train_loader, val_loader) | |
| if __name__ == "__main__": | |
| prep_args() | |
| my_app() | |