Training Guide

This guide provides detailed instructions for training CryoPARES models, including best practices for avoiding overfitting/underfitting and monitoring training progress.

Overview

Training in CryoPARES takes a set of pre-aligned particles (from RELION, cryoSPARC, or other alignment tools) and learns a deep learning model that can predict particle orientations. The training process produces a checkpoint directory (named version_0, version_1, etc.) containing the trained model weights. Training an accurate model is essential for obtaining good results during inference - the quality of your trained model directly determines the accuracy of pose predictions on new data.

Table of Contents

Quick Start

Basic training command:

python -m cryopares_train \
    --symmetry C1 \
    --particles_star_fname /path/to/aligned_particles.star \
    --particles_dir /path/to/particles \
    --train_save_dir /path/to/output \
    --n_epochs 10 \
    --batch_size 32

Important: Before training, increase the file descriptor limit:

ulimit -n 65536

If you cannot increase it, make sure that you group all your images into a small number of stack files (.mrcs).

Training Parameters

Essential Parameters

Parameter

Type

Default

Description

--symmetry

str

Required

Point group symmetry (C1, D7, T, O, I, etc.)

--particles_star_fname

str

Required

Path to pre-aligned RELION .star file

--train_save_dir

str

Required

Output directory for checkpoints and logs

--n_epochs

int

100

Number of training epochs

--batch_size

int

32

Number of particles per batch

Data Configuration Parameters

Override via --config flag with dot notation:

--config datamanager.particlesDataset.sampling_rate_angs_for_nnet=2.0 \
         datamanager.particlesDataset.image_size_px_for_nnet=128

Key data parameters:

  • datamanager.particlesDataset.sampling_rate_angs_for_nnet (float)

    • Target sampling rate in Ångströms/pixel

    • Images are rescaled to this value before training

    • Lower values = higher resolution but more memory/compute

    • Downsampling helps to reduce noise

  • datamanager.particlesDataset.image_size_px_for_nnet (int)

    • Final image size in pixels after rescaling

    • Images are cropped/padded to this size

    • Must be large enough to contain the particle after rescaling

      • But we recommend using tight boxes

  • datamanager.particlesDataset.mask_radius_angs (float, optional)

    • Radius of circular mask in Ångströms

    • If not set, uses half the box size

    • Helps the network focus on the particle

    • Notice that image_size_px_for_nnet is not the same as mask_radius_angs, but that they should be closely related.

Optimizer Configuration

--config train.learning_rate=1e-3 \
         train.weight_decay=1e-5 \
         train.accumulate_grad_batches=16

Key optimizer parameters:

  • train.learning_rate (float, default: 1e-3)

    • Controls how much to update model weights during each training step

    • The learning rate determines the size of the steps taken during gradient descent optimization

    • Higher learning rate = faster training but risk of overshooting optimal weights

    • Lower learning rate = more stable training but slower convergence

    • Good values: 1e-4 to 1e-3 depending on your data and model complexity

    • Start with default (1e-3) and adjust based on training behavior

    • Automatically reduced by 0.5 when validation loss plateaus (see patient_reduce_lr_plateau_n_epochs below)

  • train.weight_decay (float)

    • L2 regularization coefficient

    • Incrase it to reduce overfitting

  • train.accumulate_grad_batches (int)

    • Number of batches to accumulate gradients over

    • Simulates larger batch sizes: effective_batch_size = batch_size × accumulate_grad_batches

    • Useful when GPU memory is limited

  • train.patient_reduce_lr_plateau_n_epochs (int, default: 3)

    • Patience for ReduceLROnPlateau scheduler

    • LR reduced by 0.5 if val_loss doesn’t improve for this many epochs

Model Architecture Parameters

Example:

--config models.image2sphere.lmax=8 models.image2sphere.so3components.i2sprojector.sphere_fdim=128 

