From d694a092b08f2e817e84c81d1dc67e5f87a3bfe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20C=C3=A1rdenas?= Date: Fri, 13 Sep 2024 23:28:33 -0500 Subject: [PATCH] feat: #410 add `groundtruth` functionality for comparison - added support for `groundtruth` image comparison in `iterative_recon_alg`. - implemented `groundtruth` handling in error measurement and initialization. - updated documentation to reflect new parameter and its usage. --- Python/tigre/algorithms/iterative_recon_alg.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/Python/tigre/algorithms/iterative_recon_alg.py b/Python/tigre/algorithms/iterative_recon_alg.py index b882b97b..5ff501e7 100644 --- a/Python/tigre/algorithms/iterative_recon_alg.py +++ b/Python/tigre/algorithms/iterative_recon_alg.py @@ -92,6 +92,10 @@ class IterativeReconAlg(object): OS_SART_TV FISTA + :keyword groundtruth: (np.ndarray, optional) + Ground truth image for comparison with the reconstruction. + Default is None. + Usage -------- >>> import numpy as np @@ -134,6 +138,8 @@ def __init__(self, proj, geo, angles, niter, **kwargs): self.geo = geo self.niter = niter + self.groundtruth = kwargs.get("groundtruth", None) + self.geo.check_geo(angles) options = dict( @@ -167,6 +173,7 @@ def __init__(self, proj, geo, angles, niter, **kwargs): "regularisation", "tviter", "tvlambda", + "groundtruth", "hyper", "fista_p", "fista_q", @@ -203,6 +210,9 @@ def __init__(self, proj, geo, angles, niter, **kwargs): self.Quameasopts = ( [self.Quameasopts] if isinstance(self.Quameasopts, str) else self.Quameasopts ) + if self.groundtruth is not None: + if "error_norm" not in self.Quameasopts: + self.Quameasopts.append("error_norm") setattr(self, "lq", np.zeros([len(self.Quameasopts), niter])) # quameasoptslist else: setattr(self, "lq", np.zeros([0, niter])) # quameasoptslist @@ -356,8 +366,13 @@ def minimizeAwTV(self, res_prev, dtvg): return AwminTV(res_prev, dtvg, self.numiter_tv, self.delta, self.gpuids) def error_measurement(self, res_prev, iter): + if self.groundtruth is not None: + comparison_img = self.groundtruth + else: + comparison_img = res_prev + if self.Quameasopts is not None: - self.lq[:, iter] = MQ(self.res, res_prev, self.Quameasopts) + self.lq[:, iter] = MQ(self.res, comparison_img, self.Quameasopts) if self.computel2: # compute l2 borm for b-Ax errornow = im3DNORM(