【转】解决内存/显存泄露的方法 pytorch
def debug_memory(): import collections, gc, resource, torch print(maxrss = {}.format( resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)) tensors = collections.Counter((str(o.device), o.dtype, tuple(o.shape)) for o in gc.get_objects() if torch.is_tensor(o)) for line in sorted(tensors.items()): print({} {}.format(*line))
使用上面的函数,在training loop 里用,可以追踪占用显寸的变量的大小,从而发现一直在扩大的、非正常占用显存的问题变量。
下面是一个例子(这里面没有用gpu,都是cpu)
>>> z = [torch.randn(i).long() for i in range(10)] >>> debug_memory() (cpu, torch.float32, (3, 3)) 2 (cpu, torch.int64, (0,)) 1 (cpu, torch.int64, (1,)) 1 (cpu, torch.int64, (2,)) 1 (cpu, torch.int64, (3,)) 1 (cpu, torch.int64, (4,)) 1 (cpu, torch.int64, (5,)) 1 (cpu, torch.int64, (6,)) 1 (cpu, torch.int64, (7,)) 1 (cpu, torch.int64, (8,)) 1 (cpu, torch.int64, (9,)) 1
上一篇:
通过多线程提高代码的执行效率例子
下一篇:
小米路由3G刷openwrt固件