Key architecture parameters:

  • models.image2sphere.lmax (int, default: 12)

    • Maximum spherical harmonic degree for SO(3) representation

    • Higher values = more expressive model but slower training and more memory usage

    • Typical values: 8, 10, 12

  • models.image2sphere.label_smoothing (float, default: 0.05)

    • Label smoothing factor for loss function to prevent overconfidence

    • Range: 0.0 (no smoothing) to 0.2 (strong smoothing)

    • Helps with generalization

  • models.image2sphere.so3components.i2sprojector.sphere_fdim (int, default: 512)

    • Feature dimension for spherical representation in the image-to-sphere projector

    • Higher values = more capacity but slower training

    • Typical values: 256, 512, 1024

  • models.image2sphere.so3components.i2sprojector.rand_fraction_points_to_project (float, default: 0.5)

    • Fraction of points to randomly sample for projection (reduces computation)

    • Acts like a form of dropout

    • Range: 0.1 to 1.0 (1.0 = use all points)

    • Lower values = faster, potentially less accurate while more robust to overfitting

  • models.image2sphere.so3components.s2conv.f_out (int, default: 64)

    • Number of output features from S2 (sphere) convolution

    • Higher values = more capacity but slower training and potentially more overfitting

    • Typical values: 32, 64, 128

  • models.image2sphere.so3components.i2sprojector.hp_order (int, default: 3)

    • HEALPix order for image-to-sphere projector grid resolution

    • Higher values = finer resolution but more computation and potentially more overfitting

    • Each increment roughly doubles resolution. Going beyond 4 is not advisable

  • models.image2sphere.so3components.s2conv.hp_order (int, default: 4)

    • HEALPix order for S2 convolution grid resolution

    • Controls the resolution of spherical convolution

    • Higher values = finer resolution but more computation and potentially more overfitting

    • Each increment roughly doubles resolution. Going beyond 4 is not advisable

  • models.image2sphere.so3components.so3outputgrid.hp_order (int, default: 4)

    • HEALPix order for SO(3) output grid resolution

    • Affects the final orientation prediction granularity

    • Higher values = finer resolution but more computation and potentially more overfitting

    • Each increment roughly doubles resolution. Going beyond 4 is not advisable

Data Augmentation Parameters

Data augmentation is enabled by default and helps the model generalize by creating variations of training images. CryoPARES applies multiple augmentation operations randomly to each particle image. By default, each image gets augmented several times in the same batch.

View current augmentation settings:

cryopares_train --show-config | grep -A 20 "augmenter:"

Key augmentation parameters:

  • datamanager.augment_train (bool, default: True)

    • Enable/disable data augmentation for training

    • Keep enabled for better generalization

  • datamanager.num_augmented_copies_per_batch (int, default: 4)

    • Number of augmented copies per particle in each batch

    • Each copy undergoes different random augmentations

    • Batch size must be divisible by this value

    • Higher values improve robustness but increase computation

  • datamanager.augmenter.prob_augment_each_image (float, default: 0.95)

    • Probability of applying augmentation to each image

    • Range: 0.0 (no augmentation) to 1.0 (always augment)

  • datamanager.augmenter.min_n_augm_per_img (int, default: 1)

    • Minimum number of augmentation operations to apply per image

  • datamanager.augmenter.max_n_augm_per_img (int, default: 8)

    • Maximum number of augmentation operations to apply per image

Available augmentation operations (with default probabilities):

  • Gaussian noise (operations.randomGaussNoise.p=0.1)

    • Adds random Gaussian noise to simulate imaging variations

  • Uniform noise (operations.randomUnifNoise.p=0.2)

    • Adds uniform random noise

  • Gaussian blur (operations.gaussianBlur.p=0.2)

    • Blurs the image to simulate defocus variations

  • Size perturbation (operations.sizePerturbation.p=0.2)

    • Slightly scales the particle (simulates magnification errors)

  • Random erasing (operations.erasing.p=0.1)

    • Randomly erases rectangular regions (simulates occlusions)

  • Elastic deformation (operations.randomElastic.p=0.1)

    • Applies elastic distortions to the image

  • In-plane rotations (90°) (operations.inPlaneRotations90.p=1.0)

    • Rotates by 90°, 180°, or 270° (always applied for SO(3) symmetry)

  • In-plane rotations (small) (operations.inPlaneRotations.p=0.5)

    • Random rotations up to ±20° (simulates alignment errors)

  • In-plane shifts (operations.inPlaneShifts.p=0.5)

    • Random shifts up to 5% of image size (simulates centering errors)

Example: Adjusting augmentation strength

To reduce augmentation (if underfitting is an issue):

--config datamanager.augmenter.prob_augment_each_image=0.5 \
         datamanager.augmenter.max_n_augm_per_img=4

To increase augmentation (to combat overfitting):

--config datamanager.num_augmented_copies_per_batch=8 \
         datamanager.augmenter.prob_augment_each_image=0.98 \
         datamanager.augmenter.operations.gaussianBlur.p=0.3

To disable specific augmentations:

--config datamanager.augmenter.operations.randomElastic.p=0.0 \
         datamanager.augmenter.operations.erasing.p=0.0

