Models API
PyTorch Lightning Model
- class cryoPARES.models.model.PlModel(lr, symmetry, num_augmented_copies_per_batch, top_k_poses_nnet, so3model=None)[source]
Bases:
RotationPredictionMixin
,LightningModule
- Parameters:
- training_step(batch, batch_idx)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader
.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
- validation_step(batch, batch_idx)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader
.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- forward(imgs, batch_idx, dataloader_idx=0, top_k=None)[source]
Same as
torch.nn.Module.forward()
.
- optimizer_step_v1(epoch, batch_idx, optimizer, optimizer_idx=0, optimizer_closure=None, on_tpu=False, using_lbfgs=False)[source]
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated", # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
Some things to know:
Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
PlModel Class
- class cryoPARES.models.model.PlModel(lr, symmetry, num_augmented_copies_per_batch, top_k_poses_nnet, so3model=None)[source]
Bases:
RotationPredictionMixin
,LightningModule
- Parameters:
- training_step(batch, batch_idx)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader
.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
- validation_step(batch, batch_idx)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader
.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()
will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- forward(imgs, batch_idx, dataloader_idx=0, top_k=None)[source]
Same as
torch.nn.Module.forward()
.
- optimizer_step_v1(epoch, batch_idx, optimizer, optimizer_idx=0, optimizer_closure=None, on_tpu=False, using_lbfgs=False)[source]
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.# The ReduceLROnPlateau scheduler requires a monitor def configure_optimizers(self): optimizer = Adam(...) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau(optimizer, ...), "monitor": "metric_to_track", "frequency": "indicates how often the metric is updated", # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, } # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler def configure_optimizers(self): optimizer1 = Adam(...) optimizer2 = SGD(...) scheduler1 = ReduceLROnPlateau(optimizer1, ...) scheduler2 = LambdaLR(optimizer2, ...) return ( { "optimizer": optimizer1, "lr_scheduler": { "scheduler": scheduler1, "monitor": "metric_to_track", }, }, {"optimizer": optimizer2, "lr_scheduler": scheduler2}, )
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
Some things to know:
Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
Image2Sphere Network
- class cryoPARES.models.image2sphere.image2sphere.Image2Sphere(symmetry, lmax=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, hp_order=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, label_smoothing=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, num_augmented_copies_per_batch=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, enforce_symmetry=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, encoder=None, use_simCLR=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, simCLR_temperature=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, simCLR_loss_weight=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, average_neigs_for_pred=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, example_batch=None)[source]
Bases:
Module
Instantiate Image2Sphere-style network for predicting distributions over SO(3) from single image
- Parameters:
- cache = Memory(location=/tmp/cryoPARES_cache/Image2Sphere.joblib/joblib)
- __init__(symmetry, lmax=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, hp_order=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, label_smoothing=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, num_augmented_copies_per_batch=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, enforce_symmetry=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, encoder=None, use_simCLR=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, simCLR_temperature=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, simCLR_loss_weight=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, average_neigs_for_pred=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, example_batch=None)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
- predict_wignerDs(x)[source]
- Parameters:
x – image, tensor of shape (B, c, L, L)
- Returns:
flatten so3 irreps
- from_wignerD_to_topKMats(wD, k)[source]
- Parameters:
wD – The wignerD matrices
k (
int
) – The number of top-K matrices to report
- Returns:
rotMat_logits: (BxP) The logits obtained from the wignerD matrices by projecting them to the SO(3) grid pred_rotmat_id: (BxK) The top-K rotation matrix idxs. They refer to the original idxs, not the subset selected according to symmetry reduction pred_rotmat: (BxKx3x3) The top-K rotation matrices. They refer to the original matrices, not the subset selected according to symmetry reduction
- simCLR_like_loss(wD, temperature=0.5)[source]
Compute SimCLR-like contrastive loss using in-plane rotation invariant features.
The loss encourages different augmented views of the same particle to have similar representations in the spherical harmonic feature space (which is invariant to in-plane rotations).
- Parameters:
wD – Wigner-D coefficients of shape (B, 1, D) where B = num_particles * num_augmented_copies_per_batch
temperature – Temperature parameter for NT-Xent loss (controls concentration)
- Returns:
Scalar contrastive loss value
- Implementation details:
Extract spherical harmonic coefficients (m’=0 column) which are invariant to in-plane rotations
Reshape to group augmented copies: (num_particles, num_augmented_copies, feature_dim)
Compute NT-Xent (Normalized Temperature-scaled Cross Entropy) loss
Positive pairs: different augmented views of same particle
Negative pairs: views from different particles
- forward_and_loss(img, gt_rotmat, per_img_weight=None, top_k=1)[source]
Compute cross entropy loss using ground truth rotation, the correct label is the nearest rotation in the spatial grid to the ground truth rotation
- Img:
float tensor of shape (B, c, L, L)
- Gt_rotmat:
float tensor of valid rotation matrices, tensor of shape (B, 3, 3)
- Per_img_weight:
float tensor of shape (B,) with per_image_weight for loss calculation
- Top_k:
int number of top K elements to return
- Parameters:
top_k (int)
- cryoPARES.models.image2sphere.image2sphere.create_extraction_mask(lmax, device)[source]
Create a boolean mask to extract middle columns (m’=0) from flattened Wigner-D matrices. This mask is created once and can be reused for all extractions. Used to get the spherical harmonics
- Parameters:
lmax – Maximum degree l
device_type – String indicating device type (‘cuda’ or ‘cpu’)
- cryoPARES.models.image2sphere.image2sphere.extract_sh_coeffs_fast(flat_wigner_d, lmax)[source]
Efficiently extract spherical harmonic coefficients from flattened Wigner-D matrices using cached mask.
- cryoPARES.models.image2sphere.image2sphere.plot_so3_distribution(probs, rots, gt_rotation=None, fig=None, ax=None, display_threshold_probability=5e-06, show_color_wheel=True, canonical_rotation=tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))[source]
Taken from https://github.com/google-research/google-research/blob/master/implicit_pdf/evaluation.py
Image Encoders
ResNet Encoder
- class cryoPARES.models.image2sphere.imageEncoder.resNet.ResNet(in_channels, resnetName=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, load_imagenetweights=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, **kwargs)[source]
Bases:
Module
- __init__(in_channels, resnetName=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, load_imagenetweights=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, **kwargs)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
U-Net Encoder
- class cryoPARES.models.image2sphere.imageEncoder.unet.ConvolutionalBlock(dimensions, in_channels, out_channels, normalization, kernel_size, activation, preactivation, padding, dilation, dropout, padding_mode='zeros')[source]
Bases:
Module
- Parameters:
- __init__(dimensions, in_channels, out_channels, normalization, kernel_size, activation, preactivation, padding, dilation, dropout, padding_mode='zeros')[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cryoPARES.models.image2sphere.imageEncoder.unet.UnetEncoder(in_channels, out_channels_first, dimensions, pooling_type, num_encoding_blocks, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, initial_dilation, dropout)[source]
Bases:
Module
- Parameters:
- __init__(in_channels, out_channels_first, dimensions, pooling_type, num_encoding_blocks, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, initial_dilation, dropout)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property out_channels
- class cryoPARES.models.image2sphere.imageEncoder.unet.EncodingBlock(in_channels, out_channels_first, dimensions, normalization, pooling_type, preactivation, is_first_block, residual, kernel_size, padding, padding_mode, activation, dilation, dropout)[source]
Bases:
Module
- Parameters:
- __init__(in_channels, out_channels_first, dimensions, normalization, pooling_type, preactivation, is_first_block, residual, kernel_size, padding, padding_mode, activation, dilation, dropout)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property out_channels
- class cryoPARES.models.image2sphere.imageEncoder.unet.Decoder(in_channels_skip_connection, dimensions, upsampling_type, num_decoding_blocks, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, initial_dilation, dropout)[source]
Bases:
Module
- Parameters:
- __init__(in_channels_skip_connection, dimensions, upsampling_type, num_decoding_blocks, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, initial_dilation, dropout)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
- forward(skip_connections, x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cryoPARES.models.image2sphere.imageEncoder.unet.DecodingBlock(in_channels_skip_connection, dimensions, upsampling_type, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, dilation, dropout)[source]
Bases:
Module
- Parameters:
- __init__(in_channels_skip_connection, dimensions, upsampling_type, normalization, kernel_size, preactivation, residual, padding, padding_mode, activation, dilation, dropout)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(skip_connection, x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cryoPARES.models.image2sphere.imageEncoder.unet.DecodingStage(decoding_block, skip_index)[source]
Bases:
Module
- __init__(decoding_block, skip_index)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(skip_connections, x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cryoPARES.models.image2sphere.imageEncoder.unet.MultiInputSequential(*modules)[source]
Bases:
Module
- Parameters:
modules (Module)
- __init__(*modules)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
modules (Module)
- forward(skip_connections, x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- cryoPARES.models.image2sphere.imageEncoder.unet.get_conv_transpose_layer(dimensions, in_channels, out_channels)[source]
- cryoPARES.models.image2sphere.imageEncoder.unet.fix_upsampling_type(upsampling_type, dimensions)[source]
- cryoPARES.models.image2sphere.imageEncoder.unet.get_downsampling_layer(dimensions, pooling_type, kernel_size=2)[source]
- class cryoPARES.models.image2sphere.imageEncoder.unet.Unet(in_channels, n_blocks=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels_first=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, n_decoder_blocks_removed=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, kernel_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, pooling=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, padding=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, activation=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, upsampling_type=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, dropout=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, keep_2d=True, **kwargs)[source]
Bases:
Module
- Parameters:
- __init__(in_channels, n_blocks=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels_first=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, n_decoder_blocks_removed=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, kernel_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, pooling=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, padding=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, activation=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, upsampling_type=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, dropout=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, keep_2d=True, **kwargs)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
ConvMixer Encoder
- class cryoPARES.models.image2sphere.imageEncoder.convMixer.ResidualForConvMixer(fn)[source]
Bases:
Module
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cryoPARES.models.image2sphere.imageEncoder.convMixer.ConvMixer(in_channels, hidden_dim=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, n_blocks=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, kernel_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, patch_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, add_stem=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, dropout_rate=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, global_pooling=False, flatten_if_no_global_pooling=False, flatten_start_dim=1, **kwargs)[source]
Bases:
Module
- Parameters:
- __init__(in_channels, hidden_dim=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, n_blocks=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, kernel_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, patch_size=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, out_channels=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, add_stem=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, dropout_rate=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, global_pooling=False, flatten_if_no_global_pooling=False, flatten_start_dim=1, **kwargs)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Directional Normalizer
- class cryoPARES.models.directionalNormalizer.directionalNormalizer.DirectionalPercentileNormalizer(symmetry, hp_order=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>)[source]
Bases:
Module
Neural network module for computing directional percentiles on S2 space.
This module normalizes prediction scores based on their orientation in S2 space, addressing the issue where prediction quality can vary by viewing direction. It can be attached to an existing neural network that predicts SO(3) indices.
The normalization is based on computing per-cone statistics (median and MAD) and converting raw scores to Z-scores, making scores comparable across different orientations regardless of inherent direction-specific biases.
Important assumptions: 1. SO(3) indices are organized as consecutive in-plane rotations for each cone 2. The formula cone_index = so3_index // n_psi is valid for the grid structure 3. The in-plane rotation dimension has consistent size (n_psi) across all cones
- __init__(symmetry, hp_order=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- so3_to_cone_ids(so3_indices)[source]
Convert SO(3) indices to cone indices using integer division.
This mapping assumes the SO(3) grid structure from so3_healpix_grid_equiangular where the full orientation space is organized as:
n_cones cone directions (alpha, beta pairs)
For each cone, n_psi in-plane rotations (gamma angles)
The SO(3) index increases sequentially, with all in-plane rotations for a cone stored consecutively before moving to the next cone
- fit(pred_rotmats, scores, gt_rotmats=None, good_particles_percentile=95.0, min_particles_per_cone=10)[source]
Estimate normalization parameters for each cone from a reference dataset.
This method analyzes scores grouped by orientation (cone) to compute robust statistics that will be used for normalization during inference.
When ground truth is available, it uses particles with correct orientations. When ground truth is unavailable, it uses top-scoring particles, assuming they are more likely to be correct.
- Parameters:
pred_rotmats (
Tensor
) – Predicted SO(3) rotmats for particles. Shape Bx3x3scores (
Tensor
) – Prediction scores for particlesgt_rotmats (
Optional
[Tensor
]) – Ground truth SO(3) rotmats (if available for training)good_particles_percentile (
float
) – Percentile of particles to use when no ground truth Higher values mean only considering top-scored particlesmin_particles_per_cone (
int
) – Minimum number of particles required for reliable statistics Cones with fewer particles will use global statistics
- Return type: