add ssim observer #50
Conversation
hexane360
left a comment
There was a problem hiding this comment.
It looks like your rebase didn't really work, you overwrote many changes to the develop branch. I made a couple other comments.
The structure of the changes seem mostly fine. The naming does seem pretty confusing. First is this is multi-scale SSIM not regular SSIM. Secondly, the relative to past iterations is potentially confusing, and it'll make it hard to add in relative to ground truth measurements later. Would names like 'obj_rel_msssim' and 'probe_rel_msssim' be less ambiguous?
There was a problem hiding this comment.
It looks like these changes were mistakenly reverted from the upstream.
| def _resample_to_shape(im, target_shape: t.Tuple[int, ...], xp: t.Any): | ||
| """Resample im to target_shape, staying on the input device. first dimension untouched""" | ||
| xp_name = getattr(xp, '__name__', '') | ||
|
|
||
| if xp_name == 'numpy': | ||
| from scipy.ndimage import zoom | ||
| zoom_factors = tuple(s1 / s2 for s1, s2 in zip(target_shape, im.shape)) | ||
| return zoom(im, zoom_factors, order=1) | ||
|
|
||
| if 'cupy' in xp_name: | ||
| from cupyx.scipy.ndimage import zoom | ||
| zoom_factors = tuple(s1 / s2 for s1, s2 in zip(target_shape, im.shape)) | ||
| return zoom(im, zoom_factors, order=1) | ||
|
|
||
| # JAX | ||
| import jax.image | ||
| return jax.image.resize(im, target_shape, method='linear') | ||
|
|
There was a problem hiding this comment.
Have you tested this stuff returns the same results on each backend? affine_transform in phaser.utils.image is correct on all backends, you could just add a helper there for zoom which calls to affine_transform. This also isn't the preferred way to test for backends, you should only need the helper functions in phaser.utils.num
| def to_xp(self, xp: t.Any) -> Self: | ||
| return self.__class__( | ||
| self.patterns, | ||
| self.state.to_xp(xp), | ||
| self.name, self.observer, | ||
| ) | ||
|
|
| return self.inner.collect_errors(val) | ||
|
|
||
|
|
||
| __all__ = [ |
| "tifffile>=2023.8.25", | ||
| "optree>=0.13.0", | ||
| "py-pane==0.11.4", | ||
| "py-pane==0.11.3", |
| dependencies = [ | ||
| "numpy>=2.0,<2.6", # tested on 2.3 | ||
| "scipy>=1.7.0,<1.19", # tested on 1.11, 1.16 | ||
| "scikit-image>=0.19.0", |
There was a problem hiding this comment.
is this needed? you don't seem to import it
| if self.no_improvement_iter >= self.patience: | ||
| logging.info(f"Early termination: no improvement for {self.patience} iterations") | ||
| raise EarlyTermination(state, self.continue_next_engine) | ||
| self._prev_state = deepcopy(state) |
There was a problem hiding this comment.
should take this off the GPU as well. Really you can take off the GPU for SSIM calculation there's no reason this has to be super fast
| mssim : float | ||
| MS-SSIM value in [0, 1]. | ||
| """ | ||
| from phaser.utils.image import affine_transform as _affine_transform |
| def _resample(im, target_shape): | ||
| scale_y = im.shape[-2] / target_shape[-2] | ||
| scale_x = im.shape[-1] / target_shape[-1] | ||
| matrix = numpy.array([[scale_y, 0.0], [0.0, scale_x]]) | ||
| offset = numpy.array([0.5 * (scale_y - 1.0), 0.5 * (scale_x - 1.0)]) | ||
| return _affine_transform(im, matrix, offset=offset, output_shape=target_shape[-2:], order=1) |
There was a problem hiding this comment.
did you test this thoroughly? it's easy to get off by 1 or off by 1/2 pixel errors here
| xp_name = getattr(xp, '__name__', '') | ||
| sizes = [1] * (im.ndim - 2) + [size, size] | ||
|
|
||
| if xp_name == 'numpy': | ||
| from scipy.ndimage import uniform_filter | ||
| return uniform_filter(im, sizes) | ||
|
|
||
| if 'cupy' in xp_name: | ||
| from cupyx.scipy.ndimage import uniform_filter | ||
| return uniform_filter(im, sizes) | ||
|
|
||
| # JAX or other: cumsum box filter along axes -2 and -1 only (XLA-friendly) |
There was a problem hiding this comment.
This is still not the correct way to do dispatch.
The following should work:
if xp_is_jax(xp):
# jax implementation
# numpy/cupy implementation
scipy = get_scipy_module(im)
return scipy.ndimage.uniform_filter(im, sizes)And you need to make sure that the implementation returns the exact same numeric result as the scipy implementation.
| obj_now = to_numpy(xp.angle(state.object.data)) | ||
| probe_now = to_numpy(xp.abs(state.probe.data)) |
There was a problem hiding this comment.
Do you want to be doing just the phase of object and amplitude of probe? Have you compared the approaches?
| iters_without_improvement = current_iter - self._last_improvement_iter.get(key, current_iter) | ||
| if iters_without_improvement >= patience: | ||
| logging.info( | ||
| f"Early termination: {key} no improvement for {iters_without_improvement} iterations" | ||
| ) | ||
| raise EarlyTermination(state, self.continue_next_engine) |
There was a problem hiding this comment.
How should we handle the following case?
early_termination_obj_rel_msssim: 5
calc_rel_msssim: {'every': 10}Right now it looks like it'll basically ignore patience; maybe an alternate way to specify is that the patience number is the number of evaluations, not the number of iterations
|
Also, I had one other thought: Do we really want to early terminate when the SSIM doesn't change much, or when the SSIM goes above or below a given threshold? The former seems like taking the derivative of the derivative. For instance, if the SSIM is constant at 0.9, that means a constant change in the reconstruction (i.e. constant improvement). But the current implementation would have that terminate. |
The SSIMObserver tracks the structural similarity (SSIM) of the object phase and probe intensity using multi-scale SSIM. The metric is computed as the average SSIM over three pyramid downsampling levels and compares the current state with that from a user-specified number of iterations earlier.
Also, the PatienceObserver can observe obj_ssim, probe_ssim, and loss jointly and will trigger termination whichever reaches the patience limit first.
Example set up in the YAML file:
calc_ssim: {every: 10, after: 10}# same type of flag as save optionsearly_termination_loss: 10# note that the previous naming 'early_termination' is disabled.early_termination_obj_ssim: 40early_termination_probe_ssim: nullShowcase of an experimental reconstruction using calc_ssim: {every: 10}
(the actual iter counts are saved in progress.obj_ssim.iters. e.g. here the value of progress.obj_ssim.values at index 0 corresponds to iter 20 v.s. iter 10. )