错误:Pytorch1.9 ImportError: cannot import name ‘zero_gradients’
错误提示:
ImportError: cannot import name ‘zero_gradients’ from ‘torch.autograd.gradcheck’ (/root/miniconda3/envs/d2l/lib/python3.9/site-packages/torch/autograd/gradcheck.py)
原因:
pytorch版本更新后,没有对应的方法函数
解决:
将~/miniconda3/envs/d2l/lib/python3.9/site-packages/advertorch/attacks/fast_adaptive_boundary.py
中的
from torch.autograd.gradcheck import zero_gradients
删掉,加入
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, collections.abc.Iterable):
for elem in x:
zero_gradients(elem)
参考:https://zhuanlan.zhihu.com/p/420312739