Skip to content

Commit

Permalink
🔀 Merge changes from #2160
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Nov 15, 2024
1 parent ae5be2c commit 253ca80
Show file tree
Hide file tree
Showing 8 changed files with 438 additions and 334 deletions.
418 changes: 224 additions & 194 deletions CPAC/longitudinal/wf/anat.py

Large diffs are not rendered by default.

21 changes: 13 additions & 8 deletions CPAC/pipeline/cpac_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sys
import time
from time import strftime
from typing import Literal, Optional

import yaml
import nipype
Expand Down Expand Up @@ -130,7 +131,7 @@
# pylint: disable=wrong-import-order
from CPAC.pipeline import nipype_pipeline_engine as pe
from CPAC.pipeline.check_outputs import check_outputs
from CPAC.pipeline.engine import initiate_rpool, NodeBlock
from CPAC.pipeline.engine import initiate_rpool, NodeBlock, ResourcePool
from CPAC.pipeline.nipype_pipeline_engine.plugins import (
LegacyMultiProcPlugin,
MultiProcPlugin,
Expand Down Expand Up @@ -162,15 +163,14 @@
warp_deriv_mask_to_EPItemplate,
warp_deriv_mask_to_T1template,
warp_sbref_to_T1template,
warp_T1mask_to_template,
warp_timeseries_to_EPItemplate,
warp_timeseries_to_T1template,
warp_timeseries_to_T1template_abcd,
warp_timeseries_to_T1template_dcan_nhp,
warp_timeseries_to_T1template_deriv,
warp_tissuemask_to_EPItemplate,
warp_tissuemask_to_T1template,
warp_wholeheadT1_to_template,
warp_to_template,
)
from CPAC.reho.reho import reho, reho_space_template
from CPAC.sca.sca import dual_regression, multiple_regression, SCA_AVG
Expand Down Expand Up @@ -1061,25 +1061,30 @@ def build_anat_preproc_stack(rpool, cfg, pipeline_blocks=None):
return pipeline_blocks


def build_T1w_registration_stack(rpool, cfg, pipeline_blocks=None):
def build_T1w_registration_stack(
rpool: ResourcePool,
cfg: Configuration,
pipeline_blocks: Optional[list] = None,
space: Literal["longitudinal", "T1w"] = "T1w",
):
"""Build the T1w registration pipeline blocks."""
if not pipeline_blocks:
pipeline_blocks = []

reg_blocks = []
if not rpool.check_rpool("from-T1w_to-template_mode-image_xfm"):
if not rpool.check_rpool(f"from-{space}_to-template_mode-image_xfm"):
reg_blocks = [
[register_ANTs_anat_to_template, register_FSL_anat_to_template],
overwrite_transform_anat_to_template,
warp_wholeheadT1_to_template,
warp_T1mask_to_template,
warp_to_template("wholehead", space),
warp_to_template("mask", space),
]

if not rpool.check_rpool("desc-restore-brain_T1w"):
reg_blocks.append(correct_restore_brain_intensity_abcd)

