Source code for cryoPARES.datamanager.particlesDataset

import functools
import warnings
from abc import ABC, abstractmethod
from functools import cached_property

import torch
import numpy as np

from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
from starstack.constants import RELION_IMAGE_FNAME
from starstack.particlesStar import ParticlesStarSet
from torch.utils.data import Dataset
from typing import Union, Literal, Optional, List, Tuple, Any, Dict

from cryoPARES.cacheManager import get_cache
from cryoPARES.configManager.inject_defaults import inject_defaults_from_config, CONFIG_PARAM
from cryoPARES.configs.datamanager_config.particlesDataset_config import CtfCorrectionType, ImgNormalizationType
from cryoPARES.configs.mainConfig import main_config
from cryoPARES.constants import RELION_ANGLES_NAMES, RELION_SHIFTS_NAMES, \
    RELION_PRED_POSE_CONFIDENCE_NAME, RELION_EULER_CONVENTION, RELION_ORI_POSE_CONFIDENCE_NAME, BATCH_PARTICLES_NAME, \
    BATCH_IDS_NAME, BATCH_POSE_NAME, BATCH_MD_NAME, BATCH_ORI_IMAGE_NAME, BATCH_ORI_CTF_NAME

warnings.filterwarnings("ignore", "Gimbal lock detected. Setting third angle to zero since it "
                                  "is not possible to uniquely determine all angles.")
warnings.filterwarnings("ignore", message="The torchvision.datapoints and torchvision.transforms.v2 namespaces")


from cryoPARES.datamanager.augmentations import AugmenterBase
from cryoPARES.datamanager.ctf.rfft_ctf import correct_ctf
from cryoPARES.utils.torchUtils import data_to_numpy

