Data Management API

Data Manager

cryoPARES.datamanager.datamanager.get_number_image_channels()[source]
cryoPARES.datamanager.datamanager.get_example_random_batch(batch_size, n_channels=None, seed=None)[source]
class cryoPARES.datamanager.datamanager.DataManager(star_fnames, symmetry, particles_dir, halfset, batch_size, save_train_val_partition_dir, is_global_zero, num_augmented_copies_per_batch=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, train_validaton_split_seed=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, train_validation_split=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, num_dataworkers=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, augment_train=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, only_first_dataset_for_validation=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, return_ori_imagen=False, subset_idxs=None)[source]

Bases: LightningDataModule

DataManager: A LightningDataModule that wraps a ParticlesDataset

Parameters:
__init__(star_fnames, symmetry, particles_dir, halfset, batch_size, save_train_val_partition_dir, is_global_zero, num_augmented_copies_per_batch=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, train_validaton_split_seed=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, train_validation_split=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, num_dataworkers=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, augment_train=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, only_first_dataset_for_validation=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, return_ori_imagen=False, subset_idxs=None)[source]
Parameters:
prepare_data_per_node

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices

If True, dataloader with zero length within local rank is allowed. Default value is False.

prepare_data()[source]

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

Warning

DO NOT set state to the model (use setup instead) since this is NOT called on every device

Example:

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In a distributed environment, prepare_data can be called in two ways (using prepare_data_per_node)

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.

  2. Once in total. Only called on GLOBAL_RANK=0.

Example:

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True


# call on GLOBAL_RANK=0 (great for shared file systems)
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = False

This is called before requesting the dataloaders:

model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
Return type:

None

create_dataset(partitionName)[source]
train_dataloader()[source]

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set ~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()[source]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set ~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

test_dataloader()[source]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

predict_dataloader()[source]

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

class cryoPARES.datamanager.datamanager.MultiInstanceSampler(sampler, batch_size, drop_last, num_copies_to_sample=1)[source]

Bases: BatchSampler

Parameters:
__init__(sampler, batch_size, drop_last, num_copies_to_sample=1)[source]
Parameters:

Particles Dataset

class cryoPARES.datamanager.particlesDataset.ParticlesDataset(symmetry, halfset, sampling_rate_angs_for_nnet=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, image_size_px_for_nnet=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, store_data_in_memory=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, mask_radius_angs=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, apply_mask_to_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, min_maxProb=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, perImg_normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, ctf_correction=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, reduce_symmetry_in_label=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, return_ori_imagen=False, subset_idxs=None)[source]

Bases: Dataset, ABC

Parameters:
  • symmetry (str)

  • halfset (int | None)

  • sampling_rate_angs_for_nnet (float)

  • image_size_px_for_nnet (int)

  • store_data_in_memory (bool)

  • mask_radius_angs (float | None)

  • apply_mask_to_img (bool)

  • min_maxProb (float | None)

  • perImg_normalization (Literal['none', 'noiseStats', 'subtractMean'])

  • ctf_correction (Literal['none', 'phase_flip', 'ctf_multiply', 'concat_phase_flip', 'concat_ctf_multiply'])

  • reduce_symmetry_in_label (bool)

  • return_ori_imagen (bool)

  • subset_idxs (List[int] | None)

__init__(symmetry, halfset, sampling_rate_angs_for_nnet=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, image_size_px_for_nnet=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, store_data_in_memory=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, mask_radius_angs=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, apply_mask_to_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, min_maxProb=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, perImg_normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, ctf_correction=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, reduce_symmetry_in_label=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, return_ori_imagen=False, subset_idxs=None)[source]
Parameters:
  • symmetry (str)

  • halfset (int | None)

  • sampling_rate_angs_for_nnet (float)

  • image_size_px_for_nnet (int)

  • store_data_in_memory (bool)

  • mask_radius_angs (float | None)

  • apply_mask_to_img (bool)

  • min_maxProb (float | None)

  • perImg_normalization (Literal['none', 'noiseStats', 'subtractMean'])

  • ctf_correction (Literal['none', 'phase_flip', 'ctf_multiply', 'concat_phase_flip', 'concat_ctf_multiply'])

  • reduce_symmetry_in_label (bool)

  • return_ori_imagen (bool)

  • subset_idxs (List[int] | None)

property nnet_image_size_px: int

The image size in pixels

abstractmethod load_ParticlesStarSet()[source]
property particles: ParticlesStarSet

a starstack.particlesStar.ParticlesStarSet representing the loaded particles

property sampling_rate: float

The particle image sampling rate in A/pixels

original_sampling_rate()[source]
Return type:

float

original_image_size()[source]
Return type:

int

property augmenter: AugmenterBase

The data augmentator object to be applied

property symmetry_group
resizeImage(img)[source]
updateMd(ids, angles=None, shifts=None, confidence=None, angles_format='rotmat', shifts_format='Angst')[source]

Updates the metadata of the particles with selected ids

