-
Notifications
You must be signed in to change notification settings - Fork 94
design_document
Standard, high-precision reconstruction takes O(10s) per event to run meaning that it is limited to the final steps in the reconstruction chain after simpler, less performant means have been used to bring down the rate. That is because standard IceCube reconstruction is optimisation-based, i.e. each event is iteratively tested against several particle hypotheses using a detailed likelihood-based approach. Alternatively, one could use machine learning, which separates optimisation (training) and inference. This way, one could train an ML model in advance, possibly using GPUs to speed up the process, and then only need to run one forward/inference pass for each event to be reconstructed. This has the potential to bring down reconstruction time by several orders of magnitude compared to the current state-of-the-art (RETRO, etc.).
The challenge with IceCube is its complexity, heterogeneity, and high dimensionality. This affects the choice of which ML paradigm to employ: Standard, tabular methods (e.g. BDTs) require collapsing the event information into a tabular format e.g. by using manually engineered features. However, from previous work (Baldi, Sadowski, and Whiteson, 1402.4735) it is known that in particular deep learning models provide the largest marginal improvement when applied to data at the lowest level rather than to engineered features only. In IceCube, this would imply using the DOM-level data (DOM coordinates, pulse charge, timing, etc.) directly in neural networks.
One option would be to use convolutional neural networks (CNNs), but these require the input data be Euclidean; something even the IceCube-86 detector isn't and with each additional detector component (DeepCore, Upgrade, etc.) it becomes less and less straightforward to format raw IceCube data in a way that is suitable for CNN processing. Another option is to use graph neural networks (GNNs). These models can accommodate any spatial geometry of the input data through the notion of adjacency. CNNs can be considered special cases of GNNs in the language of geometric deep learning (Bronstein et al., 2104.13478). Within this paradigm, DOM data is represented as nodes on a graph corresponding to the IceCube detector (naturally also accommodating any detector extensions), and information is passed through the network along edges connecting these nodes; corresponding to generalised convolutions of CNNs. This makes GNNs a well-suited paradigm for naturally accommodating low-level IceCube data.
To fully exploit the potential of GNNs in IceCub, leverage collaboration across institutes and analysis groups, and break down unproductive silos, it is proposed to coordinate and align all GNN development efforts in IceCube through the gnn-reco project. This project would constitute an internal "center of excellence" within this branch of machine learning, where all IceCube members can contribute new models, applications/use cases, etc. and get support for e.g. using GNNs in their analysis.
The collaboration itself will take place through weekly developer meetings, discussion in a dedicated Slack chat, etc. (TBC) The work will be coordinated and guided by the repository maintainers.
The technical solution (i.e. the project code base) will be the foundation on which GNN applications are built in IceCube, and therefore the quality and appropriateness of the base code structure is crucial for the success of the project. To this end, the project should aim to be sufficiently general and extensible; it should facilitate collaboration; and generally follow best practices for ML development.
Generally, the project should provide effective means for:
- Ingesting I3 data and converting them to a format that is suitable for ML.
- Building GNN models specific to each detector and application with minimal need for duplicate work.
- Training and optimising GNN models on well-defined tasks using standard datasets
- Benchmarking GNN models against each other as well as against other reconstruction/classification/etc. methodologies to assess the added value of this approach for each potential application.
- Validating GNN models, including considerations regarding systematic uncertainties, calibration, etc.
- Deploying GNN models in a way that allows for easy use in official reconstruction chains and analyses.
The project should include the necessary functionality to:
- Extract raw, DOM-level data from I3Frames into a standard python format (e.g. numpy)
- Build graphs (i.e., adjacency matrices) from raw, DOM-level data
- Quickly ingest large numbers of events for training. Events should be indexable (i.e. not only available in sequential format) to allow for shuffling, batching.
The primary user stories/journeys are:
- Model user: These work from pre-configured IceTray environments and are mainly concerned with using pre-trained models as part of official reconstruction and analysis chains. These chains are configures as I3Trays, each reconstruction step an I3Module sequentially processing I3Frame objects passed to them. Therefore, these are concerned with pts. (1.) above: Extracting the relevant DOM-level data from the I3Frame object and passing this directly to an inference model.
- Model developer: These work from the GitHub code, checked out on Cobalt or local computing clusters and in principle work from I3-files, but due to the sequential file layout and non-standard data structure processing these directly is a massive throughput bottleneck when training ML models. Therefore, model developers will want to convert I3-files into an intermediate data format (e.g. .npy-files or SQLite databases) before processing these for the training itself.
Ad 1.) This suggests the need for a class or function with syntax (pseudocode):
class I3Extractor:
def __call__(self, frame: I3Frame) -> np.ndarray:
# Extracts raw, DOM-level data from I3Frame and returns as np.ndarray with standardised format, applicable to all/most GNN models
Ad 2.) Using the raw, DOM-level data, graphs need to be constructed for model learning and inference. In pytorch_geometric, this is done using the edge_index
and optionally edge_attr
properties of torch_geometric.data.Data
objects, e.g. using utilities such as knn_graph
in torch_geometric.nn
. Each GNN model will presuppose a certain connectivity or notion of adjacency, or multiple of these. This requires a modular approach to graph building. In the interest of portability and abstracting the question of adjacency/-ies away from the user, it might be a good idea to include graph building as part of the model forward pass (provided this does not constitute an impractical overhead). In the interesting of modularity and simplicity for the developer, it might make sense to provide a set of standard, IceCube-relevant "graph builders" that implement relevant adjacencies in a plug-and-play fashion. One possible syntax could be (pseudocode):
from torch_geometric.data import Data
class GraphBuilder(ABC):
@abstractmethod
def __call__ (self, data: Data) -> Data:
pass
class KNNGraphBuilder(GraphBuilder):
def __init__ (self, nearest_neighbours: int, columns: List[int]):
self._nearest_neighbours = nearest_neighbours
self._columns = columns
def __call__ (self, data: Data) -> Data:
# Constructs the adjacency matrix from the raw, DOM-level data and returns this matrix
assert data.edge_index is None
x, batch = data.x, data.batch
edge_index = torch_geometric.nn.knn_graph(x[:, self._columns], self._nearest_neighbours, batch)
data.edge_index = edge_index
return data
class EuclideanGraphBuilder(GraphBuilder):
# ...
class MinkowskiGraphBuilder(GraphBuilder):
# ...
with use (pseudocode):
class Model(torch.nn.Module):
# ...
def forward(self, data):
data = self._graph_builder(data)
edge_index = data.edge_index
# ...
x = F.leaky_relu(self.conv(x, edge_index))
# ...
Ad 3.) This suggests the need for dedicated data converter and dataset classes for each intermediate file format, e.g. through inheritance-based polymorphism. The data converter classes would be tasked with processing I3-files, extracting and converting the I3Frame-data (see pt. 1.), and saving these to the chosen intermediate file format. The dataset classes would be tasked with quickly loading, shuffling, batching, etc. this data for use in training pipelines. For easy interface with the default machine learning library (pytorch/-geometric) the latter classes might inherit from torch_geometric.data.{Data,Dataset}
. One possible syntax could then be (pseudocode):
class DataConverter(ABC):
def __init__(self, path: str):
self._path = path
self._extractor = I3Extractor()
def process(self, i3_file_paths: List[str]):
# ...
i3_file = dataio.I3File(i3_file_path, "r")
while i3_file.more():
try:
frame = i3_file.pop_physics()
except:
continue
array = self._extractor(frame)
self._save(array)
@abstractmethod
def _save(self, array: np.ndarray):
pass
@abstractmethod
def _initialise(self):
pass
class NumpyDataConverter(DataConverter):
def _save(self, array: np.ndarray):
with open(self._path, 'ab') as f:
# Append to file
# ...
class SQLiteDataConverter(DataConverter):
# ...
with use (pseudocode):
# my_conversion_script.py
def main():
converter = NumpyDataConverter("path/to/my_file.npy") # or SQLiteDataConverter("path/to/my_database.db"), or similar
converter.process(["path/to/some/files.i3", ...])
Similarly, for the dataset classes (pseudocode):
from torch_geometric.data import Data
class Dataset(torch_geometric.data.Dataset, ABC):
def __init__(self, path: str):
self._path = path
self._initialise()
@abstractmethod
def _initialise(self):
pass
@abstractmethod
def __len__(self) -> int:
pass
def __getitem__(self, ix: int) -> Data:
array = self._get_array(ix)
data = Data(x=array, edge_index=None, edge_attr=None
return data
@abstractmethod
def _get_array(self, ix: int) -> np.ndarray:
pass
class NumpyDataset(Dataset):
def _initialise(self):
self._source = np.load(self._path, mmap_mode='r')
def _get_array(self, ix: int) -> np.ndarray:
return self._source[ix]
# ...
class SQLiteDataset(Dataset):
# ...
with use (pseudocode):
# my_training_script.py
def main():
dataset = NumpyDataset("path/to/my_file.npy") # or SQLiteDataset("path/to/my_database.db"), or similar
dataloader = torch_geometric.loader.DataLoader(dataset, batch_size=128, shuffle=True)
for batch in dataloader:
model.fit(batch)
The way to deploy new modules in IceCube reconstruction is as I3Modules that can be used in I3Tray-chains.
This means that we can deploy the trained model in any format we like — simply specifying the path to the model artifacts, e.g. in CVMFS, as a keyword argument — and that the role of the I3Module will primarily to translate between the I3Frame and the data format ingested by the model (i.e. the same one produced by the Dataset class). This means that all models should ingest data in the same format (see Datasets) and that the data provided to the model as part of the reconstruction should be a single (non-batched) sample in this same format.
This means that the deployment class could look like:
class GNNModule(icetray.I3Module):
def __init__(self, context):
# Base class constructor
icetray.I3Module.__init__(self, context)
# Parameters to `I3Tray.Add(..., param=...)`
self.AddParameter("key", "doc_string", None)
self.AddParameter("model_path", "doc_string", None)
# Standard member variables
self.i3extractor = gnn_reco.I3Extractor()
def Configure(self):
self.key = self.GetParameter("key")
model_path = self.GetParameter("model_path")
assert self.key is not None, "..."
assert model_path is not None, "..."
assert os.path.exists(model_path), "..."
self.model = gnn_reco.Model.load(model_path)
def Physics (self, frame: icecube.icetray.I3Frame):
array = self._extract_feature_array_from_frame(frame)
data = torch_geometric.data.Data(x=array, edge_index=None)
prediction = self.model.predict(array).numpy()
frame = self._write_predictions_to_frame(frame, prediction)
self.PushFrame(frame)
def _extract_feature_array_from_frame(self, frame: icecube.icetray.I3Frame) -> np.ndarray:
return self.i3extractor(frame)
def _write_prediction_to_frame(self, frame: icecube.icetray.I3Frame, prediction: np.ndarray) -> I3Frame:
frame[self.key] = icetray.dataclasses.I3Double(prediction[0]) # Or similar
return frame
with use (pseudocode):
tray = I3Tray()
tray.AddModule("I3Reader", ...)
tray.Add(GNNModule, key="my_prediction", model_path="/cvmfs/path/to/my_model.pth")
...