Skip to content

add ssim observer #50

Open
wdwzyyg wants to merge 6 commits intohexane360:developfrom
wdwzyyg:develop
Open

add ssim observer #50
wdwzyyg wants to merge 6 commits intohexane360:developfrom
wdwzyyg:develop

Conversation

@wdwzyyg
Copy link
Copy Markdown

@wdwzyyg wdwzyyg commented Apr 10, 2026

The SSIMObserver tracks the structural similarity (SSIM) of the object phase and probe intensity using multi-scale SSIM. The metric is computed as the average SSIM over three pyramid downsampling levels and compares the current state with that from a user-specified number of iterations earlier.

Also, the PatienceObserver can observe obj_ssim, probe_ssim, and loss jointly and will trigger termination whichever reaches the patience limit first.

Example set up in the YAML file:
calc_ssim: {every: 10, after: 10} # same type of flag as save options

early_termination_loss: 10 # note that the previous naming 'early_termination' is disabled.

early_termination_obj_ssim: 40

early_termination_probe_ssim: null

Showcase of an experimental reconstruction using calc_ssim: {every: 10}

Screenshot 2026-04-10 at 9 26 10 AM Screenshot 2026-04-10 at 9 26 30 AM

(the actual iter counts are saved in progress.obj_ssim.iters. e.g. here the value of progress.obj_ssim.values at index 0 corresponds to iter 20 v.s. iter 10. )

Copy link
Copy Markdown
Owner

@hexane360 hexane360 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like your rebase didn't really work, you overwrote many changes to the develop branch. I made a couple other comments.

The structure of the changes seem mostly fine. The naming does seem pretty confusing. First is this is multi-scale SSIM not regular SSIM. Secondly, the relative to past iterations is potentially confusing, and it'll make it hard to add in relative to ground truth measurements later. Would names like 'obj_rel_msssim' and 'probe_rel_msssim' be less ambiguous?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like these changes were mistakenly reverted from the upstream.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these too

Comment thread phaser/hooks/solver.py
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Comment thread phaser/utils/analysis.py Outdated
Comment on lines +300 to +317
def _resample_to_shape(im, target_shape: t.Tuple[int, ...], xp: t.Any):
"""Resample im to target_shape, staying on the input device. first dimension untouched"""
xp_name = getattr(xp, '__name__', '')

if xp_name == 'numpy':
from scipy.ndimage import zoom
zoom_factors = tuple(s1 / s2 for s1, s2 in zip(target_shape, im.shape))
return zoom(im, zoom_factors, order=1)

if 'cupy' in xp_name:
from cupyx.scipy.ndimage import zoom
zoom_factors = tuple(s1 / s2 for s1, s2 in zip(target_shape, im.shape))
return zoom(im, zoom_factors, order=1)

# JAX
import jax.image
return jax.image.resize(im, target_shape, method='linear')

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this stuff returns the same results on each backend? affine_transform in phaser.utils.image is correct on all backends, you could just add a helper there for zoom which calls to affine_transform. This also isn't the preferred way to test for backends, you should only need the helper functions in phaser.utils.num

Comment thread phaser/state.py
Comment on lines -246 to -252
def to_xp(self, xp: t.Any) -> Self:
return self.__class__(
self.patterns,
self.state.to_xp(xp),
self.name, self.observer,
)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more overwritten changes

Comment thread phaser/types.py
return self.inner.collect_errors(val)


__all__ = [
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more overwritten changes

Comment thread pyproject.toml Outdated
"tifffile>=2023.8.25",
"optree>=0.13.0",
"py-pane==0.11.4",
"py-pane==0.11.3",
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overwritten change

Comment thread pyproject.toml Outdated
dependencies = [
"numpy>=2.0,<2.6", # tested on 2.3
"scipy>=1.7.0,<1.19", # tested on 1.11, 1.16
"scikit-image>=0.19.0",
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? you don't seem to import it

Comment thread phaser/observer.py Outdated
if self.no_improvement_iter >= self.patience:
logging.info(f"Early termination: no improvement for {self.patience} iterations")
raise EarlyTermination(state, self.continue_next_engine)
self._prev_state = deepcopy(state)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should take this off the GPU as well. Really you can take off the GPU for SSIM calculation there's no reason this has to be super fast

Comment thread phaser/utils/analysis.py
mssim : float
MS-SSIM value in [0, 1].
"""
from phaser.utils.image import affine_transform as _affine_transform
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Comment thread phaser/utils/analysis.py
Comment on lines +384 to +389
def _resample(im, target_shape):
scale_y = im.shape[-2] / target_shape[-2]
scale_x = im.shape[-1] / target_shape[-1]
matrix = numpy.array([[scale_y, 0.0], [0.0, scale_x]])
offset = numpy.array([0.5 * (scale_y - 1.0), 0.5 * (scale_x - 1.0)])
return _affine_transform(im, matrix, offset=offset, output_shape=target_shape[-2:], order=1)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you test this thoroughly? it's easy to get off by 1 or off by 1/2 pixel errors here

Comment thread phaser/utils/analysis.py
Comment on lines +311 to +322
xp_name = getattr(xp, '__name__', '')
sizes = [1] * (im.ndim - 2) + [size, size]

if xp_name == 'numpy':
from scipy.ndimage import uniform_filter
return uniform_filter(im, sizes)

if 'cupy' in xp_name:
from cupyx.scipy.ndimage import uniform_filter
return uniform_filter(im, sizes)

# JAX or other: cumsum box filter along axes -2 and -1 only (XLA-friendly)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still not the correct way to do dispatch.
The following should work:

if xp_is_jax(xp):
    # jax implementation

# numpy/cupy implementation
scipy = get_scipy_module(im)
return scipy.ndimage.uniform_filter(im, sizes)

And you need to make sure that the implementation returns the exact same numeric result as the scipy implementation.

Comment thread phaser/observer.py
Comment on lines +289 to +290
obj_now = to_numpy(xp.angle(state.object.data))
probe_now = to_numpy(xp.abs(state.probe.data))
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to be doing just the phase of object and amplitude of probe? Have you compared the approaches?

Comment thread phaser/observer.py
Comment on lines +244 to +249
iters_without_improvement = current_iter - self._last_improvement_iter.get(key, current_iter)
if iters_without_improvement >= patience:
logging.info(
f"Early termination: {key} no improvement for {iters_without_improvement} iterations"
)
raise EarlyTermination(state, self.continue_next_engine)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we handle the following case?

early_termination_obj_rel_msssim: 5
calc_rel_msssim: {'every': 10}

Right now it looks like it'll basically ignore patience; maybe an alternate way to specify is that the patience number is the number of evaluations, not the number of iterations

@hexane360
Copy link
Copy Markdown
Owner

Also, I had one other thought: Do we really want to early terminate when the SSIM doesn't change much, or when the SSIM goes above or below a given threshold? The former seems like taking the derivative of the derivative. For instance, if the SSIM is constant at 0.9, that means a constant change in the reconstruction (i.e. constant improvement). But the current implementation would have that terminate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants