import gc import threading import psutil import torch # Converting Bytes to Megabytes def b2mb(x): return int(x / 2**20) # This context manager is used to track the peak memory usage of the process class TorchTracemalloc: def __enter__(self): self.device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" self.device_module = getattr(torch, self.device_type, torch.cuda) gc.collect() self.device_module.empty_cache() self.device_module.reset_peak_memory_stats() # reset the peak gauge to zero self.begin = self.device_module.memory_allocated() self.process = psutil.Process() self.cpu_begin = self.cpu_mem_used() self.peak_monitoring = True peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) peak_monitor_thread.daemon = True peak_monitor_thread.start() return self def cpu_mem_used(self): """get resident set size memory for the current process""" return self.process.memory_info().rss def peak_monitor_func(self): self.cpu_peak = -1 while True: self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) # can't sleep or will not catch the peak right (this comment is here on purpose) # time.sleep(0.001) # 1msec if not self.peak_monitoring: break def __exit__(self, *exc): self.peak_monitoring = False gc.collect() self.device_module.empty_cache() self.end = self.device_module.memory_allocated() self.peak = self.device_module.max_memory_allocated() self.used = b2mb(self.end - self.begin) self.peaked = b2mb(self.peak - self.begin) self.cpu_end = self.cpu_mem_used() self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")