Skip to content

async_clear_tensor_memory

Source code in tinybig/util/utility.py
def async_clear_tensor_memory(tensor):
    if tensor is None:
        print("Tensor is None, nothing to clear.")
        return

    device = tensor.device

    if device.type == 'cuda':
        # Create a CUDA stream
        stream = torch.cuda.Stream()
        with torch.cuda.stream(stream):
            # Move tensor to CPU asynchronously
            tensor_cpu = tensor.cpu()
            del tensor
            # Clear unused GPU memory
            torch.cuda.empty_cache()
            print("Memory cleared for CUDA tensor.")
    elif device.type == 'mps':
        # MPS specific handling (if applicable)
        tensor_cpu = tensor.cpu()
        del tensor
        print("Memory cleared for MPS tensor.")
    else:
        del tensor
        print(f"Memory cleared for tensor on {device.type} device.")

    # Force garbage collection
    import gc
    gc.collect()