Source code for ihm_validation.futures

# -*- coding: utf-8 -*-
#
# futures.py - This is a beta-version of a new validation framework
#
# Copyright (C) 2023-2024 Arthur Zalevsky <aozalevsky@gmail.com>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""
This is a beta-version of a new validation framework
"""
import logging
from typing import Union
from collections import defaultdict

import ihm
import numpy as np
import pandas as pd

[docs] class Validator(object): '''Base validator class''' dataset = None restraint = None def __init__(self): pass
[docs] def load_restraint(self, restraint: ihm.restraint.Restraint): '''Parse the restraint data''' raise NotImplementedError( 'Method not implemented for that class')
[docs] def validate_model(self, model: ihm.model.Model) -> dict: '''Validate the model against the data''' raise NotImplementedError( 'Model validation is not implemented for that class')
[docs] def validate_ensemble(self, ensemble) -> dict: '''Validate the ensemble against the data''' raise NotImplementedError( 'Ensemble validation is not implemented for that class')
[docs] class CXMSValidator(Validator): def __init__(self): super().__init__() self.supported_particles = (ihm.model.Atom, ihm.model.Sphere)
[docs] def load_restraint(self, restraint: ihm.restraint.CrossLinkRestraint) -> None: """Extract crosslinking-MS data from mmcif file""" raw_restraints = self._load_restraint_raw(restraint) if len(raw_restraints) > 0: self._quality_check(raw_restraints) self.raw_restraints = raw_restraints self._assign_ertypes() self._assign_rtdtypes()
def _load_restraint_raw(self, restraint: ihm.restraint.CrossLinkRestraint): """Get raw restraint data""" allr = [] # Iterate over all crosslinks in the dataset for xl in restraint.cross_links: exl = xl.experimental_cross_link # Get residues r1 = xl.asym1.residue(exl.residue1.seq_id) r2 = xl.asym2.residue(exl.residue2.seq_id) # Get residue names r1n = r1.comp.id r2n = r2.comp.id # Select atoms if xl.granularity == 'by-atom': a1n = xl.atom1 a2n = xl.atom2 elif xl.granularity == 'by-residue': # Crosslink applied to the specific residue # represented by the alpha carbon atom a1n = 'CA' a2n = 'CA' else: logging.debug(Exception('Unsupported xl granularity')) continue intra_chain = False if xl.asym1.id == xl.asym2.id: intra_chain = True intra_entity = False if r1.asym.entity.description == r2.asym.entity.description: intra_entity = True r_ = { 'chemistry': restraint.linker.auth_name, 'restraint_id': xl.id, 'group_id': xl.group_id, 'chain1': xl.asym1.id, 'resnum1': r1.seq_id, 'resnum1_auth': r1.auth_seq_id, 'resname1': r1n, 'name1': a1n, 'chain2': xl.asym2.id, 'resnum2': r2.seq_id, 'resnum2_auth': r2.auth_seq_id, 'resname2': r2n, 'name2': a2n, 'distance_limit': xl.distance.distance, 'distance_lower_limit': xl.distance.distance_lower_limit, 'distance_upper_limit': xl.distance.distance_upper_limit, 'restraint_type': xl.distance.restraint_type, 'psi': xl.psi, 'sigma1': xl.sigma1, 'sigma2': xl.sigma2, 'group_restraint_all': xl.restrain_all, 'entity_name1': r1.asym.entity.description, 'entity_name2': r2.asym.entity.description, 'intra_chain': intra_chain, 'intra_entity': intra_entity, # New property to distinguish restraint types 'restraint_enum': None, # New property to select restraint group/threshold types 'restraint_rtd': None, # Geom properties 'state_group': None, 'state': None, 'model_group': None, 'model_number': None, 'distance_euclidean': None, } allr.append(r_) # Convert data to pandas dataframe allr = pd.DataFrame(allr) return allr def _quality_check(self, data: pd.DataFrame) -> None: """Check consistency of crosslink restraints""" gids = list(set(data['group_id'])) for gid in gids: data_ = data[data['group_id'] == gid] self._check_conditional_flag(data_) def _check_conditional_flag(self, data: pd.DataFrame) -> None: """Check consistency of conditional flags in a restraint group""" gid = list(set(data['group_id']))[0] conditional_flags = list(set(data['group_restraint_all'])) if len(conditional_flags) != 1: raise ValueError( f'Mixed conditional flags in crosslink restraint group {gid}' ) @property def number_of_restraints(self) -> int: return len(self.raw_restraints) @property def number_of_restraint_groups(self) -> int: nrg = len(set(self.raw_restraints['group_id'])) return nrg def _get_rtdtype(self, row): '''Determine crosslink types: Self/Heteromeric; Upper/Lower/Harmonic; Threshold''' rt = row['restraint_type'] lt = None if row['intra_entity']: lt = 'Self-links' else: lt = 'Heteromeric links' if rt == 'upper bound': d = row['distance_upper_limit'] elif rt == 'lower bound': d = row['distance_lower_limit'] elif rt == 'harmonic': d = row['distance_limit'] else: raise ValueError('Wrong restraint type') return (rt, d, lt) def _get_ertype(self, row): '''Get extended restraint type, that includes linker, residue names, atom names''' r1n = row['resname1'] a1n = row['name1'] r2n = row['resname2'] a2n = row['name2'] # Sort names of residues and atoms (r1n_, a1n_), (r2n_, a2n_) = sorted( [(r1n, a1n), (r2n, a2n)] ) # Construct the extended restraint type rtype_ = ( row['chemistry'], r1n_, a1n_, r2n_, a2n_, row['restraint_type'], # row['distance_lower_limit'], row['distance_limit'], # row['distance_upper_limit'], ) return rtype_ @property def ertypes(self): '''Get all extendend restraint types''' ertypes = {} # enumerate restraints types for index, row in self.raw_restraints.iterrows(): rtype_ = self._get_ertype(row) if rtype_ not in ertypes: ertypes[rtype_] = len(ertypes) return ertypes @property def rtdtypes(self): '''Get all crosslink types''' ertypes = {} # enumerate restraints types for index, row in self.raw_restraints.iterrows(): rtype_ = self._get_rtdtype(row) if rtype_ not in ertypes: ertypes[rtype_] = len(ertypes) return ertypes def _assign_ertypes(self): '''Assign extended restraint types''' ertypes = self.ertypes for index, row in self.raw_restraints.iterrows(): rtype_ = self._get_ertype(row) ertype_ = ertypes[rtype_] self.raw_restraints.at[index, 'restraint_enum'] = ertype_ def _assign_rtdtypes(self): '''Assign crosslink types''' rtdtypes = self.rtdtypes for index, row in self.raw_restraints.iterrows(): rtype_ = self._get_rtdtype(row) ertype_ = rtdtypes[rtype_] self.raw_restraints.at[index, 'restraint_rtd'] = ertype_ @property def ertypes_df(self): # Exctract subset of data as restraint types fields_ = [ 'Linker', 'Residue 1', 'Atom 1', 'Residue 2', 'Atom 2', 'Restraint type', # 'Lower_limit, Å', 'Distance, Å', # 'Upper_limit, Å' ] ertypes = pd.DataFrame( self.ertypes.keys(), index=self.ertypes.values(), columns=fields_) ertypes['Count'] = None for index, row in ertypes.iterrows(): n = len( self.raw_restraints[ self.raw_restraints['restraint_enum'] == index] ) ertypes.at[index, 'Count'] = n return ertypes def _measure_restraints(self, models: list) -> pd.DataFrame: restraints = [] for m in models: m_ = get_hierarchy_from_model(m) for index, row in self.raw_restraints.iterrows(): d = self._measure_restraint(m_, row) ndata = { # Store as much information as we can 'distance_euclidean': d, 'model_number': int(m._id), } nrow = row.to_dict() nrow.update(ndata) restraints.append(nrow) restraints = pd.DataFrame(restraints) if len(restraints) > 0: # Drop missing restraints count_missing = restraints[ 'distance_euclidean'].isna() logging.debug( f'Dropped {count_missing} crosslink restraints with ' 'empty distances' ) restraints.dropna( subset=['distance_euclidean'], inplace=True) return restraints def _measure_restraint(self, model, row) -> pd.DataFrame: # Check that we have all necessary atoms chid = row['chain1'] rid = row['resnum1'] an = row['name1'] a1 = model[chid][rid][an] if type(a1) not in self.supported_particles: a1 = None if a1 is None: logging.debug(f'Atom {chid} {rid} {an} is empty') chid = row['chain2'] rid = row['resnum2'] an = row['name2'] a2 = model[chid][rid][an] if type(a2) not in self.supported_particles: a2 = None if a2 is None: logging.debug(f'Atom {chid} {rid} {an} is empty') if a1 is None or a2 is None: d = None else: # Calculate distance a1_ = np.array([a1.x, a1.y, a1.z]) a2_ = np.array([a2.x, a2.y, a2.z]) d = np.linalg.norm(a2_ - a1_) return d
[docs] def validate_model(self, model: ihm.model.Model) -> dict: data = self._measure_restraints([model]) out = self._process_satisfaction_data(data) return(out)
[docs] def validate_ensemble(self, models: list) -> dict: data = self._measure_restraints(models) out = self._process_satisfaction_data(data) return(out)
def _process_satisfaction_data(self, data: pd.DataFrame) -> dict: out = {} stats = self._process_restraint_groups(data) for k, v in stats.items(): pct = None count = v['Count'] if count > 0: pct = v['Satisfied'] / v['Count'] * 100.0 out[k] = {'Satisfaction': pct, 'Count': count} return out def _process_restraint_groups(self, data: pd.DataFrame, mode='entity') -> dict: rgs = set(data['group_id']) stats = {'All': {'Satisfied': 0, 'Count': 0}} for rgs_ in rgs: good_ = 0 rg_chain_type = None rg_entity_type = None data_ = data[data['group_id'] == rgs_] rg_chain_type = self._get_restraint_group_chain_type(data_) rg_entity_type = self._get_restraint_group_entity_type(data_) assert rg_chain_type is not None assert rg_entity_type is not None rg_type = f'{rg_entity_type}/{rg_chain_type}' # Check per model models = list(set(data_['model_number'])) for model in models: data__ = data_[data_['model_number'] == model] if self._is_restraint_group_satisfied(data__): good__ = 1 else: good__ = 0 good_ = good_ or good__ if rg_type not in stats: stats[rg_type] = {'Satisfied': 0, 'Count': 0} stats[rg_type]['Satisfied'] += good_ stats[rg_type]['Count'] += 1 stats['All']['Satisfied'] += good_ stats['All']['Count'] += 1 return stats def _is_restraint_group_satisfied(self, data: pd.DataFrame) -> bool: # Verify, that there is only one type of restraints in the group satisfied_restraints = np.zeros(len(data), dtype=bool) conditional_flag_all = list(set(data['group_restraint_all']))[0] for i, (index, row) in enumerate(data.iterrows()): rtype = row['restraint_type'] ed = row['distance_euclidean'] threshold = row['distance_limit'] if rtype == 'upper bound': cmp = (ed - threshold) <= 0 elif rtype == 'lower bound': cmp = (ed - threshold) >= 0 # Check with Ben elif rtype == 'harmonic': cmp = np.isclose( ed, threshold, rtol=row['psi'] + row['sigma1'] + row['sigma2'] ) satisfied_restraints[i] = cmp if conditional_flag_all: satisfied = satisfied_restraints.all() else: satisfied = satisfied_restraints.any() return satisfied def _get_restraint_group_chain_type(self, data: pd.DataFrame) -> str: rg_type = None if data['intra_chain'].all(): rg_type = 'Intramolecular' elif (~data['intra_chain']).all(): rg_type = 'Intermolecular' elif data['intra_chain'].any() and (~data['intra_chain']).any(): rg_type = 'Ambiguous' return rg_type def _get_restraint_group_entity_type(self, data: pd.DataFrame) -> str: rg_type = None if data['intra_entity'].all(): rg_type = 'Self-links' elif (~data['intra_entity']).all(): rg_type = 'Heteromeric links' elif data['intra_entity'].any() and (~data['intra_entity']).any(): rg_type = 'Ambiguous entity' return rg_type
[docs] class StereoChemistryValidator(Validator): def __init__(self): super().__init__()
[docs] def validate_model(self, model: ihm.model.Model) -> dict: '''Validate the model against the data''' if is_model_atomic(model): pass pass
# asym_id, seq_id, atom_id
[docs] def get_hierarchy_from_atoms(atoms): def infinite_defaultdict(): return defaultdict(infinite_defaultdict) root = infinite_defaultdict() for a in atoms: root[a.asym_unit.id][a.seq_id][a.atom_id] = a return root
# asym_id, seq_id, atom_id
[docs] def get_hierarchy_from_model(model): def infinite_defaultdict(): return defaultdict(infinite_defaultdict) root = infinite_defaultdict() for a in model.get_atoms(): root[a.asym_unit.id][a.seq_id][a.atom_id] = a for r in model.representation: if r.granularity == 'by-residue': for i in range(r.asym_unit.seq_id_range[0], r.asym_unit.seq_id_range[1] + 1): root[r.asym_unit.asym.id][i]['CA'] = None for s in model.get_spheres(): # Consider only by-residue spheres seq_ids = list(set(s.seq_id_range)) if len(seq_ids) != 1: continue seq_id = seq_ids[0] if root[s.asym_unit.id][seq_id]['CA'] is None: root[s.asym_unit.id][seq_id]['CA'] = s return root
[docs] def is_model_mixed(model: ihm.model.Model) -> bool: """Check if model is atomic""" result = False granularities = set([r.granularity for r in model.representation]) if len(granularities) > 1: result = True return result
[docs] def is_model_atomic(model: ihm.model.Model) -> bool: """Check if model is atomic""" result = False granularities = set([r.granularity for r in model.representation]) if granularities == set(['by-atom']): result = True return result
[docs] def is_model_cg(model: ihm.model.Model) -> bool: """Check if model is atomic""" result = False granularities = set([r.granularity for r in model.representation]) if granularities == set(['by-residue', 'by-feature']): result = True return result