def async_clear_tensor_memory(tensor):
if tensor is None:
print("Tensor is None, nothing to clear.")
return
device = tensor.device
if device.type == 'cuda':
# Create a CUDA stream
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# Move tensor to CPU asynchronously
tensor_cpu = tensor.cpu()
del tensor
# Clear unused GPU memory
torch.cuda.empty_cache()
print("Memory cleared for CUDA tensor.")
elif device.type == 'mps':
# MPS specific handling (if applicable)
tensor_cpu = tensor.cpu()
del tensor
print("Memory cleared for MPS tensor.")
else:
del tensor
print(f"Memory cleared for tensor on {device.type} device.")
# Force garbage collection
import gc
gc.collect()