Monitoring Training with TensorBoard

CryoPARES uses PyTorch Lightning, which automatically logs metrics to TensorBoard.

Launching TensorBoard

tensorboard --logdir /path/to/train_save_dir/version_0

Then open your browser to http://localhost:6006

Key Metrics to Monitor

1. Training Loss (loss)

  • Should decrease steadily during training

  • Measures how well the model predicts rotations on training data

  • If it plateaus early, try:

    • Increasing learning rate

    • Checking if data augmentation is too aggressive

    • Reducing weight decay

2. Validation Loss (val_loss)

  • Most important metric for model quality

  • Should track training loss but slightly higher

  • Warning signs:

    • Val loss much higher than train loss → overfitting

    • Val loss not decreasing → underfitting or learning rate too low

3. Angular Error (geo_degs, val_geo_degs)

  • Average angular error in degrees

  • Goal: As low as possible (typically < 15° for good models)

  • val_geo_degs is the key metric for final model quality

  • It is easier to check for overfitting comparing geo_degs vs val_geo_degs

4. Median Angular Error (val_median_geo_degs)

  • More robust than mean to outliers

  • Should also decrease during training

  • Goal: As low as possible (typically < 8° for good models)

  • val_geo_degs and val_median_geo_degs represent the same property, but aggregated in a different manner. They should follow the same trends.

5. Learning Rate (lr)

  • Displayed in optimizer logs

  • Watch for automatic reductions via ReduceLROnPlateau

  • If LR drops too early, increase train.patient_reduce_lr_plateau_n_epochs

6. Visualization: Rotation Matrices

  • TensorBoard shows predicted vs. ground truth rotation matrices

  • Available under “Images” tab

  • Visual confirmation that the model is learning meaningful rotations

Example TensorBoard Monitoring Session

# Start training
python -m cryopares_train \
    --symmetry C1 \
    --particles_star_fname data/particles.star \
    --train_save_dir experiments/run_001 \
    --n_epochs 20

# In another terminal, launch TensorBoard
tensorboard --logdir experiments/run_001/version_0

# Monitor these curves:
# 1. loss (should decrease smoothly)
# 2. val_loss (should track loss)
# 3. val_geo_degs (target: < 5 degrees)
# 4. Learning rate (should stay constant, then drop on plateau)

Overfitting and Underfitting

What is Overfitting?

Overfitting occurs when the model learns to memorize training data instead of generalizing to new data.

Symptoms:

  • Training loss continues to decrease while validation loss increases or plateaus

  • Large gap between loss and val_loss

  • val_geo_degs stops improving or gets worse

Solutions:

  1. Increase regularization:

    --config train.weight_decay=1e-4  models.image2sphere.label_smoothing=0.1 # Increase from default 1e-5, and 0.05 respectively
    
  2. Reduce model complexity:

    --config models.image2sphere.lmax=10 # Decrease from default 12 models.image2sphere.lmax=10
    
  3. Add more training data:

    • Use more particles in your training .star file

  4. Increase data augmentation: Data augmentation is enabled by default. Check current settings:

    python -m cryopares_train --show-config | grep augmentation
    

What is Underfitting?

Underfitting occurs when the model is too simple to capture patterns in the data.

Symptoms:

  • Both training and validation loss remain high

  • val_geo_degs > 10 degrees even after many epochs

  • Loss curves plateau early

Solutions:

  1. Increase model complexity:

    --config models.image2sphere.lmax=14 models.image2sphere.so3components.i2sprojector.sphere_fdim=756 # Increase from default 12
    
  2. Reduce regularization:

    --config train.weight_decay=1e-6  # Decrease from default 1e-5
    
  3. Check data preprocessing:

    • Ensure sampling_rate_angs_for_nnet matches your data resolution

    • Verify particle images are properly centered and normalized

The Sweet Spot

Ideal training behavior:

  • Both loss and val_loss decrease together

  • Small gap between train and validation metrics

  • val_geo_degs reaches < 15 degrees

  • Validation metrics improve for at least 50 epochs

Example of good training:

TODO: Copy here from bgal

Data Preprocessing

Image Size and Sampling Rate

The neural network operates on rescaled particle images. Understanding this is crucial:

  1. Original images: Read from .star file with original pixel size and sampling rate

  2. Rescaling: Images are rescaled to sampling_rate_angs_for_nnet

  3. Crop/Pad: Images are cropped or padded to image_size_px_for_nnet

Example:

