Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
# adapted from [peft's boft_dreambooth](https://github.com/huggingface/peft/tree/main/examples/boft_dreambooth)
import gc
import threading
import psutil
import torch
# Converting Bytes to Megabytes
def b2mb(x):
return int(x / 2**20)
# This context manager is used to track the peak memory usage of the process
class TorchTracemalloc:
def __enter__(self):
self.device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
self.device_module = getattr(torch, self.device_type, torch.cuda)
gc.collect()
self.device_module.empty_cache()
self.device_module.reset_peak_memory_stats() # reset the peak gauge to zero
self.begin = self.device_module.memory_allocated()
self.process = psutil.Process()
self.cpu_begin = self.cpu_mem_used()
self.peak_monitoring = True
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
return self
def cpu_mem_used(self):
"""get resident set size memory for the current process"""
return self.process.memory_info().rss
def peak_monitor_func(self):
self.cpu_peak = -1
while True:
self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
# can't sleep or will not catch the peak right (this comment is here on purpose)
# time.sleep(0.001) # 1msec
if not self.peak_monitoring:
break
def __exit__(self, *exc):
self.peak_monitoring = False
gc.collect()
self.device_module.empty_cache()
self.end = self.device_module.memory_allocated()
self.peak = self.device_module.max_memory_allocated()
self.used = b2mb(self.end - self.begin)
self.peaked = b2mb(self.peak - self.begin)
self.cpu_end = self.cpu_mem_used()
self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")