torch.rfft – fft-based convolution creating different output than spatial convolution
我在 Pytorch 中实现了基于 FFT 的卷积,并通过 conv2d() 函数将结果与空间卷积进行了比较。使用的卷积滤波器是平均滤波器。 conv2d() 函数由于预期的平均滤波而产生了平滑的输出,但基于 fft 的卷积返回了更模糊的输出。
我已在此处附加代码和输出 –
空间卷积-
1
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from PIL import Image, ImageOps import torch from matplotlib import pyplot as plt from torchvision.transforms import ToTensor import torch.nn.functional as F import numpy as np im = Image.open("/kaggle/input/tiger.jpg") fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]]) conv_gray_im = gray_im.unsqueeze(0).unsqueeze(0) conv_op = F.conv2d(conv_gray_im,conv_fil) conv_op = conv_op.squeeze() plt.figure() |
基于 FFT 的卷积 –
1
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
def fftshift(image): sh = image.shape x = np.arange(0, sh[2], 1) y = np.arange(0, sh[3], 1) xm, ym = np.meshgrid(x,y) shifter = (–1)**(xm + ym) shifter = torch.from_numpy(shifter) return image*shifter shift_im = fftshift(conv_gray_im) fft_op = shift_fft_conv.squeeze() |
原图-
空间卷积输出-
谁能解释一下这个问题?
您的代码的主要问题是 Torch 不处理复数,其 FFT 的输出是一个 3D 数组,第 3 维有两个值,一个用于实部,一个用于虚部。因此,乘法不会进行复数乘法。
目前在 Torch 中没有定义复数乘法(参见本期),我们必须自己定义。
一个小问题,但如果你想比较两个卷积操作也很重要,如下:
FFT 在第一个元素(图像的左上角像素)中获取其输入的原点。为避免输出偏移,您需要生成一个填充内核,其中内核的原点是左上角的像素。这很棘手,实际上…
您当前的代码:
1
2 3 |
fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])
conv_fil = fil.unsqueeze(0).unsqueeze(0) padded_fil = F.pad(conv_fil, (0, gray_im.shape[0]-fil.shape[0], 0, gray_im.shape[1]-fil.shape[1])) |
生成一个填充内核,其中原点以像素 (1,1) 为单位,而不是 (0,0)。它需要在每个方向上移动一个像素。 NumPy 有一个函数
1
2 3 4 5 |
fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])
padded_fil = fil.unsqueeze(0).unsqueeze(0).numpy() padded_fil = np.pad(padded_fil, ((0, gray_im.shape[0]-fil.shape[0]), (0, gray_im.shape[1]-fil.shape[1]))) padded_fil = np.roll(padded_fil, –1, axis=(0, 1)) padded_fil = torch.from_numpy(padded_fil) |
最后,应用于空间域图像的
把这些东西放在一起,现在的卷积是:
1
2 3 4 5 6 7 8 |
def complex_multiplication(t1, t2): real1, imag1 = t1[:,:,0], t1[:,:,1] real2, imag2 = t2[:,:,0], t2[:,:,1] return torch.stack([real1 * real2 – imag1 * imag2, real1 * imag2 + imag1 * real2], dim = –1) fft_im = torch.rfft(gray_im, 2, onesided=False) |
请注意,您可以进行单边 FFT 以节省一点计算时间:
1
2 3 |
fft_im = torch.rfft(gray_im, 2, onesided=True)
fft_fil = torch.rfft(padded_fil, 2, onesided=True) fft_conv = torch.irfft(complex_multiplication(fft_im, fft_fil), 2, onesided=True, signal_sizes=gray_im.shape) |
这里的频域大小大约是完整 FFT 的一半,但它只是省略了冗余部分。卷积的结果不变。
原创文章,作者:ItWorker,如若转载,请注明出处:https://blog.ytso.com/267911.html