Trouble when generating a Born and Wolf PSF model

Dear all,
I was writing some custom python code to generate a BW PSF.
However, the PSF generate using my code is not consistent with that generated by PSF Generator in Fiji. I do not know which part of my code is wrong. Could you give me a hand of it?

Below are the my code: (you can reproduce my result with python)

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import skimage.io as io


class BWModel(nn.Module):
    def __init__(
        self,
        kernel_size=[25, 25, 25],
        kernel_norm=True,
        num_integral=100,
        over_sampling=2,
        pixel_size_z=1,  # * pixel_size_xy
    ):
        super().__init__()
        integral = torch.linspace(start=0, end=1, steps=num_integral + 1)
        self.register_buffer("integral", integral)

        self.dx = 1 / num_integral
        self.kernel_size = torch.tensor(kernel_size)
        self.kernel_norm = kernel_norm
        self.Nz, Ny, Nx = self.kernel_size
        self.pixel_size_z = torch.tensor(pixel_size_z)

        # xy plane
        # center point position
        yp, xp = (Ny - 1) / 2, (Nx - 1) / 2
        max_anchor = torch.ceil(torch.sqrt(((Nx - 1) - xp) ** 2 + ((Ny - 1) - yp) ** 2))

        # additional one anchor for poistion 0.
        self.num_anchor = int(max_anchor * over_sampling + 1)

        # ----------------------------------------------------------------------
        rAnchor = (
            torch.linspace(
                start=0, end=max_anchor * over_sampling, steps=self.num_anchor
            )
            / over_sampling
        )
        rAnchor = rAnchor[None].repeat(self.Nz // 2 + 1, 1)
        index_slice_half = torch.linspace(
            start=0, end=self.Nz // 2, steps=self.Nz // 2 + 1
        )[..., None]
        index_slice_half = index_slice_half.repeat(1, self.num_anchor).type(torch.int)

        rAnchor_flat = torch.reshape(rAnchor, shape=(-1, 1))
        index_slice_flat = torch.reshape(index_slice_half, shape=(-1, 1))

        self.register_buffer("rAnchor_flat", rAnchor_flat)
        self.register_buffer("index_slice_flat", index_slice_flat)

        # ----------------------------------------------------------------------
        # rotation
        R = (
            torch.linspace(
                start=0, end=max_anchor * over_sampling, steps=self.num_anchor
            )
            / over_sampling
        )
        gridy = torch.linspace(start=0, end=Ny - 1, steps=Ny)
        gridx = torch.linspace(start=0, end=Nx - 1, steps=Nx)
        Y, X = torch.meshgrid(gridy, gridx)

        rPixel = torch.sqrt((X - xp) ** 2 + (Y - yp) ** 2)
        index = torch.floor(rPixel * over_sampling).type(torch.int)
        index = index[None].repeat(self.Nz, 1, 1)
        index_slice = torch.linspace(start=0, end=self.Nz - 1, steps=self.Nz)[
            ..., None, None
        ]
        index_slice = index_slice.repeat(1, Ny, Nx).type(torch.int)

        self.register_buffer("index_slice", index_slice)
        self.register_buffer("index1", index)
        disR = (rPixel - R[index]) * over_sampling
        self.register_buffer("disR_1", disR)
        self.register_buffer("disR_2", 1.0 - disR)
        self.register_buffer("index2", index + 1)

    def jep(self, lam, n, r, z, rho):
        # n = NA / ni
        z = z * self.pixel_size_z
        k = 2 * torch.pi / lam
        j0 = torch.special.bessel_j0(k * n * r * rho)
        comp = torch.complex(torch.tensor(0.0), torch.tensor(1.0))
        exp = torch.exp(-0.5 * comp * (k * (rho**2) * z * (n**2)))
        return j0 * exp * rho

    def get_num_params(self):
        return 2

    def forward(self, params):
        params = torch.abs(params)
        num_batch, num_channel, _ = params.shape
        lam = params[..., 0][..., None, None]
        n = params[..., 1][..., None, None]

        sample = self.jep(
            lam=lam,
            n=n,
            r=self.rAnchor_flat,
            z=self.index_slice_flat,
            rho=self.integral,
        )

        plane = torch.trapezoid(sample, dx=self.dx, dim=-1)
        plane = torch.square(torch.abs(plane))
        plane = torch.reshape(
            plane,
            shape=(num_batch, num_channel, self.Nz // 2 + 1, self.num_anchor),
        )

        # ----------------------------------------------------------------------
        plane = torch.nn.functional.pad(
            plane, pad=[0, 0, self.Nz // 2, 0], mode="reflect"
        )

        # linear interpolation
        PSF = (
            plane[:, :, self.index_slice, self.index2] * self.disR_1
            + plane[:, :, self.index_slice, self.index1] * self.disR_2
        )

        # normalization
        if self.kernel_norm == True:
            PSF = torch.div(PSF, torch.sum(PSF, dim=(2, 3, 4), keepdim=True))

        return PSF


gen = BWModel(
    kernel_size=(127, 127, 127),
    kernel_norm=True,
    num_integral=100,
    over_sampling=2,
    pixel_size_z=1,
)


# load PSF generated by Fiji PSF Generator plugin
# set the wavelength to be 600 here
psf_600_gt = io.imread(
    "PSF BW.tif"
)

# generate PSF using above custom code
# when input wave length is 600
params_600 = torch.tensor([600 / 100, 1.4 / 1.5])[None, None]
psf_600 = gen(params_600)[0, 0]

# when input wave length is 400
params_400 = torch.tensor([400 / 100, 1.4 / 1.5])[None, None]
psf_400 = gen(params_400)[0, 0]

# normalize
psf_600_gt = psf_600_gt / psf_600_gt.max()
psf_600 = psf_600 / psf_600.max()
psf_400 = psf_400 / psf_400.max()

# show curve
plt.plot(psf_600_gt[48:78, 63, 63], label="600 (Fiji)", color="red")
plt.plot(psf_600[48:78, 63, 63], label="600 (my code)", color="blue")
plt.plot(psf_400[48:78, 63, 63], "--", label="400 (my code)", color="green")
plt.legend()
plt.savefig("tmp")

The PSF BW.tif file is generated form Fiji PSF Generator, as following

When the wavelength is input 600 nm (blue in figure), the result is different from that from FIJI PSF Generator (red in figure), however, when i input 400 nm to my code, the result (green dots) is consistent with it.
image

looks like a scaling factor is different somewhere (i guess that’s obvious :slight_smile:). I’m afraid I won’t have time soon to take a close look through your code to see where the differences might be, but I would mention that there are a number of existing python packages that do this, and you could look through them as well for comparison

since it looks like you want this as a torch model, you might also look through this module in microsim as well that tries to abstract the numpy API away as much as possible (allowing for various backends including jax, and cupy, and possibly torch, but i haven’t tested that as much). It should give you a very clear idea about exactly what methods you need torch to provide.

at the very least, you could use some of those for a “third” vote to see whether they are close to the Fiji generator

Thank you very much @talley, I will check the these packages.

Dear @talley, I think I find where the problem is.
I wrote my code according to the equation shown at BIG • PSF Generator
image
where the k=2*pi/lambda.
I think the lambda here refers to the wavelength in the immersion layer (here in my above example, is oil the ni of which is about 1.5), not the wavelength in vacuum, thus when the input is 400 for my code is equal to the results of 600 from PSF generator in Fiji. (600/400=1.5).
I try other values of refractive index immersion, and my custom code get consistent results with those from PSF generator.
Am I right?
(Sorry, I can not derive the BW model mathematically, which beyond my ability)