Original: 256×256 px at 1.0 Å/px (256 Å box)
Target: 128×128 px at 2.0 Å/px (256 Å box)

Result: Image is downsampled by 2×, then center-cropped to 128×128

Guidelines:

  • sampling_rate_angs_for_nnet: 1.5-3.0 Å/px works well for most proteins

  • image_size_px_for_nnet: Should contain entire particle after rescaling, but we prefer tight boxes.

  • Rule of thumb: image_size_px_for_nnet × sampling_rate_angs_for_nnet particle_diameter + padding

Masking

--config datamanager.particlesDataset.mask_radius_angs=100
  • Applies a soft circular mask to focus on the particle

  • Set to slightly larger than particle radius

  • Too small → cuts off particle features

  • Too large → includes too much noise

CTF Correction

CTF correction is applied automatically during training:

  • Phase flip by default

  • Ensures the model sees properly corrected images

Advanced Training Options

Continue Training

Resume from a previous checkpoint:

python -m cryopares_train \
    --continue_checkpoint_dir /path/to/train_save_dir/version_0 \
    --n_epochs 30  # Train for 30 total epochs

Fine-tuning

Start from a pre-trained model and adapt to new data:

python -m cryopares_train \
    --symmetry C1 \
    --particles_star_fname /path/to/new_particles.star \
    --train_save_dir /path/to/finetuned_model \
    --finetune_checkpoint_dir /path/to/pretrained_model/version_0 \
    --n_epochs 5 \
    --config train.learning_rate=1e-4  # Lower LR for fine-tuning

Half-Set Training

By default, CryoPARES trains two models (half1 and half2) for cross-validation:

This creates:

  • version_0/half1/ - Model trained on particles with RandomSubset=1

  • version_0/half2/ - Model trained on particles with RandomSubset=2

Benefits:

  • Enables gold-standard FSC calculations

  • Recommended for production use

To train on all data (single model):

--NOT_split_halves

Simulated Pre-training

Pre-train on simulated data before training on real particles:

python -m cryopares_train \
    --symmetry C1 \
    --particles_star_fname /path/to/real_particles.star \
    --train_save_dir /path/to/output \
    --map_fname_for_simulated_pretraining /path/to/reference_map.mrc \
    --n_epochs 40 \
    --config train.n_epochs_simulation=5 #We will train the model for 3 epochs using simulated data

This first trains on simulated projections of the reference map, then fine-tunes on real data.

Model Compilation

Speed up training with PyTorch compilation (requires PyTorch 2.0+):

--compile_model

Note: Compilation adds overhead at startup but can speed up training by 10-30%.

Debugging: Overfitting on Small Batches

Test your setup quickly by overfitting on a few batches:

python -m cryopares_train \
    --symmetry C1 \
    --particles_star_fname /path/to/particles.star \
    --train_save_dir /tmp/overfit_test \
    --n_epochs 100 \
    --overfit_batches 10

Troubleshooting Training Issues

Training is very slow

Solutions:

  • Enable model compilation: --compile_model

  • Reduce image size: --config datamanager.particlesDataset.image_size_px_for_nnet=96

  • Increase batch size: --batch_size 64

  • Use multiple GPUs (automatically detected)

  • Reduce model complexity, like lmax: --config models.image2sphere.lmax=10. Although it will have an impact on model performance (overfitting/underfitting)

Out of memory errors

Solutions:

  • Reduce batch size: --batch_size 16

    • You can increase gradient accumulation to partially compensate for the reduction: --config train.accumulate_grad_batches=32

  • Reduce image size: --config datamanager.particlesDataset.image_size_px_for_nnet=96

“Too many open files” error

ulimit -n 65536

Run this before every training session, or add to .bashrc:

echo "ulimit -n 65536" >> ~/.bashrc

Loss is NaN

Causes:

  • Learning rate too high

  • Numerical instability

Solutions:

--config train.learning_rate=1e-4  # Reduce LR

Model not improving

Checklist:

  1. Verify data is properly aligned in the input .star file

  2. Check that particles are centered

  3. Ensure sufficient training data (>100000 particles recommended)

  4. Verify symmetry is correct

  5. Try increasing learning rate: --config train.learning_rate=5e-3

  6. Increase model capacity: --config models.image2sphere.lmax=12 models.image2sphere.so3components.s2conv.f_out=128

Validation loss jumps around

Causes:

  • Validation set too small

  • Batch size too small

Solutions:

  • Ensure >10000 particles in validation set

  • Increase batch size: --batch_size 64

See Also