Dear all,
I was writing some custom python code to generate a BW PSF.
However, the PSF generate using my code is not consistent with that generated by PSF Generator in Fiji. I do not know which part of my code is wrong. Could you give me a hand of it?
Below are the my code: (you can reproduce my result with python)
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import skimage.io as io
class BWModel(nn.Module):
def __init__(
self,
kernel_size=[25, 25, 25],
kernel_norm=True,
num_integral=100,
over_sampling=2,
pixel_size_z=1, # * pixel_size_xy
):
super().__init__()
integral = torch.linspace(start=0, end=1, steps=num_integral + 1)
self.register_buffer("integral", integral)
self.dx = 1 / num_integral
self.kernel_size = torch.tensor(kernel_size)
self.kernel_norm = kernel_norm
self.Nz, Ny, Nx = self.kernel_size
self.pixel_size_z = torch.tensor(pixel_size_z)
# xy plane
# center point position
yp, xp = (Ny - 1) / 2, (Nx - 1) / 2
max_anchor = torch.ceil(torch.sqrt(((Nx - 1) - xp) ** 2 + ((Ny - 1) - yp) ** 2))
# additional one anchor for poistion 0.
self.num_anchor = int(max_anchor * over_sampling + 1)
# ----------------------------------------------------------------------
rAnchor = (
torch.linspace(
start=0, end=max_anchor * over_sampling, steps=self.num_anchor
)
/ over_sampling
)
rAnchor = rAnchor[None].repeat(self.Nz // 2 + 1, 1)
index_slice_half = torch.linspace(
start=0, end=self.Nz // 2, steps=self.Nz // 2 + 1
)[..., None]
index_slice_half = index_slice_half.repeat(1, self.num_anchor).type(torch.int)
rAnchor_flat = torch.reshape(rAnchor, shape=(-1, 1))
index_slice_flat = torch.reshape(index_slice_half, shape=(-1, 1))
self.register_buffer("rAnchor_flat", rAnchor_flat)
self.register_buffer("index_slice_flat", index_slice_flat)
# ----------------------------------------------------------------------
# rotation
R = (
torch.linspace(
start=0, end=max_anchor * over_sampling, steps=self.num_anchor
)
/ over_sampling
)
gridy = torch.linspace(start=0, end=Ny - 1, steps=Ny)
gridx = torch.linspace(start=0, end=Nx - 1, steps=Nx)
Y, X = torch.meshgrid(gridy, gridx)
rPixel = torch.sqrt((X - xp) ** 2 + (Y - yp) ** 2)
index = torch.floor(rPixel * over_sampling).type(torch.int)
index = index[None].repeat(self.Nz, 1, 1)
index_slice = torch.linspace(start=0, end=self.Nz - 1, steps=self.Nz)[
..., None, None
]
index_slice = index_slice.repeat(1, Ny, Nx).type(torch.int)
self.register_buffer("index_slice", index_slice)
self.register_buffer("index1", index)
disR = (rPixel - R[index]) * over_sampling
self.register_buffer("disR_1", disR)
self.register_buffer("disR_2", 1.0 - disR)
self.register_buffer("index2", index + 1)
def jep(self, lam, n, r, z, rho):
# n = NA / ni
z = z * self.pixel_size_z
k = 2 * torch.pi / lam
j0 = torch.special.bessel_j0(k * n * r * rho)
comp = torch.complex(torch.tensor(0.0), torch.tensor(1.0))
exp = torch.exp(-0.5 * comp * (k * (rho**2) * z * (n**2)))
return j0 * exp * rho
def get_num_params(self):
return 2
def forward(self, params):
params = torch.abs(params)
num_batch, num_channel, _ = params.shape
lam = params[..., 0][..., None, None]
n = params[..., 1][..., None, None]
sample = self.jep(
lam=lam,
n=n,
r=self.rAnchor_flat,
z=self.index_slice_flat,
rho=self.integral,
)
plane = torch.trapezoid(sample, dx=self.dx, dim=-1)
plane = torch.square(torch.abs(plane))
plane = torch.reshape(
plane,
shape=(num_batch, num_channel, self.Nz // 2 + 1, self.num_anchor),
)
# ----------------------------------------------------------------------
plane = torch.nn.functional.pad(
plane, pad=[0, 0, self.Nz // 2, 0], mode="reflect"
)
# linear interpolation
PSF = (
plane[:, :, self.index_slice, self.index2] * self.disR_1
+ plane[:, :, self.index_slice, self.index1] * self.disR_2
)
# normalization
if self.kernel_norm == True:
PSF = torch.div(PSF, torch.sum(PSF, dim=(2, 3, 4), keepdim=True))
return PSF
gen = BWModel(
kernel_size=(127, 127, 127),
kernel_norm=True,
num_integral=100,
over_sampling=2,
pixel_size_z=1,
)
# load PSF generated by Fiji PSF Generator plugin
# set the wavelength to be 600 here
psf_600_gt = io.imread(
"PSF BW.tif"
)
# generate PSF using above custom code
# when input wave length is 600
params_600 = torch.tensor([600 / 100, 1.4 / 1.5])[None, None]
psf_600 = gen(params_600)[0, 0]
# when input wave length is 400
params_400 = torch.tensor([400 / 100, 1.4 / 1.5])[None, None]
psf_400 = gen(params_400)[0, 0]
# normalize
psf_600_gt = psf_600_gt / psf_600_gt.max()
psf_600 = psf_600 / psf_600.max()
psf_400 = psf_400 / psf_400.max()
# show curve
plt.plot(psf_600_gt[48:78, 63, 63], label="600 (Fiji)", color="red")
plt.plot(psf_600[48:78, 63, 63], label="600 (my code)", color="blue")
plt.plot(psf_400[48:78, 63, 63], "--", label="400 (my code)", color="green")
plt.legend()
plt.savefig("tmp")
The PSF BW.tif file is generated form Fiji PSF Generator, as following
When the wavelength is input 600 nm (blue in figure), the result is different from that from FIJI PSF Generator (red in figure), however, when i input 400 nm to my code, the result (green dots) is consistent with it.