Convolution

Comparison of:

- scipy_fft
- jax_fft
- scipy_convolve2d
[2]:
import numpy as np
import matplotlib.pyplot as plt
from wolfhece.wolf_array import header_wolf, WolfArray

from datetime import datetime as dt, timedelta
[3]:
def set_a_f(n, n_filter):
    h = header_wolf()
    h.shape = (n, n)
    h.set_resolution(0.1, 0.1)

    h_filter = header_wolf()
    h_filter.shape = (n_filter, n_filter)
    h_filter.set_resolution(0.1, 0.1)

    a = WolfArray(srcheader=h)
    f = WolfArray(srcheader=h_filter)

    return a, f
[4]:
to_test = [(100, 10),
           (200, 20),
           (500, 50),
           (1000, 100),
           (2000, 200),
           (5000, 500),
        #    (10000, 1000),
           ]

all = [set_a_f(n, n_filter) for n, n_filter in to_test]
[ ]:
for a, f in all:
    fig, axes= plt.subplots(1, 2, figsize=(10, 5))
    n = a.shape[0]
    a.array[n//2-n//4:n//2+n//4,n//2-n//4:n//2+n//4] = 2.0
    a.plot_matplotlib((fig, axes[0]))
    # make a circular filter with exponential falloff
    n_filter = f.shape[0]
    f.array[:,:] = np.asarray([np.exp(-((i-n_filter/2)**2+(j-n_filter/2)**2)/n_filter**2) for i in range(n_filter) for j in range(n_filter)]).reshape((n_filter,n_filter))
    f.plot_matplotlib((fig, axes[1]))
    plt.show()
../_images/tutorials_convolve_4_0.png
../_images/tutorials_convolve_4_1.png
../_images/tutorials_convolve_4_2.png
../_images/tutorials_convolve_4_3.png
../_images/tutorials_convolve_4_4.png
../_images/tutorials_convolve_4_5.png
[ ]:
for a, f in all:
      fig, axes = plt.subplots(1, 3, figsize=(15, 5))

      # scipy fft
      start = dt.now()
      scipy_fft = a.convolve(f, method='scipyfft', inplace=False)
      scipy_runtime = dt.now() - start
      scipy_fft.plot_matplotlib((fig, axes[0]))
      axes[0].set_title('scipy fft')

      # jax fft
      start = dt.now()
      jax_fft = a.convolve(f, method='jaxfft', inplace=False)
      jax_runtime = dt.now() - start
      jax_fft.plot_matplotlib((fig, axes[1]))
      axes[1].set_title('jax fft')

      if a.shape[0] <= 500:
            #classic == scipy.convolve2d
            start = dt.now()
            classic = a.convolve(f, method='classic', inplace=False)
            classic_runtime = dt.now() - start
            classic.plot_matplotlib((fig, axes[2]))
            axes[2].set_title('classic')
      else:
            classic_runtime = timedelta(0)
            axes[2].axis('off')

      plt.figure(figsize=(10, 5))
      plt.bar([0,1,2], [scipy_runtime.total_seconds(),
                        jax_runtime.total_seconds(),
                        classic_runtime.total_seconds()],
                  tick_label=['scipy', 'jax', 'classic'])
      plt.show()
../_images/tutorials_convolve_5_0.png
../_images/tutorials_convolve_5_1.png
../_images/tutorials_convolve_5_2.png
../_images/tutorials_convolve_5_3.png
../_images/tutorials_convolve_5_4.png
../_images/tutorials_convolve_5_5.png
../_images/tutorials_convolve_5_6.png
../_images/tutorials_convolve_5_7.png
../_images/tutorials_convolve_5_8.png
../_images/tutorials_convolve_5_9.png
../_images/tutorials_convolve_5_10.png
../_images/tutorials_convolve_5_11.png