diff --git a/CPAC/longitudinal/wf/anat.py b/CPAC/longitudinal/wf/anat.py index a49386b1a9..192a50034f 100644 --- a/CPAC/longitudinal/wf/anat.py +++ b/CPAC/longitudinal/wf/anat.py @@ -18,9 +18,11 @@ """Longitudinal workflows for anatomical data.""" import os -from typing import Optional +from typing import cast, Optional +from nipype import config as nipype_config from nipype.interfaces import fsl +from nipype.interfaces.utility import Merge from CPAC.longitudinal.preproc import subject_specific_template from CPAC.pipeline import nipype_pipeline_engine as pe @@ -31,7 +33,7 @@ connect_pipeline, initialize_nipype_wf, ) -from CPAC.pipeline.engine import ingress_output_dir, initiate_rpool +from CPAC.pipeline.engine import ingress_output_dir, initiate_rpool, ResourcePool from CPAC.pipeline.nodeblock import nodeblock, NODEBLOCK_RETURN from CPAC.registration.registration import apply_transform from CPAC.utils.configuration import Configuration @@ -132,7 +134,8 @@ def mask_longitudinal_T1w_brain( ( "space-longitudinal_desc-brain_T1w", "from-longitudinal_to-template_mode-image_xfm", - ) + ), + "T1w-brain-template", ], outputs=["space-template_desc-brain_T1w"], ) @@ -169,7 +172,7 @@ def warp_longitudinal_T1w_to_template( node, out = strat_pool.get_data("space-longitudinal_desc-brain_T1w") wf.connect(node, out, apply_xfm, "inputspec.input_image") - node, out = strat_pool.get_data("T1w_brain_template") + node, out = strat_pool.get_data("T1w-brain-template") wf.connect(node, out, apply_xfm, "inputspec.reference") node, out = strat_pool.get_data("from-longitudinal_to-template_mode-image_xfm") @@ -188,7 +191,11 @@ def warp_longitudinal_T1w_to_template( option_val="C-PAC legacy", inputs=[ ( - "from-longitudinal_to-T1w_mode-image_desc-linear_xfm", + "space-longitudinal_desc-brain_T1w", + [ + "from-longitudinal_to-T1w_mode-image_desc-linear_xfm", + "from-T1w_to-longitudinal_mode-image_desc-linear_xfm", + ], "space-longitudinal_label-CSF_mask", "space-longitudinal_label-GM_mask", "space-longitudinal_label-WM_mask", @@ -198,7 +205,8 @@ def warp_longitudinal_T1w_to_template( "space-longitudinal_label-CSF_probseg", "space-longitudinal_label-GM_probseg", "space-longitudinal_label-WM_probseg", - ) + ), + "T1w-brain-template", ], outputs=[ "label-CSF_mask", @@ -213,14 +221,35 @@ def warp_longitudinal_T1w_to_template( ], ) def warp_longitudinal_seg_to_T1w( - wf, cfg, strat_pool, pipe_num, opt=None + wf: pe.Workflow, + cfg: Configuration, + strat_pool: ResourcePool, + pipe_num: int, + opt: Optional[str] = None, ) -> NODEBLOCK_RETURN: """Transform anatomical images from longitudinal space template space.""" - xfm_prov = strat_pool.get_cpac_provenance( - "from-longitudinal_to-T1w_mode-image_desc-linear_xfm" - ) - reg_tool = check_prov_for_regtool(xfm_prov) - + if strat_pool.check_rpool("from-longitudinal_to-T1w_mode-image_desc-linear_xfm"): + xfm_prov = strat_pool.get_cpac_provenance( + "from-longitudinal_to-T1w_mode-image_desc-linear_xfm" + ) + reg_tool = check_prov_for_regtool(xfm_prov) + xfm: tuple[pe.Node, str] = strat_pool.get_data( + "from-longitudinal_to-T1w_mode-image_desc-linear_xfm" + ) + else: + xfm_prov = strat_pool.get_cpac_provenance( + "from-T1w_to-longitudinal_mode-image_desc-linear_xfm" + ) + reg_tool = check_prov_for_regtool(xfm_prov) + # create inverse xfm if we don't have it + invt = pe.Node(interface=fsl.ConvertXFM(), name="convert_xfm") + invt.inputs.invert_xfm = True + wf.connect( + *strat_pool.get_data("from-T1w_to-longitudinal_mode-image_desc-linear_xfm"), + invt, + "in_file", + ) + xfm = (invt, "out_file") num_cpus = cfg.pipeline_setup["system_config"]["max_cores_per_participant"] num_ants_cores = cfg.pipeline_setup["system_config"]["num_ants_threads"] @@ -260,11 +289,10 @@ def warp_longitudinal_seg_to_T1w( node, out = strat_pool.get_data("space-longitudinal_desc-brain_T1w") wf.connect(node, out, apply_xfm, "inputspec.input_image") - node, out = strat_pool.get_data("T1w_brain_template") + node, out = strat_pool.get_data("T1w-brain-template") wf.connect(node, out, apply_xfm, "inputspec.reference") - node, out = strat_pool.get_data("from-longitudinal_to-template_mode-image_xfm") - wf.connect(node, out, apply_xfm, "inputspec.transform") + wf.connect(*xfm, apply_xfm, "inputspec.transform") outputs[f"label-{label}"] = (apply_xfm, "outputspec.output_image") @@ -272,7 +300,7 @@ def warp_longitudinal_seg_to_T1w( def anat_longitudinal_wf( - subject_id: str, sub_list: list[dict], config: Configuration + subject_id: str, sub_list: list[dict], config: Configuration, dry_run: bool = False ) -> None: """ Create and run longitudinal workflows for anatomical data. @@ -285,21 +313,37 @@ def anat_longitudinal_wf( a list of sessions for one subject config a Configuration object containing the information for the participant pipeline + dry_run + build graph without running? """ + nipype_config.update_config( + { + "execution": { + "crashfile_format": "txt", + "stop_on_first_crash": config[ + "pipeline_setup", "system_config", "fail_fast" + ], + } + } + ) config["subject_id"] = subject_id - session_id_list: list[list] = [] + session_id_list: list[str] = [] """List of lists for every strategy""" session_wfs = {} cpac_dirs = [] - out_dir = config.pipeline_setup["output_directory"]["path"] - - orig_pipe_name = config.pipeline_setup["pipeline_name"] - - # Loop over the sessions to create the input for the longitudinal - # algorithm - for session in sub_list: - unique_id = session["unique_id"] + out_dir: str = config.pipeline_setup["output_directory"]["path"] + + orig_pipe_name: str = config.pipeline_setup["pipeline_name"] + + strats_dct: dict[str, list[tuple[pe.Node, str] | str]] = { + "desc-brain_T1w": [], + "desc-head_T1w": [], + } + for i, session in enumerate(sub_list): + # Loop over the sessions to create the input for the longitudinal algorithm + unique_id: str = session["unique_id"] + unique_id: str = str(session.get("unique_id", i)) session_id_list.append(unique_id) try: @@ -319,13 +363,12 @@ def anat_longitudinal_wf( except KeyError: input_creds_path = None - workflow = initialize_nipype_wf( + workflow: pe.Workflow = initialize_nipype_wf( config, - sub_list[0], - # just grab the first one for the name - name="anat_longitudinal_pre-preproc", + session, + name=f"anat_longitudinal_pre-preproc_{unique_id}", ) - + rpool: ResourcePool workflow, rpool = initiate_rpool(workflow, config, session) pipeline_blocks = build_anat_preproc_stack(rpool, config) workflow = connect_pipeline(workflow, config, rpool, pipeline_blocks) @@ -334,158 +377,143 @@ def anat_longitudinal_wf( rpool.gather_pipes(workflow, config) - workflow.run() - - cpac_dir = os.path.join( - out_dir, f"pipeline_{orig_pipe_name}", f"{subject_id}_{unique_id}" - ) - cpac_dirs.append(os.path.join(cpac_dir, "anat")) - - # Now we have all the anat_preproc set up for every session - # loop over the different anat preproc strategies - strats_brain_dct = {} - strats_head_dct = {} - for cpac_dir in cpac_dirs: - if os.path.isdir(cpac_dir): - for filename in os.listdir(cpac_dir): - if "T1w.nii" in filename: - for tag in filename.split("_"): - if "desc-" in tag and "brain" in tag: - if tag not in strats_brain_dct: - strats_brain_dct[tag] = [] - strats_brain_dct[tag].append( - os.path.join(cpac_dir, filename) - ) - if tag not in strats_head_dct: - strats_head_dct[tag] = [] - head_file = filename.replace(tag, "desc-reorient") - strats_head_dct[tag].append( - os.path.join(cpac_dir, head_file) - ) - - for strat in strats_brain_dct.keys(): - wf = initialize_nipype_wf( - config, - sub_list[0], - # just grab the first one for the name - name=f"template_node_{strat}", - ) + for key in strats_dct.keys(): + strats_dct[key].append(cast(tuple[pe.Node, str], rpool.get_data(key))) + if not dry_run: + workflow.run() + for key in strats_dct.keys(): # get the outputs from run-nodes + for index, data in enumerate(list(strats_dct[key])): + if isinstance(data, tuple): + strats_dct[key][index] = workflow.get_output_path(*data) + + wf = initialize_nipype_wf( + config, + sub_list[0], + # just grab the first one for the name + name="template_node_brain", + ) - config.pipeline_setup["pipeline_name"] = f"longitudinal_{orig_pipe_name}" + config.pipeline_setup["pipeline_name"] = f"longitudinal_{orig_pipe_name}" + + template_node_name = "longitudinal_anat_template_brain" + + # This node will generate the longitudinal template (the functions are + # in longitudinal_preproc) + # Later other algorithms could be added to calculate it, like the + # multivariate template from ANTS + # It would just require to change it here. + template_node = subject_specific_template(workflow_name=template_node_name) + + template_node.inputs.set( + avg_method=config.longitudinal_template_generation["average_method"], + dof=config.longitudinal_template_generation["dof"], + interp=config.longitudinal_template_generation["interp"], + cost=config.longitudinal_template_generation["cost"], + convergence_threshold=config.longitudinal_template_generation[ + "convergence_threshold" + ], + thread_pool=config.longitudinal_template_generation["thread_pool"], + unique_id_list=list(session_wfs.keys()), + ) - template_node_name = f"longitudinal_anat_template_{strat}" - - # This node will generate the longitudinal template (the functions are - # in longitudinal_preproc) - # Later other algorithms could be added to calculate it, like the - # multivariate template from ANTS - # It would just require to change it here. - template_node = subject_specific_template(workflow_name=template_node_name) - - template_node.inputs.set( - avg_method=config.longitudinal_template_generation["average_method"], - dof=config.longitudinal_template_generation["dof"], - interp=config.longitudinal_template_generation["legacy-specific"]["interp"], - cost=config.longitudinal_template_generation["legacy-specific"]["cost"], - convergence_threshold=config.longitudinal_template_generation[ - "legacy-specific" - ]["convergence_threshold"], - max_iter=config.longitudinal_template_generation["max_iter"], - thread_pool=config.longitudinal_template_generation["legacy-specific"][ - "thread_pool" - ], - unique_id_list=list(session_wfs.keys()), - ) + num_sessions = len(strats_dct["desc-brain_T1w"]) + merge_brains = pe.Node(Merge(num_sessions), name="merge_brains") + merge_skulls = pe.Node(Merge(num_sessions), name="merge_skulls") - template_node.inputs.input_brain_list = strats_brain_dct[strat] - template_node.inputs.input_skull_list = strats_head_dct[strat] + for i in list(range(0, num_sessions)): + wf._connect_node_or_path(merge_brains, strats_dct, "desc-brain_T1w", i) + wf._connect_node_or_path(merge_skulls, strats_dct, "desc-head_T1w", i) + wf.connect(merge_brains, "out", template_node, "input_brain_list") + wf.connect(merge_skulls, "out", template_node, "input_skull_list") - long_id = f"longitudinal_{subject_id}_strat-{strat}" + long_id = f"longitudinal_{subject_id}_strat-desc-brain_T1w" - wf, rpool = initiate_rpool(wf, config, part_id=long_id) + wf, rpool = initiate_rpool(wf, config, part_id=long_id) - rpool.set_data( - "space-longitudinal_desc-brain_T1w", - template_node, - "brain_template", - {}, - "", - template_node_name, - ) + rpool.set_data( + "space-longitudinal_desc-brain_T1w", + template_node, + "brain_template", + {}, + "", + template_node_name, + ) - rpool.set_data( - "space-longitudinal_desc-brain_T1w-template", - template_node, - "brain_template", - {}, - "", - template_node_name, - ) + rpool.set_data( + "space-longitudinal_desc-brain_T1w-template", + template_node, + "brain_template", + {}, + "", + template_node_name, + ) - rpool.set_data( - "space-longitudinal_desc-reorient_T1w", - template_node, - "skull_template", - {}, - "", - template_node_name, - ) + rpool.set_data( + "space-longitudinal_desc-reorient_T1w", + template_node, + "skull_template", + {}, + "", + template_node_name, + ) - rpool.set_data( - "space-longitudinal_desc-reorient_T1w-template", - template_node, - "skull_template", - {}, - "", - template_node_name, - ) + rpool.set_data( + "space-longitudinal_desc-reorient_T1w-template", + template_node, + "skull_template", + {}, + "", + template_node_name, + ) - pipeline_blocks = [mask_longitudinal_T1w_brain] + pipeline_blocks = [mask_longitudinal_T1w_brain] - pipeline_blocks = build_T1w_registration_stack(rpool, config, pipeline_blocks) + pipeline_blocks = build_T1w_registration_stack( + rpool, config, pipeline_blocks, space="longitudinal" + ) - pipeline_blocks = build_segmentation_stack(rpool, config, pipeline_blocks) + pipeline_blocks = build_segmentation_stack(rpool, config, pipeline_blocks) - wf = connect_pipeline(wf, config, rpool, pipeline_blocks) + wf = connect_pipeline(wf, config, rpool, pipeline_blocks) - excl = [ - "space-longitudinal_desc-brain_T1w", - "space-longitudinal_desc-reorient_T1w", - "space-longitudinal_desc-brain_mask", - ] - rpool.gather_pipes(wf, config, add_excl=excl) + excl = [ + "space-longitudinal_desc-brain_T1w", + "space-longitudinal_desc-reorient_T1w", + "space-longitudinal_desc-brain_mask", + ] + rpool.gather_pipes(wf, config, add_excl=excl) - # this is going to run multiple times! - # once for every strategy! + if not dry_run: wf.run() - # now, just write out a copy of the above to each session - config.pipeline_setup["pipeline_name"] = orig_pipe_name - for session in sub_list: - unique_id = session["unique_id"] - - try: - creds_path = session["creds_path"] - if creds_path and "none" not in creds_path.lower(): - if os.path.exists(creds_path): - input_creds_path = os.path.abspath(creds_path) - else: - err_msg = ( - 'Credentials path: "%s" for subject "%s" ' - 'session "%s" was not found. Check this path ' - "and try again." % (creds_path, subject_id, unique_id) - ) - raise Exception(err_msg) + # now, just write out a copy of the above to each session + config.pipeline_setup["pipeline_name"] = orig_pipe_name + for session in sub_list: + unique_id = session["unique_id"] + + try: + creds_path = session["creds_path"] + if creds_path and "none" not in creds_path.lower(): + if os.path.exists(creds_path): + input_creds_path = os.path.abspath(creds_path) else: - input_creds_path = None - except KeyError: + err_msg = ( + 'Credentials path: "%s" for subject "%s" ' + 'session "%s" was not found. Check this path ' + "and try again." % (creds_path, subject_id, unique_id) + ) + raise Exception(err_msg) + else: input_creds_path = None + except KeyError: + input_creds_path = None - wf = initialize_nipype_wf(config, sub_list[0]) + wf = initialize_nipype_wf(config, sub_list[0]) - wf, rpool = initiate_rpool(wf, config, session) + wf, rpool = initiate_rpool(wf, config, session, rpool=rpool) - config.pipeline_setup["pipeline_name"] = f"longitudinal_{orig_pipe_name}" + config.pipeline_setup["pipeline_name"] = f"longitudinal_{orig_pipe_name}" + if "derivatives_dir" in session: rpool = ingress_output_dir( wf, config, @@ -497,42 +525,43 @@ def anat_longitudinal_wf( creds_path=input_creds_path, ) - select_node_name = f"select_{unique_id}" - select_sess = pe.Node( - Function( - input_names=["session", "output_brains", "warps"], - output_names=["brain_path", "warp_path"], - function=select_session, - ), - name=select_node_name, - ) - select_sess.inputs.session = unique_id - - wf.connect(template_node, "output_brain_list", select_sess, "output_brains") - wf.connect(template_node, "warp_list", select_sess, "warps") - - rpool.set_data( - "space-longitudinal_desc-brain_T1w", - select_sess, - "brain_path", - {}, - "", - select_node_name, - ) + select_node_name = f"FSL_select_{unique_id}" + select_sess = pe.Node( + Function( + input_names=["session", "output_brains", "warps"], + output_names=["brain_path", "warp_path"], + function=select_session, + ), + name=select_node_name, + ) + select_sess.inputs.session = unique_id - rpool.set_data( - "from-T1w_to-longitudinal_mode-image_desc-linear_xfm", - select_sess, - "warp_path", - {}, - "", - select_node_name, - ) + wf.connect(template_node, "output_brain_list", select_sess, "output_brains") + wf.connect(template_node, "warp_list", select_sess, "warps") + + rpool.set_data( + "space-longitudinal_desc-brain_T1w", + select_sess, + "brain_path", + {}, + "", + select_node_name, + ) + + rpool.set_data( + "from-T1w_to-longitudinal_mode-image_" "desc-linear_xfm", + select_sess, + "warp_path", + {}, + "", + select_node_name, + ) - config.pipeline_setup["pipeline_name"] = orig_pipe_name - excl = ["space-template_desc-brain_T1w", "space-T1w_desc-brain_mask"] + config.pipeline_setup["pipeline_name"] = orig_pipe_name + excl = ["space-template_desc-brain_T1w", "space-T1w_desc-brain_mask"] - rpool.gather_pipes(wf, config, add_excl=excl) + rpool.gather_pipes(wf, config, add_excl=excl) + if not dry_run: wf.run() # begin single-session stuff again @@ -571,4 +600,5 @@ def anat_longitudinal_wf( # this is going to run multiple times! # once for every strategy! - wf.run() + if not dry_run: + wf.run() diff --git a/CPAC/pipeline/cpac_pipeline.py b/CPAC/pipeline/cpac_pipeline.py index 26f67c970f..ae0a52121b 100644 --- a/CPAC/pipeline/cpac_pipeline.py +++ b/CPAC/pipeline/cpac_pipeline.py @@ -25,6 +25,7 @@ import sys import time from time import strftime +from typing import Literal, Optional import yaml import nipype @@ -130,7 +131,7 @@ # pylint: disable=wrong-import-order from CPAC.pipeline import nipype_pipeline_engine as pe from CPAC.pipeline.check_outputs import check_outputs -from CPAC.pipeline.engine import initiate_rpool, NodeBlock +from CPAC.pipeline.engine import initiate_rpool, NodeBlock, ResourcePool from CPAC.pipeline.nipype_pipeline_engine.plugins import ( LegacyMultiProcPlugin, MultiProcPlugin, @@ -162,7 +163,6 @@ warp_deriv_mask_to_EPItemplate, warp_deriv_mask_to_T1template, warp_sbref_to_T1template, - warp_T1mask_to_template, warp_timeseries_to_EPItemplate, warp_timeseries_to_T1template, warp_timeseries_to_T1template_abcd, @@ -170,7 +170,7 @@ warp_timeseries_to_T1template_deriv, warp_tissuemask_to_EPItemplate, warp_tissuemask_to_T1template, - warp_wholeheadT1_to_template, + warp_to_template, ) from CPAC.reho.reho import reho, reho_space_template from CPAC.sca.sca import dual_regression, multiple_regression, SCA_AVG @@ -1061,25 +1061,30 @@ def build_anat_preproc_stack(rpool, cfg, pipeline_blocks=None): return pipeline_blocks -def build_T1w_registration_stack(rpool, cfg, pipeline_blocks=None): +def build_T1w_registration_stack( + rpool: ResourcePool, + cfg: Configuration, + pipeline_blocks: Optional[list] = None, + space: Literal["longitudinal", "T1w"] = "T1w", +): """Build the T1w registration pipeline blocks.""" if not pipeline_blocks: pipeline_blocks = [] reg_blocks = [] - if not rpool.check_rpool("from-T1w_to-template_mode-image_xfm"): + if not rpool.check_rpool(f"from-{space}_to-template_mode-image_xfm"): reg_blocks = [ [register_ANTs_anat_to_template, register_FSL_anat_to_template], overwrite_transform_anat_to_template, - warp_wholeheadT1_to_template, - warp_T1mask_to_template, + warp_to_template("wholehead", space), + warp_to_template("mask", space), ] if not rpool.check_rpool("desc-restore-brain_T1w"): reg_blocks.append(correct_restore_brain_intensity_abcd) if cfg.voxel_mirrored_homotopic_connectivity["run"]: - if not rpool.check_rpool("from-T1w_to-symtemplate_mode-image_xfm"): + if not rpool.check_rpool(f"from-{space}_to-symtemplate_mode-image_xfm"): reg_blocks.append( [ register_symmetric_ANTs_anat_to_template, diff --git a/CPAC/pipeline/cpac_runner.py b/CPAC/pipeline/cpac_runner.py index 6dc1241036..9f09623f69 100644 --- a/CPAC/pipeline/cpac_runner.py +++ b/CPAC/pipeline/cpac_runner.py @@ -236,7 +236,7 @@ def run_cpac_on_cluster(config_file, subject_list_file, cluster_files_dir): f.write(pid) -def run_T1w_longitudinal(sublist, cfg): +def run_T1w_longitudinal(sublist, cfg: Configuration, dry_run: bool = False): subject_id_dict = {} for sub in sublist: @@ -249,7 +249,7 @@ def run_T1w_longitudinal(sublist, cfg): # sessions for each participant as value for subject_id, sub_list in subject_id_dict.items(): if len(sub_list) > 1: - anat_longitudinal_wf(subject_id, sub_list, cfg) + anat_longitudinal_wf(subject_id, sub_list, cfg, dry_run=dry_run) elif len(sub_list) == 1: warnings.warn( "\n\nThere is only one anatomical session " @@ -495,7 +495,7 @@ def run( hasattr(c, "longitudinal_template_generation") and c.longitudinal_template_generation["run"] ): - run_T1w_longitudinal(sublist, c) + run_T1w_longitudinal(sublist, c, dry_run=test_config) # TODO functional longitudinal pipeline """ diff --git a/CPAC/pipeline/engine.py b/CPAC/pipeline/engine.py index f99749c94b..922b34c682 100644 --- a/CPAC/pipeline/engine.py +++ b/CPAC/pipeline/engine.py @@ -1229,7 +1229,11 @@ def gather_pipes(self, wf, cfg, all=False, add_incl=None, add_excl=None): unlabelled.remove(key) # del all_forks for pipe_idx in self.rpool[resource]: - pipe_x = self.get_pipe_number(pipe_idx) + try: + pipe_x = self.get_pipe_number(pipe_idx) + except ValueError: + # already gone + continue json_info = self.rpool[resource][pipe_idx]["json"] out_dct = self.rpool[resource][pipe_idx]["out"] @@ -2623,7 +2627,14 @@ def _set_nested(attr, keys): return wf, rpool -def initiate_rpool(wf, cfg, data_paths=None, part_id=None): +def initiate_rpool( + wf: pe.Workflow, + cfg: Configuration, + data_paths=None, + part_id=None, + *, + rpool: Optional[ResourcePool] = None, +): """ Initialize a new ResourcePool. @@ -2662,7 +2673,7 @@ def initiate_rpool(wf, cfg, data_paths=None, part_id=None): unique_id = part_id creds_path = None - rpool = ResourcePool(name=unique_id, cfg=cfg) + rpool = ResourcePool(rpool=rpool.rpool if rpool else None, name=unique_id, cfg=cfg) if data_paths: # ingress outdir diff --git a/CPAC/pipeline/nipype_pipeline_engine/engine.py b/CPAC/pipeline/nipype_pipeline_engine/engine.py index 743285ae9d..78eda7dca3 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/engine.py +++ b/CPAC/pipeline/nipype_pipeline_engine/engine.py @@ -8,6 +8,7 @@ # * Applies a random seed # * Supports overriding memory estimates via a log file and a buffer # * Adds quotation marks around strings in dotfiles +# * Adds methods for cross-graph connections # ORIGINAL WORK'S ATTRIBUTION NOTICE: # Copyright (c) 2009-2016, Nipype developers @@ -50,16 +51,18 @@ for Nipype's documentation. """ # pylint: disable=line-too-long +from collections.abc import Mapping from copy import deepcopy from inspect import Parameter, Signature, signature import os import re -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar, Optional, TYPE_CHECKING from numpy import prod from traits.trait_base import Undefined from traits.trait_handlers import TraitListObject from nibabel import load +from nipype.interfaces.base.support import InterfaceResult from nipype.interfaces.utility import Function from nipype.pipeline import engine as pe from nipype.pipeline.engine.utils import ( @@ -76,6 +79,9 @@ from CPAC.utils.monitoring import getLogger, WFLOGGER +if TYPE_CHECKING: + pass + # set global default mem_gb DEFAULT_MEM_GB = 2.0 UNDEFINED_SIZE = (42, 42, 42, 1200) @@ -527,6 +533,25 @@ def __init__(self, name, base_dir=None, debug=False): self._nodes_cache = set() self._nested_workflows_cache = set() + def copy_input_connections(self, node1: pe.Node, node2: pe.Node) -> None: + """Copy input connections from ``node1`` to ``node2``.""" + new_connections: list[tuple[pe.Node, str, pe.Node, str]] = [] + for connection in self._graph.edges: + _out: pe.Node + _in: pe.Node + _out, _in = connection + if _in == node1: + details = self._graph.get_edge_data(*connection) + if "connect" in details: + for connect in details["connect"]: + new_connections.append((_out, connect[0], node2, connect[1])) + for connection in new_connections: + try: + self.connect(*connection) + except Exception: + # connection already exists + continue + def _configure_exec_nodes(self, graph): """Ensure that each node knows where to get inputs from.""" for node in graph.nodes(): @@ -565,6 +590,20 @@ def _configure_exec_nodes(self, graph): except (FileNotFoundError, KeyError, TypeError): self._handle_just_in_time_exception(node) + def _connect_node_or_path( + self, + node: pe.Node, + strats_dct: Mapping[str, list[tuple[pe.Node, str] | str]], + key: str, + index: int, + ) -> None: + """Set input appropriately for either a Node or a path string.""" + _input: str = f"in{index + 1}" + if isinstance(strats_dct[key][index], str): + setattr(node.inputs, _input, strats_dct[key][index]) + else: + self.connect(*strats_dct[key][index], node, _input) + def _get_dot( self, prefix=None, hierarchy=None, colored=False, simple_form=True, level=0 ): @@ -678,6 +717,22 @@ def _get_dot( WFLOGGER.debug("cross connection: %s", dotlist[-1]) return ("\n" + prefix).join(dotlist) + def get_output_path(self, node: pe.Node, out: str) -> str: + """Get an output path from an already-run Node.""" + try: + _run_node: pe.Node = next( + iter( + _ + for _ in self.run(updatehash=True).nodes + if _.fullname == node.fullname + ) + ) + except IndexError as index_error: + msg = f"Could not find {node.fullname} in {self}'s run Nodes." + raise LookupError(msg) from index_error + _res: InterfaceResult = _run_node.run() + return getattr(_res.outputs, out) + def _handle_just_in_time_exception(self, node): # pylint: disable=protected-access if hasattr(self, "_local_func_scans"): diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index 1c6b6fa71a..c10a60f39d 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -17,7 +17,7 @@ # pylint: disable=too-many-lines,ungrouped-imports,wrong-import-order """Workflows for registration.""" -from typing import Optional +from typing import Literal, Optional, TYPE_CHECKING from voluptuous import RequiredFieldInvalid from nipype.interfaces import afni, ants, c3, fsl, utility as util @@ -26,7 +26,7 @@ from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks from CPAC.pipeline import nipype_pipeline_engine as pe -from CPAC.pipeline.nodeblock import nodeblock +from CPAC.pipeline.nodeblock import nodeblock, NODEBLOCK_RETURN, NodeBlockFunction from CPAC.registration.utils import ( change_itk_transform_type, check_transforms, @@ -34,6 +34,7 @@ hardcoded_reg, interpolation_string, one_d_to_mat, + prepend_space, run_c3d, run_c4d, seperate_warps_list, @@ -43,6 +44,10 @@ from CPAC.utils.interfaces.fsl import Merge as fslMerge from CPAC.utils.utils import check_prov_for_motion_tool, check_prov_for_regtool +if TYPE_CHECKING: + from CPAC.pipeline.engine import ResourcePool + from CPAC.utils.configuration import Configuration + def apply_transform( wf_name, @@ -2656,14 +2661,13 @@ def register_ANTs_anat_to_template(wf, cfg, strat_pool, pipe_num, opt=None): wf.connect(node, out, ants_rc, "inputspec.lesion_mask") if "space-longitudinal" in brain: - for key in outputs: + for key in list(outputs.keys()): for direction in ["from", "to"]: if f"{direction}-T1w" in key: new_key = key.replace( f"{direction}-T1w", f"{direction}-longitudinal" ) outputs[new_key] = outputs[key] - del outputs[key] return (wf, outputs) @@ -3849,115 +3853,115 @@ def apply_blip_to_timeseries_separately(wf, cfg, strat_pool, pipe_num, opt=None) return (wf, outputs) -@nodeblock( - name="transform_whole_head_T1w_to_T1template", - config=["registration_workflows", "anatomical_registration"], - switch=["run"], - inputs=[ - ( - "desc-head_T1w", - "from-T1w_to-template_mode-image_xfm", - "space-template_desc-head_T1w", - ), - "T1w-template", - ], - outputs={"space-template_desc-head_T1w": {"Template": "T1w-template"}}, -) -def warp_wholeheadT1_to_template(wf, cfg, strat_pool, pipe_num, opt=None): - """Warp T1 head to template.""" - xfm_prov = strat_pool.get_cpac_provenance("from-T1w_to-template_mode-image_xfm") - reg_tool = check_prov_for_regtool(xfm_prov) - - num_cpus = cfg.pipeline_setup["system_config"]["max_cores_per_participant"] - - num_ants_cores = cfg.pipeline_setup["system_config"]["num_ants_threads"] - - apply_xfm = apply_transform( - f"warp_wholehead_T1w_to_T1template_{pipe_num}", - reg_tool, - time_series=False, - num_cpus=num_cpus, - num_ants_cores=num_ants_cores, - ) - - if reg_tool == "ants": - apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ - "functional_registration" - ]["func_registration_to_template"]["ANTs_pipelines"]["interpolation"] - elif reg_tool == "fsl": - apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ - "functional_registration" - ]["func_registration_to_template"]["FNIRT_pipelines"]["interpolation"] - - connect = strat_pool.get_data("desc-head_T1w") - node, out = connect - wf.connect(node, out, apply_xfm, "inputspec.input_image") - - node, out = strat_pool.get_data("T1w-template") - wf.connect(node, out, apply_xfm, "inputspec.reference") +def warp_to_template( + warp_what: Literal["mask", "wholehead"], space_from: Literal["longitudinal", "T1w"] +) -> NodeBlockFunction: + """Get a NodeBlockFunction to transform a resource from ``space`` to template. - node, out = strat_pool.get_data("from-T1w_to-template_mode-image_xfm") - wf.connect(node, out, apply_xfm, "inputspec.transform") - - outputs = {"space-template_desc-head_T1w": (apply_xfm, "outputspec.output_image")} + The resource being warped needs to be the first list or string in the tuple + in the first position of the decorator's "inputs". + """ + _decorators = { + "mask": { + "name": f"transform_{space_from}-mask_to_T1-template", + "switch": [ + ["registration_workflows", "anatomical_registration", "run"], + ["anatomical_preproc", "run"], + ["anatomical_preproc", "brain_extraction", "run"], + ], + "inputs": [ + ( + f"space-{space_from}_desc-brain_mask", + f"from-{space_from}_to-template_mode-image_xfm", + ), + "T1w-template", + ], + "outputs": {"space-template_desc-brain_mask": {"Template": "T1w-template"}}, + }, + "wholehead": { + "name": f"transform_wholehead_{space_from}_to_T1template", + "config": ["registration_workflows", "anatomical_registration"], + "switch": ["run"], + "inputs": [ + ( + ["desc-head_T1w", "desc-reorient_T1w"], + [ + f"from-{space_from}_to-template_mode-image_xfm", + f"from-{space_from}_to-template_mode-image_xfm", + ], + "space-template_desc-head_T1w", + ), + "T1w-template", + ], + "outputs": {"space-template_desc-head_T1w": {"Template": "T1w-template"}}, + }, + } + if space_from != "T1w": + _decorators[warp_what]["inputs"][0] = ( + prepend_space(_decorators[warp_what]["inputs"][0][0], space_from), + *_decorators[warp_what]["inputs"][0][1:], + ) - return (wf, outputs) + @nodeblock(**_decorators[warp_what]) + def warp_to_template_fxn( + wf: pe.Workflow, + cfg: "Configuration", + strat_pool: "ResourcePool", + pipe_num: int, + opt: Optional[str] = None, + ) -> NODEBLOCK_RETURN: + """Transform a resource to template space.""" + xfm_prov = strat_pool.get_cpac_provenance( + f"from-{space_from}_to-template_mode-image_xfm" + ) + reg_tool = check_prov_for_regtool(xfm_prov) + num_cpus = cfg.pipeline_setup["system_config"]["max_cores_per_participant"] -@nodeblock( - name="transform_T1mask_to_T1template", - switch=[ - ["registration_workflows", "anatomical_registration", "run"], - ["anatomical_preproc", "run"], - ["anatomical_preproc", "brain_extraction", "run"], - ], - inputs=[ - ("space-T1w_desc-brain_mask", "from-T1w_to-template_mode-image_xfm"), - "T1w-template", - ], - outputs={"space-template_desc-brain_mask": {"Template": "T1w-template"}}, -) -def warp_T1mask_to_template(wf, cfg, strat_pool, pipe_num, opt=None): - """Warp T1 mask to template.""" - xfm_prov = strat_pool.get_cpac_provenance("from-T1w_to-template_mode-image_xfm") - reg_tool = check_prov_for_regtool(xfm_prov) + num_ants_cores = cfg.pipeline_setup["system_config"]["num_ants_threads"] - num_cpus = cfg.pipeline_setup["system_config"]["max_cores_per_participant"] + apply_xfm = apply_transform( + f"warp_{space_from}{warp_what}_to_T1template_{pipe_num}", + reg_tool, + time_series=False, + num_cpus=num_cpus, + num_ants_cores=num_ants_cores, + ) - num_ants_cores = cfg.pipeline_setup["system_config"]["num_ants_threads"] + if warp_what == "mask": + apply_xfm.inputs.inputspec.interpolation = "NearestNeighbor" + else: + tool = ( + "ANTs" if reg_tool == "ants" else "FNIRT" if reg_tool == "fsl" else None + ) + if not tool: + msg = f"Warp {warp_what} to template not implemented for {reg_tool}." + raise NotImplementedError(msg) + apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ + "functional_registration" + ]["func_registration_to_template"][f"{tool}_pipelines"]["interpolation"] - apply_xfm = apply_transform( - f"warp_T1mask_to_T1template_{pipe_num}", - reg_tool, - time_series=False, - num_cpus=num_cpus, - num_ants_cores=num_ants_cores, - ) + # the resource being warped needs to be inputs[0][0] for this + node, out = strat_pool.get_data(_decorators[warp_what]["inputs"][0][0]) + wf.connect(node, out, apply_xfm, "inputspec.input_image") - apply_xfm.inputs.inputspec.interpolation = "NearestNeighbor" - """ - if reg_tool == 'ants': - apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ - 'functional_registration']['func_registration_to_template'][ - 'ANTs_pipelines']['interpolation'] - elif reg_tool == 'fsl': - apply_xfm.inputs.inputspec.interpolation = cfg.registration_workflows[ - 'functional_registration']['func_registration_to_template'][ - 'FNIRT_pipelines']['interpolation'] - """ - connect = strat_pool.get_data("space-T1w_desc-brain_mask") - node, out = connect - wf.connect(node, out, apply_xfm, "inputspec.input_image") + node, out = strat_pool.get_data("T1w-template") + wf.connect(node, out, apply_xfm, "inputspec.reference") - node, out = strat_pool.get_data("T1w-template") - wf.connect(node, out, apply_xfm, "inputspec.reference") + node, out = strat_pool.get_data(f"from-{space_from}_to-template_mode-image_xfm") + wf.connect(node, out, apply_xfm, "inputspec.transform") - node, out = strat_pool.get_data("from-T1w_to-template_mode-image_xfm") - wf.connect(node, out, apply_xfm, "inputspec.transform") + outputs = { + # there's only one output, so that's what we give here + next(iter(_decorators[warp_what]["outputs"].keys())): ( + apply_xfm, + "outputspec.output_image", + ) + } - outputs = {"space-template_desc-brain_mask": (apply_xfm, "outputspec.output_image")} + return wf, outputs - return (wf, outputs) + return warp_to_template_fxn @nodeblock( @@ -5416,8 +5420,8 @@ def warp_tissuemask_to_template(wf, cfg, strat_pool, pipe_num, xfm, template_spa def warp_resource_to_template( wf: pe.Workflow, - cfg, - strat_pool, + cfg: "Configuration", + strat_pool: "ResourcePool", pipe_num: int, input_resource: list[str] | str, xfm: str, diff --git a/CPAC/registration/utils.py b/CPAC/registration/utils.py index 4e0dc4421e..64bcb23884 100644 --- a/CPAC/registration/utils.py +++ b/CPAC/registration/utils.py @@ -18,6 +18,7 @@ import os import subprocess +from typing import overload import numpy as np from voluptuous import RequiredFieldInvalid @@ -808,3 +809,18 @@ def run_c4d(input_name, output_name): os.system(cmd) return output1, output2, output3 + + +@overload +def prepend_space(resource: list[str], space: str) -> list[str]: ... +@overload +def prepend_space(resource: str, space: str) -> str: ... +def prepend_space(resource: str | list[str], space: str) -> str | list[str]: + """Given a resource or list of resources, return same but with updated space.""" + if isinstance(resource, list): + return [prepend_space(_, space) for _ in resource] + if "space" not in resource: + return f"space-{space}_{resource}" + pre, post = resource.split("space-") + _old_space, post = post.split("_", 1) + return f"space-{space}_".join([pre, post]) diff --git a/CPAC/seg_preproc/seg_preproc.py b/CPAC/seg_preproc/seg_preproc.py index f769cf14b3..3051943314 100644 --- a/CPAC/seg_preproc/seg_preproc.py +++ b/CPAC/seg_preproc/seg_preproc.py @@ -1,4 +1,4 @@ -# Copyright (C) 2012-2023 C-PAC Developers +# Copyright (C) 2012-2024 C-PAC Developers # This file is part of C-PAC. @@ -507,27 +507,10 @@ def create_seg_preproc_antsJointLabel_method(wf_name="seg_preproc_templated_base "WM-path", ], outputs=[ - "label-CSF_mask", - "label-GM_mask", - "label-WM_mask", - "label-CSF_desc-preproc_mask", - "label-GM_desc-preproc_mask", - "label-WM_desc-preproc_mask", - "label-CSF_probseg", - "label-GM_probseg", - "label-WM_probseg", - "label-CSF_pveseg", - "label-GM_pveseg", - "label-WM_pveseg", - "space-longitudinal_label-CSF_mask", - "space-longitudinal_label-GM_mask", - "space-longitudinal_label-WM_mask", - "space-longitudinal_label-CSF_desc-preproc_mask", - "space-longitudinal_label-GM_desc-preproc_mask", - "space-longitudinal_label-WM_desc-preproc_mask", - "space-longitudinal_label-CSF_probseg", - "space-longitudinal_label-GM_probseg", - "space-longitudinal_label-WM_probseg", + f"{long}label-{tissue}_{entity}" + for long in ["", "space-longitudinal_"] + for tissue in ["CSF", "GM", "WM"] + for entity in ["mask", "desc-preproc_mask", "probseg", "pveseg"] ], ) def tissue_seg_fsl_fast(wf, cfg, strat_pool, pipe_num, opt=None):