|
|
|
|
|
|
|
|
import gc |
|
|
import threading |
|
|
|
|
|
import psutil |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def b2mb(x): |
|
|
return int(x / 2**20) |
|
|
|
|
|
|
|
|
|
|
|
class TorchTracemalloc: |
|
|
def __enter__(self): |
|
|
self.device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" |
|
|
self.device_module = getattr(torch, self.device_type, torch.cuda) |
|
|
gc.collect() |
|
|
self.device_module.empty_cache() |
|
|
self.device_module.reset_peak_memory_stats() |
|
|
self.begin = self.device_module.memory_allocated() |
|
|
self.process = psutil.Process() |
|
|
|
|
|
self.cpu_begin = self.cpu_mem_used() |
|
|
self.peak_monitoring = True |
|
|
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) |
|
|
peak_monitor_thread.daemon = True |
|
|
peak_monitor_thread.start() |
|
|
return self |
|
|
|
|
|
def cpu_mem_used(self): |
|
|
"""get resident set size memory for the current process""" |
|
|
return self.process.memory_info().rss |
|
|
|
|
|
def peak_monitor_func(self): |
|
|
self.cpu_peak = -1 |
|
|
|
|
|
while True: |
|
|
self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.peak_monitoring: |
|
|
break |
|
|
|
|
|
def __exit__(self, *exc): |
|
|
self.peak_monitoring = False |
|
|
|
|
|
gc.collect() |
|
|
self.device_module.empty_cache() |
|
|
self.end = self.device_module.memory_allocated() |
|
|
self.peak = self.device_module.max_memory_allocated() |
|
|
self.used = b2mb(self.end - self.begin) |
|
|
self.peaked = b2mb(self.peak - self.begin) |
|
|
|
|
|
self.cpu_end = self.cpu_mem_used() |
|
|
self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) |
|
|
self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) |
|
|
|
|
|
|