Models API

PyTorch Lightning Model

class cryoPARES.models.model.RotationPredictionMixin[source]

Bases: object

__init_mixin__()[source]

Initialize mixin-specific attributes. Call this in the main class __init__.

class cryoPARES.models.model.PlModel(lr, symmetry, num_augmented_copies_per_batch, top_k_poses_nnet, so3model=None)[source]

Bases: RotationPredictionMixin, LightningModule

Parameters:
  • lr (float)

  • symmetry (str)

  • num_augmented_copies_per_batch (int)

  • top_k_poses_nnet (int)

  • so3model (Module | ScriptModule | None)

__init__(lr, symmetry, num_augmented_copies_per_batch, top_k_poses_nnet, so3model=None)[source]
Parameters:
  • lr (float)

  • symmetry (str)

  • num_augmented_copies_per_batch (int)

  • top_k_poses_nnet (int)

  • so3model (Module | ScriptModule | None)

static build_components(symmetry, num_augmented_copies_per_batch)[source]
training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_idx)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

forward(imgs, batch_idx, dataloader_idx=0, top_k=None)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

  • imgs (Tensor)

  • batch_idx (int)

  • dataloader_idx (int)

  • top_k (int | None)

Return type:

Any

Returns:

Your model’s output

optimizer_step_v1(epoch, batch_idx, optimizer, optimizer_idx=0, optimizer_closure=None, on_tpu=False, using_lbfgs=False)[source]
Parameters:
optimizer_step_v2(epoch, batch_idx, optimizer, optimizer_closure=None)[source]
Parameters:
configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

on_train_end()[source]

Called at the end of training before logger experiment is closed.

Return type:

None

PlModel Class

class cryoPARES.models.model.PlModel(lr, symmetry, num_augmented_copies_per_batch, top_k_poses_nnet, so3model=None)[source]

Bases: RotationPredictionMixin, LightningModule

Parameters:
  • lr (float)

  • symmetry (str)

  • num_augmented_copies_per_batch (int)

  • top_k_poses_nnet (int)

  • so3model (Module | ScriptModule | None)

__init__(lr, symmetry, num_augmented_copies_per_batch, top_k_poses_nnet, so3model=None)[source]
Parameters:
  • lr (float)

  • symmetry (str)

  • num_augmented_copies_per_batch (int)

  • top_k_poses_nnet (int)

  • so3model (Module | ScriptModule | None)

static build_components(symmetry, num_augmented_copies_per_batch)[source]
training_step(batch, batch_idx)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch, batch_idx)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

forward(imgs, batch_idx, dataloader_idx=0, top_k=None)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

  • imgs (Tensor)

  • batch_idx (int)

  • dataloader_idx (int)

  • top_k (int | None)

Return type:

Any

Returns:

Your model’s output

optimizer_step_v1(epoch, batch_idx, optimizer, optimizer_idx=0, optimizer_closure=None, on_tpu=False, using_lbfgs=False)[source]
Parameters:
optimizer_step_v2(epoch, batch_idx, optimizer, optimizer_closure=None)[source]
Parameters:
configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

on_train_end()[source]

Called at the end of training before logger experiment is closed.

Return type:

None

Image2Sphere Network

class cryoPARES.models.image2sphere.image2sphere.Image2Sphere(symmetry, lmax=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, hp_order=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, label_smoothing=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, num_augmented_copies_per_batch=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, enforce_symmetry=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, encoder=None, use_simCLR=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, simCLR_temperature=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, simCLR_loss_weight=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, average_neigs_for_pred=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, example_batch=None)[source]

Bases: Module

Instantiate Image2Sphere-style network for predicting distributions over SO(3) from single image

Parameters:
  • symmetry (str)

  • lmax (int)

  • hp_order (int)

  • label_smoothing (float)

  • num_augmented_copies_per_batch (int | None)

  • enforce_symmetry (bool)

  • encoder (Module | None)

  • use_simCLR (bool)

  • simCLR_temperature (float)

  • simCLR_loss_weight (float)

  • average_neigs_for_pred (bool)

  • example_batch (Dict[str, Any] | None)