Parameters:
  • ids (List[str]) – The ids of the entries to be updated e.g. [”1@particles_0.mrcs”, “2@particles_0.mrcs]

  • angles (Optional[Union[torch.Tensor, np.ndarray]]) – The particle pose angles to update

  • shifts (Optional[Union[torch.Tensor, np.ndarray]]) – The particle shifts

  • confidence (Optional[Union[torch.Tensor, np.ndarray]]) – The prediction confidence

  • angles_format (Literal[rotmat, zyzEulerDegs]) – The format for the argument angles

  • shifts_format (Literal[rotmat, zyzEulerDegs]) – The format for the argument shifts

cryoPARES.datamanager.particlesDataset.resize_and_padCrop_tensorBatch(array, current_sampling_rate, new_sampling_rate, new_n_pixels=None, padding_mode='reflect')[source]

Augmentations

class cryoPARES.datamanager.augmentations.AugmenterBase[source]

Bases: ABC

abstractmethod applyAugmentation(imgs, degEulerList, shiftFractionList)[source]
class cryoPARES.datamanager.augmentations.Augmenter(min_n_augm_per_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, max_n_augm_per_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, prob_augment_each_image=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>)[source]

Bases: AugmenterBase

Parameters:
  • min_n_augm_per_img (int)

  • max_n_augm_per_img (int)

  • prob_augment_each_image (float)

__init__(min_n_augm_per_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, max_n_augm_per_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, prob_augment_each_image=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>)[source]
Parameters:
  • min_n_augm_per_img (int)

  • max_n_augm_per_img (int)

  • prob_augment_each_image (float)

applyAugmentation(imgs, degEulerList, shiftFractionList)[source]
cryoPARES.datamanager.augmentations.rotTransImage(image, degrees, translationFract, scaling=1.0, padding_mode='reflection', interpolation_mode='bilinear', rotation_first=True)[source]
Parameters:
  • image – BxCxNxN

  • degrees

  • translationFract – The translation to be applied as a fraction of the total size in pixels

  • scaling

  • padding_mode

  • interpolation_mode

  • rotation_first – if using to compute Relion alignment parameters, set it to True

Return type:

Tuple[Tensor, Tensor]

Returns:

class cryoPARES.datamanager.augmentations.Scheduler(schedulerInfo)[source]

Bases: object

__init__(schedulerInfo)[source]
identity(x, current_step)[source]
linear_up(p, current_step)[source]
linear_down(p, current_step)[source]
generate()[source]

CTF Correction

cryoPARES.datamanager.ctf.rfft_ctf.compute_ctf_rfft(image_size, sampling_rate, dfu, dfv, dfang, volt, cs, w, phase_shift, bfactor, fftshift, device)[source]

Compute CTF using RFFT frequency grid

Input:

image_size: the side of the image sampling_rate: In A/pixel dfu (float or Bx1 tensor): DefocusU (Angstrom) dfv (float or Bx1 tensor): DefocusV (Angstrom) dfang (float or Bx1 tensor): DefocusAngle (degrees) volt (float or Bx1 tensor): accelerating voltage (kV) cs (float or Bx1 tensor): spherical aberration (mm) w (float or Bx1 tensor): amplitude contrast ratio phase_shift (float or Bx1 tensor): degrees bfactor (float or Bx1 tensor): envelope fcn B-factor (Angstrom^2) fftshift: if true, fftshift the ctf

Parameters:
cryoPARES.datamanager.ctf.rfft_ctf.correct_ctf(image, sampling_rate, dfu, dfv, dfang, volt, cs, w, phase_shift=0, bfactor=None, mode='phase_flip', fftshift=True, wiener_parameter=0.15)[source]

Apply the 2D CTF through phase flip or wigner filter using RFFT

Input:

image: a real space image sampling rate: Angstrom/pixel dfu (float or Bx1 tensor): DefocusU (Angstrom) dfv (float or Bx1 tensor): DefocusV (Angstrom) dfang (float or Bx1 tensor): DefocusAngle (degrees) volt (float or Bx1 tensor): accelerating voltage (kV) cs (float or Bx1 tensor): spherical aberration (mm) w (float or Bx1 tensor): amplitude contrast ratio phase_shift (float or Bx1 tensor): degrees bfactor (float or Bx1 tensor): envelope fcn B-factor (Angstrom^2) mode (Choice[“phase_flip”, “wiener”]): how to correct the ctf fftshift (bool): If true, fftshift will be applied (and the returned ctf will be also fftshifted) wiener_parameter (float):

Output:

ctf, corrected_image: ctf is rfft, can be fftshifted or not. corrected_image is a real space image

cryoPARES.datamanager.ctf.rfft_ctf.corrupt_with_ctf(image, sampling_rate, dfu, dfv, dfang, volt, cs, w, phase_shift=0, bfactor=None, fftshift=True)[source]

Corrupt an image applying CTF using RFFT

Output:

ctf, image_corrupted: ctf is rfft, can be fftshifted or not. corrected_image is a real space image