Data Management API
Data Manager
- cryoPARES.datamanager.datamanager.get_example_random_batch(batch_size, n_channels=None, seed=None)[source]
- class cryoPARES.datamanager.datamanager.DataManager(star_fnames, symmetry, particles_dir, halfset, batch_size, save_train_val_partition_dir, is_global_zero, num_augmented_copies_per_batch=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, train_validaton_split_seed=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, train_validation_split=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, num_dataworkers=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, augment_train=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, only_first_dataset_for_validation=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, return_ori_imagen=False, subset_idxs=None)[source]
Bases:
LightningDataModule
DataManager: A LightningDataModule that wraps a ParticlesDataset
- Parameters:
star_fnames (List[PathLike | str] | PathLike | str | List[Tuple[DataFrame, DataFrame]])
symmetry (str)
particles_dir (List[PathLike | str] | None | PathLike | str)
halfset (Literal[1, 2] | None)
batch_size (int)
is_global_zero (bool)
num_augmented_copies_per_batch (int)
train_validaton_split_seed (int)
num_dataworkers (int)
augment_train (bool)
only_first_dataset_for_validation (bool)
return_ori_imagen (bool)
- __init__(star_fnames, symmetry, particles_dir, halfset, batch_size, save_train_val_partition_dir, is_global_zero, num_augmented_copies_per_batch=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, train_validaton_split_seed=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, train_validation_split=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, num_dataworkers=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, augment_train=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, only_first_dataset_for_validation=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, return_ori_imagen=False, subset_idxs=None)[source]
- Parameters:
star_fnames (List[PathLike | str] | PathLike | str | List[Tuple[DataFrame, DataFrame]])
symmetry (str)
particles_dir (List[PathLike | str] | None | PathLike | str)
halfset (Literal[1, 2] | None)
batch_size (int)
is_global_zero (bool)
num_augmented_copies_per_batch (int)
train_validaton_split_seed (int)
num_dataworkers (int)
augment_train (bool)
only_first_dataset_for_validation (bool)
return_ori_imagen (bool)
- prepare_data_per_node
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices
If True, dataloader with zero length within local rank is allowed. Default value is False.
- prepare_data()[source]
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.
Warning
DO NOT set state to the model (use
setup
instead) since this is NOT called on every deviceExample:
def prepare_data(self): # good download_data() tokenize() etc() # bad self.split = data_split self.some_state = some_other_state()
In a distributed environment,
prepare_data
can be called in two ways (using prepare_data_per_node)Once per node. This is the default and is only called on LOCAL_RANK=0.
Once in total. Only called on GLOBAL_RANK=0.
Example:
# DEFAULT # called once per node on LOCAL_RANK=0 of that node class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = True # call on GLOBAL_RANK=0 (great for shared file systems) class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = False
This is called before requesting the dataloaders:
model.prepare_data() initialize_distributed() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader() model.predict_dataloader()
- Return type:
- train_dataloader()[source]
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set
~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs
to a positive integer.For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set
~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs
to a positive integer.It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.
- test_dataloader()[source]
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()
setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.
- predict_dataloader()[source]
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.predict()
setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying prediction samples.
- class cryoPARES.datamanager.datamanager.MultiInstanceSampler(sampler, batch_size, drop_last, num_copies_to_sample=1)[source]
Bases:
BatchSampler
- Parameters:
Particles Dataset
- class cryoPARES.datamanager.particlesDataset.ParticlesDataset(symmetry, halfset, sampling_rate_angs_for_nnet=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, image_size_px_for_nnet=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, store_data_in_memory=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, mask_radius_angs=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, apply_mask_to_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, min_maxProb=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, perImg_normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, ctf_correction=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, reduce_symmetry_in_label=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, return_ori_imagen=False, subset_idxs=None)[source]
-
- Parameters:
symmetry (str)
halfset (int | None)
sampling_rate_angs_for_nnet (float)
image_size_px_for_nnet (int)
store_data_in_memory (bool)
mask_radius_angs (float | None)
apply_mask_to_img (bool)
min_maxProb (float | None)
perImg_normalization (Literal['none', 'noiseStats', 'subtractMean'])
ctf_correction (Literal['none', 'phase_flip', 'ctf_multiply', 'concat_phase_flip', 'concat_ctf_multiply'])
reduce_symmetry_in_label (bool)
return_ori_imagen (bool)
- __init__(symmetry, halfset, sampling_rate_angs_for_nnet=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, image_size_px_for_nnet=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, store_data_in_memory=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, mask_radius_angs=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, apply_mask_to_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, min_maxProb=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, perImg_normalization=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, ctf_correction=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, reduce_symmetry_in_label=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, return_ori_imagen=False, subset_idxs=None)[source]
- Parameters:
symmetry (str)
halfset (int | None)
sampling_rate_angs_for_nnet (float)
image_size_px_for_nnet (int)
store_data_in_memory (bool)
mask_radius_angs (float | None)
apply_mask_to_img (bool)
min_maxProb (float | None)
perImg_normalization (Literal['none', 'noiseStats', 'subtractMean'])
ctf_correction (Literal['none', 'phase_flip', 'ctf_multiply', 'concat_phase_flip', 'concat_ctf_multiply'])
reduce_symmetry_in_label (bool)
return_ori_imagen (bool)
- property particles: ParticlesStarSet
a starstack.particlesStar.ParticlesStarSet representing the loaded particles
- property augmenter: AugmenterBase
The data augmentator object to be applied
- property symmetry_group
- updateMd(ids, angles=None, shifts=None, confidence=None, angles_format='rotmat', shifts_format='Angst')[source]
Updates the metadata of the particles with selected ids
- Parameters:
ids (List[str]) – The ids of the entries to be updated e.g. [”1@particles_0.mrcs”, “2@particles_0.mrcs]
angles (Optional[Union[torch.Tensor, np.ndarray]]) – The particle pose angles to update
shifts (Optional[Union[torch.Tensor, np.ndarray]]) – The particle shifts
confidence (Optional[Union[torch.Tensor, np.ndarray]]) – The prediction confidence
angles_format (Literal[rotmat, zyzEulerDegs]) – The format for the argument angles
shifts_format (Literal[rotmat, zyzEulerDegs]) – The format for the argument shifts
Augmentations
- class cryoPARES.datamanager.augmentations.Augmenter(min_n_augm_per_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, max_n_augm_per_img=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>, prob_augment_each_image=<cryoPARES.configManager.inject_defaults.CONFIG_PARAM object>)[source]
Bases:
AugmenterBase
CTF Correction
- cryoPARES.datamanager.ctf.rfft_ctf.compute_ctf_rfft(image_size, sampling_rate, dfu, dfv, dfang, volt, cs, w, phase_shift, bfactor, fftshift, device)[source]
Compute CTF using RFFT frequency grid
- Input:
image_size: the side of the image sampling_rate: In A/pixel dfu (float or Bx1 tensor): DefocusU (Angstrom) dfv (float or Bx1 tensor): DefocusV (Angstrom) dfang (float or Bx1 tensor): DefocusAngle (degrees) volt (float or Bx1 tensor): accelerating voltage (kV) cs (float or Bx1 tensor): spherical aberration (mm) w (float or Bx1 tensor): amplitude contrast ratio phase_shift (float or Bx1 tensor): degrees bfactor (float or Bx1 tensor): envelope fcn B-factor (Angstrom^2) fftshift: if true, fftshift the ctf
- cryoPARES.datamanager.ctf.rfft_ctf.correct_ctf(image, sampling_rate, dfu, dfv, dfang, volt, cs, w, phase_shift=0, bfactor=None, mode='phase_flip', fftshift=True, wiener_parameter=0.15)[source]
Apply the 2D CTF through phase flip or wigner filter using RFFT
- Input:
image: a real space image sampling rate: Angstrom/pixel dfu (float or Bx1 tensor): DefocusU (Angstrom) dfv (float or Bx1 tensor): DefocusV (Angstrom) dfang (float or Bx1 tensor): DefocusAngle (degrees) volt (float or Bx1 tensor): accelerating voltage (kV) cs (float or Bx1 tensor): spherical aberration (mm) w (float or Bx1 tensor): amplitude contrast ratio phase_shift (float or Bx1 tensor): degrees bfactor (float or Bx1 tensor): envelope fcn B-factor (Angstrom^2) mode (Choice[“phase_flip”, “wiener”]): how to correct the ctf fftshift (bool): If true, fftshift will be applied (and the returned ctf will be also fftshifted) wiener_parameter (float):
- Output:
ctf, corrected_image: ctf is rfft, can be fftshifted or not. corrected_image is a real space image
- cryoPARES.datamanager.ctf.rfft_ctf.corrupt_with_ctf(image, sampling_rate, dfu, dfv, dfang, volt, cs, w, phase_shift=0, bfactor=None, fftshift=True)[source]
Corrupt an image applying CTF using RFFT
- Output:
ctf, image_corrupted: ctf is rfft, can be fftshifted or not. corrected_image is a real space image