[docs] class ParticlesDataset(Dataset, ABC): #TODO: This class still has several Relion-specific features
[docs] @inject_defaults_from_config(main_config.datamanager.particlesdataset) def __init__(self, symmetry: str, halfset: Optional[int], sampling_rate_angs_for_nnet: float = CONFIG_PARAM(), image_size_px_for_nnet: int = CONFIG_PARAM(), store_data_in_memory: bool = CONFIG_PARAM(), mask_radius_angs: Optional[float] = CONFIG_PARAM(), apply_mask_to_img: bool = CONFIG_PARAM(), min_maxProb: Optional[float] = CONFIG_PARAM(), perImg_normalization: Literal["none", "noiseStats", "subtractMean"] = CONFIG_PARAM(), ctf_correction: Literal["none", "phase_flip", "ctf_multiply", "concat_phase_flip", "concat_ctf_multiply"] = CONFIG_PARAM(), reduce_symmetry_in_label:bool = CONFIG_PARAM(), return_ori_imagen: bool = False, subset_idxs: Optional[List[int]] = None ): super().__init__() self.sampling_rate_angs_for_nnet = sampling_rate_angs_for_nnet self.image_size_px_for_nnet = image_size_px_for_nnet self.store_data_in_memory = store_data_in_memory self.mask_radius_angs = mask_radius_angs self.apply_mask_to_img = apply_mask_to_img self.min_maxProb = min_maxProb self.reduce_symmetry_in_label = reduce_symmetry_in_label self.return_ori_imagen = return_ori_imagen self.symmetry = symmetry.upper() self.halfset = halfset self.subset_idxs = subset_idxs assert perImg_normalization in (item.value for item in ImgNormalizationType) if perImg_normalization == "none": self._normalize = self._normalizeNone elif perImg_normalization == "noiseStats": self._normalize = self._normalizeNoiseStats elif perImg_normalization == "subtractMean": self._normalize = self.confidences_normalizeSubtractMean else: ValueError(f"Error, perImg_normalization {perImg_normalization} wrong option") assert ctf_correction in (item.value for item in CtfCorrectionType) if ctf_correction == "none": self._correctCtf = self._correctCtfNone elif ctf_correction.endswith("phase_flip"): self._correctCtf = self._correctCtfPhase elif ctf_correction.endswith("ctf_multiply"): raise NotImplementedError("Error, ctf_multiply was not implemented") else: ValueError(f"Error, perImg_normalization {ctf_correction} wrong option") self.ctf_correction_do_concat = ctf_correction.startswith("concat") self.ctf_correction = ctf_correction.removeprefix("concat_") if self.store_data_in_memory: self.memory = get_cache(cache_name=None, verbose=0) self._getIdx = self.memory.cache(self._getIdx, ignore=["self"], verbose=0) self._particles = None self._augmenter = None self._image_size = None
@property def nnet_image_size_px(self) -> int: """The image size in pixels""" if self.image_size_px_for_nnet is None: return self.particles.particle_shape[-1] else: return self.image_size_px_for_nnet
[docs] @abstractmethod def load_ParticlesStarSet(self): raise NotImplementedError()
def _load_ParticlesStarSet(self): part_set = self.load_ParticlesStarSet() self._particles = part_set assert len(part_set) > 0, "Error, no particles were found in the star file" if self.subset_idxs is not None: self._particles = self._particles.createSubset(idxs=self.subset_idxs) if self.halfset is not None: if "rlnRandomSubset" not in self._particles.particles_md: half1, half2 = train_test_split(self._particles.particles_md.index, test_size=0.5, random_state=11, #Using the same seed to ensure that we always split the same way shuffle=True) self._particles.particles_md.loc[:, "rlnRandomSubset"] = 1 self._particles.particles_md.loc[half2, "rlnRandomSubset"] = 2 subsetNums = self._particles.particles_md["rlnRandomSubset"].values _subsetNums = set(subsetNums) assert min(_subsetNums) >= 1 and max(_subsetNums) <= 2 idxs = np.where(subsetNums == self.halfset)[0] self._particles = self._particles.createSubset(idxs=idxs) if self.min_maxProb is not None: maxprob = self._particles.particles_md[RELION_ORI_POSE_CONFIDENCE_NAME] idxs = np.where(maxprob >= self.min_maxProb)[0] self._particles = self.particles.createSubset(idxs=idxs) return self._particles @property def particles(self) -> ParticlesStarSet: """ a starstack.particlesStar.ParticlesStarSet representing the loaded particles """ if self._particles is None: self._particles = self._load_ParticlesStarSet() return self._particles @property def sampling_rate(self) -> float: """The particle image sampling rate in A/pixels""" if self.image_size_px_for_nnet is None: return self.particles.sampling_rate else: return self.sampling_rate_angs_for_nnet
[docs] def original_sampling_rate(self) -> float: return self.particles.sampling_rate
[docs] def original_image_size(self) -> int: return self.particles.optics_md["rlnImageSize"].values
@property def augmenter(self) -> AugmenterBase: """The data augmentator object to be applied""" return self._augmenter @augmenter.setter def augmenter(self, augmenterObj: AugmenterBase): """ Args: augmenter: he data augmentator object to be applied """ self._augmenter = augmenterObj @staticmethod @functools.lru_cache(2) def _getParticleMask(image_size_px, sampling_rate, mask_radius_angs, device: Optional[Union[torch.device, str]] = None) -> Tuple[torch.Tensor, torch.Tensor]: radius = image_size_px / 2 if mask_radius_angs is None: normalizationRadiusPixels = image_size_px / 2 else: normalizationRadiusPixels = mask_radius_angs / sampling_rate ies, jes = torch.meshgrid( torch.linspace(-1 * radius, 1 * radius, image_size_px, dtype=torch.float32), torch.linspace(-1 * radius, 1 * radius, image_size_px, dtype=torch.float32), indexing="ij" ) r = (ies ** 2 + jes ** 2) ** 0.5 normalizationMask = (r > normalizationRadiusPixels) normalizationMask = normalizationMask.to(device) particleMask = ~ normalizationMask return normalizationMask, particleMask def _normalizeNoiseStats(self, img): """ Args: img: 1XSxS tensor Returns: """ backgroundMask = self._getParticleMask(self.nnet_image_size_px, sampling_rate=self.sampling_rate, mask_radius_angs=self.mask_radius_angs)[0] noiseRegion = img[:, backgroundMask] meanImg = noiseRegion.mean() stdImg = noiseRegion.std() return (img - meanImg) / stdImg def _normalizeSubtractMean(self, img): return (img - img.mean()) def _normalizeNone(self, img): return img def _correctCtfPhase(self, img, md_row): ctf, wimg = correct_ctf(img, float(self.particles.optics_md["rlnImagePixelSize"].item()), dfu=md_row["rlnDefocusU"], dfv=md_row["rlnDefocusV"], dfang=md_row["rlnDefocusAngle"], volt=float(self.particles.optics_md["rlnVoltage"][0]), cs=float(self.particles.optics_md["rlnSphericalAberration"][0]), w=float(self.particles.optics_md["rlnAmplitudeContrast"][0]), mode=self.ctf_correction, fftshift=True) wimg = torch.clamp(wimg, img.min(), img.max()) wimg = torch.nan_to_num(wimg, nan=img.mean()) if self.ctf_correction_do_concat: img = torch.concat([img, wimg], dim=0) else: img = wimg ctf = ctf.real return img, ctf def _correctCtfNone(self, img, md_row): return img, None def _getIdx(self, item: int) -> Tuple[str, torch.Tensor, Tuple[torch.Tensor,torch.Tensor,torch.Tensor], Dict[str, Any], Tuple[torch.Tensor, Optional[torch.Tensor]]]: try: img_ori, md_row = self.particles[item] except ValueError: print(f"Error retrieving item {item}") raise iid = md_row[RELION_IMAGE_FNAME] img_ori = torch.FloatTensor(img_ori) img, ctf_ori = self._correctCtf(img_ori.unsqueeze(0), md_row) if img.isnan().any(): raise RuntimeError(f"Error, img with idx {item} has NAN") img = self.resizeImage(img) img = self._normalize(img) #I changed the order of the normalization call, in cesped it was before ctf correction degEuler = torch.FloatTensor([md_row.get(name, 0) for name in RELION_ANGLES_NAMES]) xyShiftAngs = torch.FloatTensor([md_row.get(name, 0) for name in RELION_SHIFTS_NAMES]) confidence = torch.FloatTensor([md_row.get(RELION_ORI_POSE_CONFIDENCE_NAME, 1)]) return iid, img, (degEuler, xyShiftAngs, confidence), md_row.to_dict(), (img_ori, ctf_ori) @cached_property def symmetry_group(self): return R.create_group(self.symmetry.upper())
[docs] def resizeImage(self, img): ori_pixelSize = float(self.particles.optics_md["rlnImagePixelSize"].item()) img, pad_info, crop_info = resize_and_padCrop_tensorBatch(img.unsqueeze(0), ori_pixelSize, self.sampling_rate_angs_for_nnet, self.nnet_image_size_px, padding_mode="constant") img = img.squeeze(0) return img
def __getitem(self, item): iid, prepro_img, (degEuler, xyShiftAngs, confidence), md_dict, (img_ori, ctf_ori)= self._getIdx(item) if self.augmenter is not None: prepro_img, degEuler, shift, _ = self.augmenter(prepro_img, # 1xSxS image expected degEuler, shiftFraction=xyShiftAngs / (self.nnet_image_size_px * self.sampling_rate)) xyShiftAngs = shift * (self.nnet_image_size_px * self.sampling_rate) r = R.from_euler(RELION_EULER_CONVENTION, degEuler, degrees=True) if self.symmetry != "C1" and self.reduce_symmetry_in_label: r = r.reduce(self.symmetry_group) rotMat = r.as_matrix() rotMat = torch.FloatTensor(rotMat) if self.apply_mask_to_img: mask = self._getParticleMask(self.nnet_image_size_px, sampling_rate=self.sampling_rate, mask_radius_angs=self.mask_radius_angs)[1] prepro_img *= mask batch = {BATCH_IDS_NAME: iid, BATCH_PARTICLES_NAME: prepro_img, BATCH_POSE_NAME: (rotMat, xyShiftAngs, confidence), BATCH_MD_NAME: md_dict} if self.return_ori_imagen: batch[BATCH_ORI_IMAGE_NAME] = img_ori batch[BATCH_ORI_CTF_NAME] = ctf_ori return batch def __getitem__(self, item): return self.__getitem(item) def __len__(self): return len(self.particles)
[docs] def updateMd(self, ids: List[str], angles: Optional[Union[torch.Tensor, np.ndarray]] = None, shifts: Optional[Union[torch.Tensor, np.ndarray]] = None, confidence: Optional[Union[torch.Tensor, np.ndarray]] = None, angles_format: Literal["rotmat", "ZYZEulerDegs"] = "rotmat", shifts_format: Literal["Angst"] = "Angst"): """ Updates the metadata of the particles with selected ids Args: ids (List[str]): The ids of the entries to be updated e.g. ["1@particles_0.mrcs", "2@particles_0.mrcs] angles (Optional[Union[torch.Tensor, np.ndarray]]): The particle pose angles to update shifts (Optional[Union[torch.Tensor, np.ndarray]]): The particle shifts confidence (Optional[Union[torch.Tensor, np.ndarray]]): The prediction confidence angles_format (Literal[rotmat, zyzEulerDegs]): The format for the argument angles shifts_format (Literal[rotmat, zyzEulerDegs]): The format for the argument shifts """ assert angles_format in ["rotmat", "ZYZEulerDegs"], \ 'Error, angle_format should be in ["rotmat", "ZYZEulerDegs"]' assert shifts_format in ["Angst"], \ 'Error, shifts_format should be in ["Angst"]' col2val = {} if angles is not None: angles = data_to_numpy(angles) if angles_format == "rotmat": r = R.from_matrix(angles) rots, tilts, psis = r.as_euler(RELION_EULER_CONVENTION, degrees=True).T else: rots, tilts, psis = [angles[:, i] for i in range(3)] col2val.update({ # RELION_ANGLES_NAMES RELION_ANGLES_NAMES[0]: rots, RELION_ANGLES_NAMES[1]: tilts, RELION_ANGLES_NAMES[2]: psis }) if shifts is not None: shifts = data_to_numpy(shifts) col2val.update({ RELION_SHIFTS_NAMES[0]: shifts[:, 0], RELION_SHIFTS_NAMES[1]: shifts[:, 1], }) if confidence is not None: confidence = data_to_numpy(confidence) col2val.update({ RELION_PRED_POSE_CONFIDENCE_NAME: confidence, }) assert col2val, "Error, no editing values were provided" self.particles.updateMd(ids=ids, colname2change=col2val)
[docs] def resize_and_padCrop_tensorBatch(array, current_sampling_rate, new_sampling_rate, new_n_pixels=None, padding_mode='reflect'): ndims = array.ndim - 2 if isinstance(array, np.ndarray): wasNumpy = True array = torch.from_numpy(array) else: wasNumpy = False if isinstance(current_sampling_rate, tuple): current_sampling_rate = torch.tensor(current_sampling_rate) if isinstance(new_sampling_rate, tuple): new_sampling_rate = torch.tensor(new_sampling_rate) scaleFactor = current_sampling_rate / new_sampling_rate if isinstance(scaleFactor, (int, float)): scaleFactor = (scaleFactor,) * ndims else: scaleFactor = tuple(scaleFactor) # Resize the array if ndims == 2: mode = 'bilinear' elif ndims == 3: mode = 'trilinear' else: raise ValueError(f"Option not valid. ndims={ndims}") resampled_array = torch.nn.functional.interpolate(array, scale_factor=scaleFactor, mode=mode, antialias=False) pad_width = [] crop_positions = [] if new_n_pixels is not None: if isinstance(new_n_pixels, int): new_n_pixels = [new_n_pixels] * ndims for i in range(ndims): new_n_pix = new_n_pixels[i] old_n_pix = resampled_array.shape[i + 2] if new_n_pix < old_n_pix: # Crop the tensor crop_start = (old_n_pix - new_n_pix) // 2 resampled_array = resampled_array.narrow(i + 2, crop_start, new_n_pix) crop_positions.append((crop_start, crop_start + new_n_pix)) elif new_n_pix > old_n_pix: # Pad the tensor pad_before = (new_n_pix - old_n_pix) // 2 pad_after = new_n_pix - old_n_pix - pad_before pad_width.extend((pad_before, pad_after)) if len(pad_width) > 0: resampled_array = torch.nn.functional.pad(resampled_array, pad_width, mode=padding_mode) if wasNumpy: resampled_array = resampled_array.numpy() return resampled_array, pad_width, crop_positions