cache = Memory(location=/tmp/cryoPARES_cache/Image2Sphere.joblib/joblib)
__init__(symmetry, lmax=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, hp_order=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, label_smoothing=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, num_augmented_copies_per_batch=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, enforce_symmetry=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, encoder=None, use_simCLR=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, simCLR_temperature=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, simCLR_loss_weight=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, average_neigs_for_pred=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, example_batch=None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • symmetry (str)

  • lmax (int)

  • hp_order (int)

  • label_smoothing (float)

  • num_augmented_copies_per_batch (int | None)

  • enforce_symmetry (bool)

  • encoder (Module | None)

  • use_simCLR (bool)

  • simCLR_temperature (float)

  • simCLR_loss_weight (float)

  • average_neigs_for_pred (bool)

  • example_batch (Dict[str, Any] | None)

predict_wignerDs(x)[source]
Parameters:

x – image, tensor of shape (B, c, L, L)

Returns:

flatten so3 irreps

from_wignerD_to_topKMats(wD, k)[source]
Parameters:
  • wD – The wignerD matrices

  • k (int) – The number of top-K matrices to report

Returns:

rotMat_logits: (BxP) The logits obtained from the wignerD matrices by projecting them to the SO(3) grid pred_rotmat_id: (BxK) The top-K rotation matrix idxs. They refer to the original idxs, not the subset selected according to symmetry reduction pred_rotmat: (BxKx3x3) The top-K rotation matrices. They refer to the original matrices, not the subset selected according to symmetry reduction

forward_standard(img, top_k)[source]
Img:

float tensor of shape (B, c, L, L)

Top_k:

int number of top K elements to return

Parameters:
forward_with_neigs(img, top_k)[source]
Parameters:
compute_probabilities(img, hp_order=None)[source]
simCLR_like_loss(wD, temperature=0.5)[source]

Compute SimCLR-like contrastive loss using in-plane rotation invariant features.

The loss encourages different augmented views of the same particle to have similar representations in the spherical harmonic feature space (which is invariant to in-plane rotations).

Parameters:
  • wD – Wigner-D coefficients of shape (B, 1, D) where B = num_particles * num_augmented_copies_per_batch

  • temperature – Temperature parameter for NT-Xent loss (controls concentration)

Returns:

Scalar contrastive loss value

Implementation details:
  1. Extract spherical harmonic coefficients (m’=0 column) which are invariant to in-plane rotations

  2. Reshape to group augmented copies: (num_particles, num_augmented_copies, feature_dim)

  3. Compute NT-Xent (Normalized Temperature-scaled Cross Entropy) loss

  4. Positive pairs: different augmented views of same particle

  5. Negative pairs: views from different particles

forward_and_loss(img, gt_rotmat, per_img_weight=None, top_k=1)[source]

Compute cross entropy loss using ground truth rotation, the correct label is the nearest rotation in the spatial grid to the ground truth rotation

Img:

float tensor of shape (B, c, L, L)

Gt_rotmat:

float tensor of valid rotation matrices, tensor of shape (B, 3, 3)

Per_img_weight:

float tensor of shape (B,) with per_image_weight for loss calculation

Top_k:

int number of top K elements to return

Parameters:

top_k (int)

cryoPARES.models.image2sphere.image2sphere.create_extraction_mask(lmax, device)[source]

Create a boolean mask to extract middle columns (m’=0) from flattened Wigner-D matrices. This mask is created once and can be reused for all extractions. Used to get the spherical harmonics

Parameters:
  • lmax – Maximum degree l

  • device_type – String indicating device type (‘cuda’ or ‘cpu’)

cryoPARES.models.image2sphere.image2sphere.extract_sh_coeffs_fast(flat_wigner_d, lmax)[source]

Efficiently extract spherical harmonic coefficients from flattened Wigner-D matrices using cached mask.

cryoPARES.models.image2sphere.image2sphere.plot_so3_distribution(probs, rots, gt_rotation=None, fig=None, ax=None, display_threshold_probability=5e-06, show_color_wheel=True, canonical_rotation=tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))[source]

Taken from https://github.com/google-research/google-research/blob/master/implicit_pdf/evaluation.py

Parameters:

Image Encoders

ResNet Encoder

class cryoPARES.models.image2sphere.imageEncoder.resNet.ResNet(in_channels, resnetName=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, load_imagenetweights=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, **kwargs)[source]

Bases: Module

Parameters:
  • in_channels (int)

  • resnetName (str)

  • load_imagenetweights (bool)

  • out_channels (int)

__init__(in_channels, resnetName=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, load_imagenetweights=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, **kwargs)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • in_channels (int)

  • resnetName (str)

  • load_imagenetweights (bool)

  • out_channels (int)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

U-Net Encoder

class cryoPARES.models.image2sphere.imageEncoder.unet.ConvolutionalBlock(dimensions, in_channels, out_channels, normalization, kernel_size, activation, preactivation, padding, dilation, dropout, padding_mode='zeros')[source]

Bases: Module

Parameters:
  • dimensions (int)

  • in_channels (int)

  • out_channels (int)

  • normalization (str | None)

  • kernel_size (int)

  • activation (str | None)

  • preactivation (bool)

  • padding (int)

  • dilation (int | None)

  • dropout (float)

  • padding_mode (str)

__init__(dimensions, in_channels, out_channels, normalization, kernel_size, activation, preactivation, padding, dilation, dropout, padding_mode='zeros')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • dimensions (int)

  • in_channels (int)

  • out_channels (int)

  • normalization (str | None)

  • kernel_size (int)

  • activation (str | None)

  • preactivation (bool)

  • padding (int)

  • dilation (int | None)

  • dropout (float)

  • padding_mode (str)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static add_if_not_none(module_list, module)[source]
class cryoPARES.models.image2sphere.imageEncoder.unet.UnetEncoder(in_channels, out_channels_first, dimensions, pooling_type, num_encoding_blocks, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, initial_dilation, dropout)[source]

Bases: Module

Parameters:
  • in_channels (int)

  • out_channels_first (int)

  • dimensions (int)

  • pooling_type (str)

  • num_encoding_blocks (int)

  • normalization (str | None)

  • residual (bool)

  • padding (str | int)

  • padding_mode (str)

  • activation (str | None)

  • initial_dilation (int | None)

  • dropout (float)

__init__(in_channels, out_channels_first, dimensions, pooling_type, num_encoding_blocks, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, initial_dilation, dropout)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • in_channels (int)

  • out_channels_first (int)

  • dimensions (int)

  • pooling_type (str)

  • num_encoding_blocks (int)

  • normalization (str | None)

  • residual (bool)

  • padding (str | int)

  • padding_mode (str)

  • activation (str | None)

  • initial_dilation (int | None)

  • dropout (float)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property out_channels
class cryoPARES.models.image2sphere.imageEncoder.unet.EncodingBlock(in_channels, out_channels_first, dimensions, normalization, pooling_type, preactivation, is_first_block, residual, kernel_size, padding, padding_mode, activation, dilation, dropout)[source]

Bases: Module

Parameters:
  • in_channels (int)

  • out_channels_first (int)

  • dimensions (int)

  • normalization (str | None)

  • pooling_type (str | None)

  • preactivation (bool)

  • is_first_block (bool)

  • residual (bool)

  • kernel_size (int)

  • padding (int)

  • activation (str | None)

  • dilation (int | None)

  • dropout (float)

__init__(in_channels, out_channels_first, dimensions, normalization, pooling_type, preactivation, is_first_block, residual, kernel_size, padding, padding_mode, activation, dilation, dropout)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • in_channels (int)

  • out_channels_first (int)

  • dimensions (int)

  • normalization (str | None)

  • pooling_type (str | None)

  • preactivation (bool)

  • is_first_block (bool)

  • residual (bool)

  • kernel_size (int)

  • padding (int)

  • activation (str | None)

  • dilation (int | None)

  • dropout (float)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property out_channels
class cryoPARES.models.image2sphere.imageEncoder.unet.Decoder(in_channels_skip_connection, dimensions, upsampling_type, num_decoding_blocks, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, initial_dilation, dropout)[source]

Bases: Module

Parameters:
  • in_channels_skip_connection (int)

  • dimensions (int)

  • upsampling_type (str)

  • num_decoding_blocks (int)

  • normalization (str | None)

  • kernel_size (int)

  • preactivation (bool)

  • residual (bool)

  • padding (str | int)

  • padding_mode (str)

  • activation (str | None)

  • initial_dilation (int | None)

  • dropout (float)