if cfg.voxel_mirrored_homotopic_connectivity["run"]:
if not rpool.check_rpool("from-T1w_to-symtemplate_mode-image_xfm"):
if not rpool.check_rpool(f"from-{space}_to-symtemplate_mode-image_xfm"):
reg_blocks.append(
[
register_symmetric_ANTs_anat_to_template,
Expand Down
6 changes: 3 additions & 3 deletions CPAC/pipeline/cpac_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def run_cpac_on_cluster(config_file, subject_list_file, cluster_files_dir):
f.write(pid)


def run_T1w_longitudinal(sublist, cfg):
def run_T1w_longitudinal(sublist, cfg: Configuration, dry_run: bool = False):
subject_id_dict = {}

for sub in sublist:
Expand All @@ -249,7 +249,7 @@ def run_T1w_longitudinal(sublist, cfg):
# sessions for each participant as value
for subject_id, sub_list in subject_id_dict.items():
if len(sub_list) > 1:
anat_longitudinal_wf(subject_id, sub_list, cfg)
anat_longitudinal_wf(subject_id, sub_list, cfg, dry_run=dry_run)
elif len(sub_list) == 1:
warnings.warn(
"\n\nThere is only one anatomical session "
Expand Down Expand Up @@ -495,7 +495,7 @@ def run(
hasattr(c, "longitudinal_template_generation")
and c.longitudinal_template_generation["run"]
):
run_T1w_longitudinal(sublist, c)
run_T1w_longitudinal(sublist, c, dry_run=test_config)
# TODO functional longitudinal pipeline

"""
Expand Down
17 changes: 14 additions & 3 deletions CPAC/pipeline/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,11 @@ def gather_pipes(self, wf, cfg, all=False, add_incl=None, add_excl=None):
unlabelled.remove(key)
# del all_forks
for pipe_idx in self.rpool[resource]:
pipe_x = self.get_pipe_number(pipe_idx)
try:
pipe_x = self.get_pipe_number(pipe_idx)
except ValueError:
# already gone
continue
json_info = self.rpool[resource][pipe_idx]["json"]
out_dct = self.rpool[resource][pipe_idx]["out"]

Expand Down Expand Up @@ -2623,7 +2627,14 @@ def _set_nested(attr, keys):
return wf, rpool


def initiate_rpool(wf, cfg, data_paths=None, part_id=None):
def initiate_rpool(
wf: pe.Workflow,
cfg: Configuration,
data_paths=None,
part_id=None,
*,
rpool: Optional[ResourcePool] = None,
):
"""
Initialize a new ResourcePool.
Expand Down Expand Up @@ -2662,7 +2673,7 @@ def initiate_rpool(wf, cfg, data_paths=None, part_id=None):
unique_id = part_id
creds_path = None

rpool = ResourcePool(name=unique_id, cfg=cfg)
rpool = ResourcePool(rpool=rpool.rpool if rpool else None, name=unique_id, cfg=cfg)

if data_paths:
# ingress outdir
Expand Down
57 changes: 56 additions & 1 deletion CPAC/pipeline/nipype_pipeline_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# * Applies a random seed
# * Supports overriding memory estimates via a log file and a buffer
# * Adds quotation marks around strings in dotfiles
# * Adds methods for cross-graph connections

# ORIGINAL WORK'S ATTRIBUTION NOTICE:
# Copyright (c) 2009-2016, Nipype developers
Expand Down Expand Up @@ -50,16 +51,18 @@
for Nipype's documentation.
""" # pylint: disable=line-too-long

from collections.abc import Mapping
from copy import deepcopy
from inspect import Parameter, Signature, signature
import os
import re
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar, Optional, TYPE_CHECKING

from numpy import prod
from traits.trait_base import Undefined
from traits.trait_handlers import TraitListObject
from nibabel import load
from nipype.interfaces.base.support import InterfaceResult
from nipype.interfaces.utility import Function
from nipype.pipeline import engine as pe
from nipype.pipeline.engine.utils import (
Expand All @@ -76,6 +79,9 @@

from CPAC.utils.monitoring import getLogger, WFLOGGER

if TYPE_CHECKING:
pass

# set global default mem_gb
DEFAULT_MEM_GB = 2.0
UNDEFINED_SIZE = (42, 42, 42, 1200)
Expand Down Expand Up @@ -527,6 +533,25 @@ def __init__(self, name, base_dir=None, debug=False):
self._nodes_cache = set()
self._nested_workflows_cache = set()

def copy_input_connections(self, node1: pe.Node, node2: pe.Node) -> None:
"""Copy input connections from ``node1`` to ``node2``."""
new_connections: list[tuple[pe.Node, str, pe.Node, str]] = []
for connection in self._graph.edges:
_out: pe.Node
_in: pe.Node
_out, _in = connection
if _in == node1:
details = self._graph.get_edge_data(*connection)
if "connect" in details:
for connect in details["connect"]:
new_connections.append((_out, connect[0], node2, connect[1]))
for connection in new_connections:
try:
self.connect(*connection)
except Exception:
# connection already exists
continue

def _configure_exec_nodes(self, graph):
"""Ensure that each node knows where to get inputs from."""
for node in graph.nodes():
Expand Down Expand Up @@ -565,6 +590,20 @@ def _configure_exec_nodes(self, graph):
except (FileNotFoundError, KeyError, TypeError):
self._handle_just_in_time_exception(node)

def _connect_node_or_path(
self,
node: pe.Node,
strats_dct: Mapping[str, list[tuple[pe.Node, str] | str]],
key: str,
index: int,
) -> None:
"""Set input appropriately for either a Node or a path string."""
_input: str = f"in{index + 1}"
if isinstance(strats_dct[key][index], str):
setattr(node.inputs, _input, strats_dct[key][index])
else:
self.connect(*strats_dct[key][index], node, _input)

def _get_dot(
self, prefix=None, hierarchy=None, colored=False, simple_form=True, level=0
):
Expand Down Expand Up @@ -678,6 +717,22 @@ def _get_dot(
WFLOGGER.debug("cross connection: %s", dotlist[-1])
return ("\n" + prefix).join(dotlist)

def get_output_path(self, node: pe.Node, out: str) -> str:
"""Get an output path from an already-run Node."""
try:
_run_node: pe.Node = next(
iter(
_
for _ in self.run(updatehash=True).nodes
if _.fullname == node.fullname
)
)
except IndexError as index_error:
msg = f"Could not find {node.fullname} in {self}'s run Nodes."
raise LookupError(msg) from index_error
_res: InterfaceResult = _run_node.run()
return getattr(_res.outputs, out)

def _handle_just_in_time_exception(self, node):
# pylint: disable=protected-access
if hasattr(self, "_local_func_scans"):
Expand Down
Loading

0 comments on commit 253ca80

Please sign in to comment.