Convolution
Comparison of:
- scipy_fft
- jax_fft
- scipy_convolve2d
[2]:
import numpy as np
import matplotlib.pyplot as plt
from wolfhece.wolf_array import header_wolf, WolfArray
from datetime import datetime as dt, timedelta
[3]:
def set_a_f(n, n_filter):
h = header_wolf()
h.shape = (n, n)
h.set_resolution(0.1, 0.1)
h_filter = header_wolf()
h_filter.shape = (n_filter, n_filter)
h_filter.set_resolution(0.1, 0.1)
a = WolfArray(srcheader=h)
f = WolfArray(srcheader=h_filter)
return a, f
[4]:
to_test = [(100, 10),
(200, 20),
(500, 50),
(1000, 100),
(2000, 200),
(5000, 500),
# (10000, 1000),
]
all = [set_a_f(n, n_filter) for n, n_filter in to_test]
[ ]:
for a, f in all:
fig, axes= plt.subplots(1, 2, figsize=(10, 5))
n = a.shape[0]
a.array[n//2-n//4:n//2+n//4,n//2-n//4:n//2+n//4] = 2.0
a.plot_matplotlib((fig, axes[0]))
# make a circular filter with exponential falloff
n_filter = f.shape[0]
f.array[:,:] = np.asarray([np.exp(-((i-n_filter/2)**2+(j-n_filter/2)**2)/n_filter**2) for i in range(n_filter) for j in range(n_filter)]).reshape((n_filter,n_filter))
f.plot_matplotlib((fig, axes[1]))
plt.show()






[ ]:
for a, f in all:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# scipy fft
start = dt.now()
scipy_fft = a.convolve(f, method='scipyfft', inplace=False)
scipy_runtime = dt.now() - start
scipy_fft.plot_matplotlib((fig, axes[0]))
axes[0].set_title('scipy fft')
# jax fft
start = dt.now()
jax_fft = a.convolve(f, method='jaxfft', inplace=False)
jax_runtime = dt.now() - start
jax_fft.plot_matplotlib((fig, axes[1]))
axes[1].set_title('jax fft')
if a.shape[0] <= 500:
#classic == scipy.convolve2d
start = dt.now()
classic = a.convolve(f, method='classic', inplace=False)
classic_runtime = dt.now() - start
classic.plot_matplotlib((fig, axes[2]))
axes[2].set_title('classic')
else:
classic_runtime = timedelta(0)
axes[2].axis('off')
plt.figure(figsize=(10, 5))
plt.bar([0,1,2], [scipy_runtime.total_seconds(),
jax_runtime.total_seconds(),
classic_runtime.total_seconds()],
tick_label=['scipy', 'jax', 'classic'])
plt.show()