__init__(in_channels_skip_connection, dimensions, upsampling_type, num_decoding_blocks, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, initial_dilation, dropout)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • in_channels_skip_connection (int)

  • dimensions (int)

  • upsampling_type (str)

  • num_decoding_blocks (int)

  • normalization (str | None)

  • kernel_size (int)

  • preactivation (bool)

  • residual (bool)

  • padding (str | int)

  • padding_mode (str)

  • activation (str | None)

  • initial_dilation (int | None)

  • dropout (float)

forward(skip_connections, x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
class cryoPARES.models.image2sphere.imageEncoder.unet.DecodingBlock(in_channels_skip_connection, dimensions, upsampling_type, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, dilation, dropout)[source]

Bases: Module

Parameters:
  • in_channels_skip_connection (int)

  • dimensions (int)

  • upsampling_type (str)

  • normalization (str | None)

  • kernel_size (int)

  • preactivation (bool)

  • residual (bool)

  • padding (int)

  • padding_mode (str)

  • activation (str | None)

  • dilation (int | None)

__init__(in_channels_skip_connection, dimensions, upsampling_type, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, dilation, dropout)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • in_channels_skip_connection (int)

  • dimensions (int)

  • upsampling_type (str)

  • normalization (str | None)

  • kernel_size (int)

  • preactivation (bool)

  • residual (bool)

  • padding (int)

  • padding_mode (str)

  • activation (str | None)

  • dilation (int | None)

forward(skip_connection, x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

center_crop(skip_connection, x)[source]

Center-crop the skip connection tensor to match the spatial dimensions of x. This version is fully TorchScript compatible by using explicit int types.

class cryoPARES.models.image2sphere.imageEncoder.unet.DecodingStage(decoding_block, skip_index)[source]

Bases: Module

Parameters:
__init__(decoding_block, skip_index)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(skip_connections, x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

Parameters:
class cryoPARES.models.image2sphere.imageEncoder.unet.MultiInputSequential(*modules)[source]

Bases: Module

Parameters:

modules (Module)

__init__(*modules)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:

modules (Module)

forward(skip_connections, x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

Parameters:
append(module)[source]
Parameters:

module (Module)

cryoPARES.models.image2sphere.imageEncoder.unet.get_upsampling_layer(upsampling_type)[source]
Return type:

Upsample

Parameters:

upsampling_type (str)

cryoPARES.models.image2sphere.imageEncoder.unet.get_conv_transpose_layer(dimensions, in_channels, out_channels)[source]
cryoPARES.models.image2sphere.imageEncoder.unet.fix_upsampling_type(upsampling_type, dimensions)[source]
Parameters:
  • upsampling_type (str)

  • dimensions (int)

cryoPARES.models.image2sphere.imageEncoder.unet.get_downsampling_layer(dimensions, pooling_type, kernel_size=2)[source]
Return type:

Module

Parameters:
  • dimensions (int)

  • pooling_type (str)

  • kernel_size (int)

class cryoPARES.models.image2sphere.imageEncoder.unet.Unet(in_channels, n_blocks=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels_first=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, n_decoder_blocks_removed=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, kernel_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, pooling=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, padding=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, activation=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, upsampling_type=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, dropout=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, keep_2d=True, **kwargs)[source]

Bases: Module

Parameters:
  • in_channels (int)

  • n_blocks (int)

  • out_channels (int | None)

  • out_channels_first (int)

  • n_decoder_blocks_removed (int)

  • kernel_size (int)

  • pooling (str)

  • padding (str)

  • activation (str)

  • normalization (str)

  • upsampling_type (str)

  • dropout (float)

__init__(in_channels, n_blocks=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels_first=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, n_decoder_blocks_removed=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, kernel_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, pooling=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, padding=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, activation=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, upsampling_type=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, dropout=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, keep_2d=True, **kwargs)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • in_channels (int)

  • n_blocks (int)

  • out_channels (int | None)

  • out_channels_first (int)

  • n_decoder_blocks_removed (int)

  • kernel_size (int)

  • pooling (str)

  • padding (str)

  • activation (str)

  • normalization (str)

  • upsampling_type (str)

  • dropout (float)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ConvMixer Encoder

class cryoPARES.models.image2sphere.imageEncoder.convMixer.ResidualForConvMixer(fn)[source]

Bases: Module

__init__(fn)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cryoPARES.models.image2sphere.imageEncoder.convMixer.ConvMixer(in_channels, hidden_dim=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, n_blocks=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, kernel_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, patch_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, add_stem=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, dropout_rate=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, global_pooling=False, flatten_if_no_global_pooling=False, flatten_start_dim=1, **kwargs)[source]

Bases: Module

Parameters:
  • hidden_dim (int)

  • n_blocks (int)

  • kernel_size (int)

  • patch_size (int)

  • out_channels (int)

  • add_stem (bool)

  • dropout_rate (float)

  • normalization (Literal['Batch'])

__init__(in_channels, hidden_dim=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, n_blocks=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, kernel_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, patch_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, add_stem=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, dropout_rate=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, global_pooling=False, flatten_if_no_global_pooling=False, flatten_start_dim=1, **kwargs)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • hidden_dim (int)

  • n_blocks (int)

  • kernel_size (int)

  • patch_size (int)

  • out_channels (int)

  • add_stem (bool)

  • dropout_rate (float)

  • normalization (Literal['Batch'])

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Directional Normalizer

class cryoPARES.models.directionalNormalizer.directionalNormalizer.DirectionalPercentileNormalizer(symmetry, hp_order=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>)[source]

Bases: Module

Neural network module for computing directional percentiles on S2 space.

This module normalizes prediction scores based on their orientation in S2 space, addressing the issue where prediction quality can vary by viewing direction. It can be attached to an existing neural network that predicts SO(3) indices.

The normalization is based on computing per-cone statistics (median and MAD) and converting raw scores to Z-scores, making scores comparable across different orientations regardless of inherent direction-specific biases.

Important assumptions: 1. SO(3) indices are organized as consecutive in-plane rotations for each cone 2. The formula cone_index = so3_index // n_psi is valid for the grid structure 3. The in-plane rotation dimension has consistent size (n_psi) across all cones

Parameters:
  • symmetry (str)

  • hp_order (int)

__init__(symmetry, hp_order=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • symmetry (str)

  • hp_order (int)

so3_to_cone_ids(so3_indices)[source]

Convert SO(3) indices to cone indices using integer division.

This mapping assumes the SO(3) grid structure from so3_healpix_grid_equiangular where the full orientation space is organized as:

  • n_cones cone directions (alpha, beta pairs)

  • For each cone, n_psi in-plane rotations (gamma angles)

  • The SO(3) index increases sequentially, with all in-plane rotations for a cone stored consecutively before moving to the next cone

Parameters:

so3_indices (Tensor) – Tensor of SO(3) indices

Return type:

Tensor

Returns:

Tensor of cone indices

rotmats_to_cone_id(rotmats)[source]

Convert rotation matrices to cone indices.

Parameters:

rotmats (Tensor) – Tensor of rotation matrices

Return type:

Tensor

Returns:

Tensor of cone indices

fit(pred_rotmats, scores, gt_rotmats=None, good_particles_percentile=95.0, min_particles_per_cone=10)[source]

Estimate normalization parameters for each cone from a reference dataset.

This method analyzes scores grouped by orientation (cone) to compute robust statistics that will be used for normalization during inference.

When ground truth is available, it uses particles with correct orientations. When ground truth is unavailable, it uses top-scoring particles, assuming they are more likely to be correct.

Parameters:
  • pred_rotmats (Tensor) – Predicted SO(3) rotmats for particles. Shape Bx3x3

  • scores (Tensor) – Prediction scores for particles

  • gt_rotmats (Optional[Tensor]) – Ground truth SO(3) rotmats (if available for training)

  • good_particles_percentile (float) – Percentile of particles to use when no ground truth Higher values mean only considering top-scored particles

  • min_particles_per_cone (int) – Minimum number of particles required for reliable statistics Cones with fewer particles will use global statistics

Return type:

None

forward(pred_rotmats, scores)[source]

Apply directional normalization to scores.

Parameters:
  • pred_rotmats (Tensor) – SO(3) rotmats

  • scores (Tensor) – Raw prediction scores

Return type:

Tensor

Returns:

Normalized scores (Z-scores)

save(path)[source]

Save normalization parameters to a file.

Parameters:

path (str) – File path to save parameters

Return type:

None

classmethod load(path, device='cpu')[source]

Load normalization parameters from a file.

Parameters:
  • path (str) – File path to load parameters from

  • device (str) – Device to load parameters to

Return type:

DirectionalPercentileNormalizer

Returns:

Loaded DirectionalPercentileNormalizer