Source code for ihm_validation.cx

# -*- coding: utf-8 -*-
#
# cx.py - Crosslinking-MS validation for PDB-IHM
#
# Copyright (C) 2023-2025 Arthur Zalevsky <aozalevsky@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""
Crosslinking-MS validation for PDB-IHM
"""

from mmcif_io import GetInputInformation
from utility import get_hierarchy_from_model, NA
import pandas as pd
import logging
import ihm
import numpy as np
from pathlib import Path
from bokeh.plotting import save
from bokeh.layouts import gridplot
from bokeh.models import Range1d
from bokeh.models.widgets import Panel, Tabs
from bokeh.resources import CDN
import iqplot
import json
from bokeh.embed import json_item
from bokeh.io import export_svgs
import requests
import pickle
import pyhmmer
import time
import utility
import xml.etree.ElementTree as ET

pd.options.mode.chained_assignment = None


[docs] class CxValidation(GetInputInformation): ID = None driver = None def __init__(self, mmcif_file, cache): super().__init__(mmcif_file) self.cache = cache self.nos = self.get_number_of_models() self.dataset = self.get_dataset_comp() # Only atomic structures are supported so far # self.struct = prody.parseMMCIF(mmcif_file, header=False) self.entities = set([x.details for x in self.system.asym_units]) self.chains = set([x.id for x in self.system.asym_units]) self.get_cx_data()
[docs] def select_atom_by_asym_id_seq_id_atom_id(atoms, asym_id, seq_id, atom_id): result = None for a in atoms: if a.asym_unit.id == asym_id: if a.seq_id == seq_id: if a.atom_id == atom_id: result = a return result
[docs] def get_models(self): # parse all models models = {} gim = 0 for istg, stg in enumerate(self.system.state_groups): for ist, st in enumerate(stg): for img, mg in enumerate(st): for im, m in enumerate(mg): gim += 1 m_ = get_hierarchy_from_model(m) models[gim] = m_ return models
[docs] def get_raw_restraints(self): """Get all restraints""" allr = [] rid = -1 # Iterate over all restraints datasets for restr_ in self.system.restraints: # We are interested only in Chemical crosslinks if type(restr_) != ihm.restraint.CrossLinkRestraint: continue # Iterate over all crosslinks in the dataset for xl in restr_.cross_links: # we have same ID for measurments in multiple models rid += 1 # get corresponding experimental crosslink exl = xl.experimental_cross_link # Extract residue names from atoms try: r1n = exl.residue1.comp.id r2n = exl.residue2.comp.id except IndexError as e: logging.error('Missing residue') logging.error(e) continue # Select atoms if xl.granularity == 'by-atom': a1n = xl.atom1 a2n = xl.atom2 elif xl.granularity == 'by-residue': # Crosslink applied to the specific residue # represented by the alpha carbon atom a1n = 'CA' a2n = 'CA' elif xl.granularity == 'by-feature': a1n = 'coarse-grained' a2n = 'coarse-grained' else: logging.debug(Exception('Unsupported xl granularity')) continue try: r1 = xl.asym1.residue(exl.residue1.seq_id) r2 = xl.asym2.residue(exl.residue2.seq_id) except IndexError as e: logging.error('Missing residue') logging.error(e) continue intra_chain = False if xl.asym1.id == xl.asym2.id: intra_chain = True intra_entity = False if r1.asym.entity.description == r2.asym.entity.description: intra_entity = True r_ = { 'chemistry': restr_.linker.auth_name, 'restraint_id': int(xl._id), 'group_id': int(exl._id), 'chain1': xl.asym1.id, 'resnum1': r1.seq_id, 'resnum1_auth': r1.auth_seq_id, 'resname1': r1n, 'name1': a1n, 'chain2': xl.asym2.id, 'resnum2': r2.seq_id, 'resnum2_auth': r2.auth_seq_id, 'resname2': r2n, 'name2': a2n, 'distance_limit': xl.distance.distance, 'distance_lower_limit': xl.distance.distance_lower_limit, 'distance_upper_limit': xl.distance.distance_upper_limit, 'restraint_type': xl.distance.restraint_type, 'psi': xl.psi, 'sigma1': xl.sigma1, 'sigma2': xl.sigma2, 'group_restraint_all': xl.restrain_all, 'entity_name1': r1.asym.entity.description, 'entity_name2': r2.asym.entity.description, 'intra_chain': intra_chain, 'intra_entity': intra_entity, # New property to distinguish restraint types 'restraint_enum': None, # New property to select restraint group/threshold types 'restraint_rtd': None, # Geom properties 'state_group': None, 'state': None, 'model_group': None, 'model_number': None, 'distance_euclidean': None, } allr.append(r_) # Convert data to pandas dataframe allr = pd.DataFrame(allr) return allr
[docs] def get_rtdtype(self, row): rt = row['restraint_type'] lt = None if row['intra_entity']: lt = 'Self-links' else: lt = 'Heteromeric links' if rt == 'upper bound': d = row['distance_upper_limit'] elif rt == 'lower bound': d = row['distance_lower_limit'] elif rt == 'harmonic': d = row['distance_limit'] else: raise ValueError('Wrong restraint type') return (rt, d, lt)
[docs] def get_ertype(self, row): r1n = row['resname1'] a1n = row['name1'] r2n = row['resname2'] a2n = row['name2'] # Sort names of residues and atoms (r1n_, a1n_), (r2n_, a2n_) = sorted( [(r1n, a1n), (r2n, a2n)] ) # Construct the extended restraint type rtype_ = ( row['chemistry'], r1n_, a1n_, r2n_, a2n_, row['restraint_type'], # row['distance_lower_limit'], row['distance_limit'], # row['distance_upper_limit'], ) return rtype_
[docs] def get_ertypes(self): ertypes = {} # enumerate restraints types for index, row in self.raw_restraints.iterrows(): rtype_ = self.get_ertype(row) if rtype_ not in ertypes: ertypes[rtype_] = len(ertypes) return ertypes
[docs] def get_rtdtypes(self): ertypes = {} # enumerate restraints types for index, row in self.raw_restraints.iterrows(): rtype_ = self.get_rtdtype(row) if rtype_ not in ertypes: ertypes[rtype_] = len(ertypes) return ertypes
[docs] def assign_ertypes(self): for index, row in self.raw_restraints.iterrows(): rtype_ = self.get_ertype(row) ertype_ = self.ertypes[rtype_] self.raw_restraints.at[index, 'restraint_enum'] = ertype_
[docs] def assign_rtdtypes(self): for index, row in self.raw_restraints.iterrows(): rtype_ = self.get_rtdtype(row) ertype_ = self.rtdtypes[rtype_] self.raw_restraints.at[index, 'restraint_rtd'] = ertype_
[docs] def get_measured_restraints(self): restraints = [] gistg = 0 gist = 0 gimg = 0 gim = 0 for istg, stg in enumerate(self.system.state_groups): gistg += 1 for ist, st in enumerate(stg): gist += 1 for img, mg in enumerate(st): gimg += 1 for im, m in enumerate(mg): gim += 1 m_ = self.models[gim] logging.info(f'Assessing crosslinking-MS for MODEL {gim}') for index, row in self.raw_restraints.iterrows(): d = self.measure_restraint(m_, row) ndata = { # Store as much information as we can 'distance_euclidean': d, 'model_number': gim, 'model_group': gimg, 'state': gist, 'state_group': gistg, } nrow = row.to_dict() nrow.update(ndata) restraints.append(nrow) restraints = pd.DataFrame(restraints) return restraints
[docs] def measure_restraint(self, model, row): allowed_particle_types = (ihm.model.Atom, ihm.model.Sphere) # Check that we have all necessary atoms rid = row['restraint_id'] gid = row['group_id'] chid = row['chain1'] rid = row['resnum1'] an = row['name1'] a1 = model[chid][rid][an] if not isinstance(a1, allowed_particle_types): a1 = None if a1 is None: logging.warning(f'Restraint {rid}: Atom {chid} {rid} {an} is empty') chid = row['chain2'] rid = row['resnum2'] an = row['name2'] a2 = model[chid][rid][an] if not isinstance(a2, allowed_particle_types): a2 = None if a2 is None: logging.warning(f'Restraint {rid}: Atom {chid} {rid} {an} is empty') if a1 is None or a2 is None: d = None elif row['name1'] == 'coarse-grained' or row['name2'] == 'coarse-grained': if row['name1'] == row['name2'] == 'coarse-grained': # Calculate distance between spheres a1_ = np.array([a1.x, a1.y, a1.z]) r1_ = a1.radius a2_ = np.array([a2.x, a2.y, a2.z]) r2_ = a2.radius # In case spheres overlap d = max(0, np.linalg.norm(a2_ - a1_) - (r1_ + r2_)) else: logging.warning(r'Incompatible crosslinking-MS granularities') d = None else: # Assume atomic distances # Calculate distance a1_ = np.array([a1.x, a1.y, a1.z]) a2_ = np.array([a2.x, a2.y, a2.z]) d = np.linalg.norm(a2_ - a1_) return d
[docs] def quality_check(self, data: pd.DataFrame) -> None: """Check consistency of crosslink restraints""" gids = list(set(data['group_id'])) for gid in gids: data_ = data[data['group_id'] == gid] self.check_conditional_flag(data_)
[docs] def check_conditional_flag(self, data: pd.DataFrame) -> None: """Check consistency of conditional flags in a restraint group""" gid = list(set(data['group_id']))[0] conditional_flags = list(set(data['group_restraint_all'])) if len(conditional_flags) != 1: raise ValueError( f'Mixed conditional flags in crosslink restraint group {gid}' )
[docs] def get_number_of_restraints(self) -> int: return len(self.raw_restraints)
[docs] def get_number_of_restraint_groups(self) -> int: nrg = len(set(self.raw_restraints['group_id'])) return nrg
[docs] def get_cx_data(self) -> (pd.DataFrame, pd.DataFrame): """Extract crosslinking-MS data from mmcif file""" output = (None, None) raw_restraints = self.get_raw_restraints() if len(raw_restraints) > 0: self.quality_check(raw_restraints) self.raw_restraints = raw_restraints self.ertypes = self.get_ertypes() self.assign_ertypes() self.ertypes_df = self.get_ertypes_df() self.rtdtypes = self.get_rtdtypes() self.assign_rtdtypes() self.models = self.get_models() measured_restraints = self.get_measured_restraints() if len(measured_restraints) > 0: # Drop missing restraints count_missing = measured_restraints[ 'distance_euclidean'].isna() logging.debug( f'Dropped {count_missing} crosslink restraints with ' 'empty distances' ) measured_restraints.dropna( subset=['distance_euclidean'], inplace=True) self.measured_restraints = measured_restraints output = (self.ertypes_df, self.measured_restraints) return output
[docs] def get_ertypes_df(self): # Exctract subset of data as restraint types fields_ = [ 'Linker', 'Residue 1', 'Atom 1', 'Residue 2', 'Atom 2', 'Restraint type', # 'Lower_limit, Ã…', 'Distance, Ã…', # 'Upper_limit, Ã…' ] ertypes = pd.DataFrame( self.ertypes.keys(), index=self.ertypes.values(), columns=fields_) ertypes['Count'] = None for index, row in ertypes.iterrows(): n = len( self.raw_restraints[ self.raw_restraints['restraint_enum'] == index] ) ertypes.at[index, 'Count'] = n return ertypes
[docs] def get_ertypes_df_html(self): return self.ertypes_df.to_dict('tight')
[docs] @staticmethod def format_pct_count(pct, count): if pct is None: pct_ = NA pctr_ = NA elif isinstance(pct, float): pct_ = f'{pct:.2f}' pctr_ = f'{100.0 - pct:.2f}' else: raise Exception('Wrong percentage of violated restraints') stats = { 'Satisfied': pct_, 'Violated': pctr_, 'Count': count } return stats
[docs] def get_best_distance_per_restraint(self, data): # Verify, that there is only one type of restraints in the group best = None rtype = data.iloc[0]['restraint_type'] dists = data['distance_euclidean'].to_numpy() threshold = data.iloc[0]['distance_limit'] if rtype == 'upper bound': best = min(dists) elif rtype == 'lower bound': diff = (dists - threshold) >= 0 if diff.any(): best = min(dists[diff]) else: best = max(dists) elif rtype == 'harmonic': best = dists[np.argmin(np.abs(dists - threshold))] return best
[docs] def get_best_distances_per_model_group(self): rtd_groups = {} rids = list(set(self.measured_restraints['restraint_id'])) gistg = 0 gist = 0 gimg = 0 for istg, stg in enumerate(self.system.state_groups): gistg += 1 for ist, st in enumerate(stg): gist += 1 for img, mg in enumerate(st): gimg += 1 rtd_groups[gimg] = {} for rid in rids: data_ = self.measured_restraints[ (self.measured_restraints['restraint_id'] == rid) & (self.measured_restraints['model_group'] == gimg)] if len(data_) == 0: continue rtdtype = self.get_rtdtype(data_.iloc[0]) d = self.get_best_distance_per_restraint(data_) if rtdtype not in rtd_groups[gimg]: rtd_groups[gimg][rtdtype] = [] rtd_groups[gimg][rtdtype].append(d) rtd_groups[gimg] = dict( sorted( rtd_groups[gimg].items(), key=lambda x: f'{x[0][2]}_{x[0][0]}_{x[0][1]}' ) ) return rtd_groups
[docs] def is_restraint_group_satisfied(self, data): # Verify, that there is only one type of restraints in the group satisfied_restraints = np.zeros(len(data), dtype=bool) conditional_flag_all = list(set(data['group_restraint_all']))[0] for i, (index, row) in enumerate(data.iterrows()): rtype = row['restraint_type'] ed = row['distance_euclidean'] threshold = row['distance_limit'] if rtype == 'upper bound': cmp = (ed - threshold) <= 0 elif rtype == 'lower bound': cmp = (ed - threshold) >= 0 # Check with Ben elif rtype == 'harmonic': atol = 1e-08 tol_keys = ['sigma1', 'sigma2'] for k in tol_keys: if row[k] is not None: atol += row[k] cmp = np.isclose( ed, threshold, atol=atol ) satisfied_restraints[i] = cmp if conditional_flag_all: satisfied = satisfied_restraints.all() else: satisfied = satisfied_restraints.any() return satisfied
[docs] def get_restraint_group_chain_type(self, data): rg_type = None if data['intra_chain'].all(): rg_type = 'Intramolecular' elif (~data['intra_chain']).all(): rg_type = 'Intermolecular' elif data['intra_chain'].any() and (~data['intra_chain']).any(): rg_type = 'Ambiguous' return rg_type
[docs] def get_restraint_group_entity_type(self, data): rg_type = None if data['intra_entity'].all(): rg_type = 'Self-links' elif (~data['intra_entity']).all(): rg_type = 'Heteromeric links' elif data['intra_entity'].any() and (~data['intra_entity']).any(): rg_type = 'Ambiguous entity' return rg_type
[docs] def process_restraint_groups(self, data, mode='entity'): rgs = set(data['group_id']) stats = {'All': {'satisfied': 0, 'total': 0}} for rgs_ in rgs: good_ = 0 rg_chain_type = None rg_entity_type = None data_ = data[data['group_id'] == rgs_] rg_chain_type = self.get_restraint_group_chain_type(data_) rg_entity_type = self.get_restraint_group_entity_type(data_) assert rg_chain_type is not None assert rg_entity_type is not None rg_type = f'{rg_entity_type}/{rg_chain_type}' # Check per model models = list(set(data_['model_number'])) for model in models: data__ = data_[data_['model_number'] == model] if self.is_restraint_group_satisfied(data__): good__ = 1 else: good__ = 0 good_ = good_ or good__ if rg_type not in stats: stats[rg_type] = {'satisfied': 0, 'total': 0} stats[rg_type]['satisfied'] += good_ stats[rg_type]['total'] += 1 stats['All']['satisfied'] += good_ stats['All']['total'] += 1 return stats
[docs] def get_stats_per_model_group(self): stats = {} gistg = 0 gist = 0 gimg = 0 # gim = 0 for istg, stg in enumerate(self.system.state_groups): gistg += 1 stats[gistg] = {} for ist, st in enumerate(stg): gist += 1 stats[gistg][gist] = {} for img, mg in enumerate(st): gimg += 1 stats[gistg][gist][gimg] = { 'cx_stats': None, 'ens_stats': None} cx_stats = {} data_ = self.measured_restraints[ self.measured_restraints['model_group'] == gimg] stats_ = self.process_restraint_groups(data_) for k, v in stats_.items(): pct = None count = v['total'] if count > 0: pct = v['satisfied'] / v['total'] * 100.0 cx_stats[k] = self.format_pct_count(pct, count) stats[gistg][gist][gimg]['cx_stats'] = cx_stats ens_stats = {} found_ensemble = False for i, e in enumerate(self.system.ensembles, 1): if e.model_group == mg: found_ensemble = True break if found_ensemble: try: ens_stats['ensemble_id'] = i ens_stats['num_models_deposited'] = \ e.num_models_deposited ens_stats['num_models'] = e.num_models except AttributeError: logging.error(f'Ens: {i} | Missing model_group') except TypeError: ens_stats['ensemble_id'] = i # Wait for fix in the python-ihm ens_stats['num_models_deposited'] = len(mg) ens_stats['num_models'] = e.num_models else: ens_stats['ensemble_id'] = NA ens_stats['num_models_deposited'] = len(mg) ens_stats['num_models'] = len(mg) stats[gistg][gist][gimg]['ens_stats'] = ens_stats return stats
[docs] def get_per_model_satifaction_rates(self) -> list: out = [] gistg = 0 gist = 0 gimg = 0 gim = 0 tabs_ = [] for istg, stg in enumerate(self.system.state_groups): gistg += 1 for ist, st in enumerate(stg): gist += 1 for img, mg in enumerate(st): gimg += 1 for im, m in enumerate(mg): gim += 1 data_ = self.measured_restraints[ self.measured_restraints['model_number'] == gim] stats_ = self.process_restraint_groups(data_) for k, v in stats_.items(): pct = None count = v['total'] if count > 0: pct = v['satisfied'] / v['total'] * 100.0 # out_stats[k].append(pct) r = { 'model_number': gim, 'Category': k, 'Satisfaction': pct } if k == 'All': out.append(pct) return(out)
[docs] def plot_satisfaction_per_ensemble(self, imgDirname='.'): def scatter_plot(stats): xmin, xmax = -2, 102 jitter = None if len(stats['model_number'].unique()) > 1: jitter = 'jitter' p = iqplot.stripbox( data=stats, q='Satisfaction', cats='Category', spread=jitter, frame_width=350, frame_height=max(100, len(set(stats['Category'])) * 50), marker_kwargs=dict(alpha=0.5, size=7) ) p.output_backend = "svg" p.x_range = Range1d(xmin, xmax) p.xaxis.axis_label = 'Satisfaction rate [%]' p.xaxis.axis_label_text_font_size = "14pt" p.yaxis.axis_label_text_font_size = "14pt" p.xaxis.major_label_text_font_size = "14pt" p.yaxis.major_label_text_font_size = "14pt" p.xaxis.axis_label_text_font_style = 'normal' p.yaxis.axis_label_text_font_style = 'normal' p.title.text_font_size = "14pt" p.yaxis.major_label_text_align = 'right' p.yaxis.group_text_align = 'right' p.yaxis.subgroup_text_align = 'right' p.min_border_bottom = 75 return p gistg = 0 gist = 0 gimg = 0 gim = 0 tabs_ = [] for istg, stg in enumerate(self.system.state_groups): gistg += 1 for ist, st in enumerate(stg): gist += 1 for img, mg in enumerate(st): gimg += 1 out_stats_ = [] for im, m in enumerate(mg): gim += 1 data_ = self.measured_restraints[ self.measured_restraints['model_number'] == gim] stats_ = self.process_restraint_groups(data_) for k, v in stats_.items(): pct = None count = v['total'] if count > 0: pct = v['satisfied'] / v['total'] * 100.0 # out_stats[k].append(pct) r = { 'model_number': gim, 'Category': k, 'Satisfaction': pct } out_stats_.append(r) if len(out_stats_) == 0: continue out_stats = pd.DataFrame(out_stats_) p = scatter_plot(out_stats) title = f'Satisfaction rates in Model Group {gimg}' p.title.text = title col = gridplot( [p], ncols=1, toolbar_location='right', # sizing_mode='scale_width' ) tab = Panel(child=col, title=f'Model Group {gimg}') tabs_.append(tab) tabs = Tabs(tabs=tabs_) title = 'cx_ensemble_satisfaction' return self.save_plots(tabs, title, imgDirname)
[docs] def plot_distograms_per_model_group(self, imgDirname='.'): """plot all restraints in the dataset""" data = self.get_best_distances_per_model_group() tabs_ = [] for gimg in data.keys(): plots = [] if len(data[gimg]) == 0: continue for (rt, d, lt), dists in data[gimg].items(): data_ = pd.DataFrame(dists, columns=['Crosslinks']) xmax = 1 bins = [0, 1] if len(dists) > 0: xmax = int(np.ceil(max(dists))) bins = np.linspace(0, xmax, xmax + 1) elif len(dists) > 10: bins = 'freedman-diaconis' p = iqplot.histogram( data=data_, q='Crosslinks', density=False, bins=bins, style="step_filled", frame_width=500, frame_height=100, # sizing_mode='scale_width', ) p.yaxis.ticker.desired_num_ticks = 3 p.output_backend = "svg" title = f"Model Group {gimg}; {lt}: {rt}, {d:.1f} Ã…" p.title.text_font_size = "14pt" p.xaxis.axis_label_text_font_size = "14pt" p.yaxis.axis_label_text_font_size = "14pt" p.xaxis.major_label_text_font_size = "14pt" p.yaxis.major_label_text_font_size = "14pt" p.xaxis.axis_label_text_font_style = 'normal' p.yaxis.axis_label_text_font_style = 'normal' p.yaxis.major_label_text_align = 'right' p.ray( x=d, y=0, line_color='black', angle=np.pi / 2, line_width=2 ) p.xaxis.axis_label = 'Euclidean distance [Ã…]' p.yaxis.axis_label = 'Count' p.title.text = title p.min_border_bottom = 75 plots.append(p) col = gridplot( plots, ncols=1, toolbar_location='right', # sizing_mode='scale_width' ) tab = Panel( child=col, title=f'Model Group {gimg}',) tabs_.append(tab) tabs = Tabs(tabs=tabs_, # sizing_mode='scale_width' ) title = 'cx_distograms' return self.save_plots(tabs, title, imgDirname)
[docs] def save_plots(self, plot, title, imgDirname='.'): stem = f'{self.ID_f}_{title}' imgpath = Path( imgDirname, f'{stem}.html') save( plot, imgpath, resources=CDN, title=title, ) imgpath_json = Path( imgDirname, f'{stem}.json') with open(imgpath_json, 'w') as f: json.dump(json_item(plot, title), f) imgpath_svg = Path( imgDirname, f'{stem}.svg') svgs = export_svgs(plot, filename=imgpath_svg, webdriver=self.driver, timeout=15) svgs = [Path(x).name for x in svgs] return (imgpath, imgpath_json, svgs)
[docs] @staticmethod def request_pride(url: str) -> dict: ''' pull data from PRIDE using crosslinking PDB-IHM API ''' result = None r = requests.get(url) if not r.ok: # Wait until cold request completes and go to DB cache logging.info(f'Retrying pulling {url} from PRIDE') time.sleep(60) r = requests.get(url) if r.ok: try: result = r.json() except JSONDecodeError: pass return result
[docs] def get_sequences_pride(self, pid: str) -> dict: '''get sequences from PRIDE entry''' result = None url = f"https://www.ebi.ac.uk/pride/ws/archive/crosslinking/v2/pdbdev/projects/{pid}/sequences" result = self.request_pride(url) return result
[docs] def get_residue_pairs_pride(self, pid: str, page_size: int = 99) -> dict: '''get sequences from PRIDE entry''' url = f"https://www.ebi.ac.uk/pride/ws/archive/crosslinking/v2/pdbdev/projects/{pid}/residue-pairs/based-on-reported-psm-level/passing" page = 1 url_ = f"{url}?page={page}&page_size={page_size}" result = self.request_pride(url_) rps = [] if result is not None and 'page' in result: session = requests.Session() max_page = int(result['page']["total_pages"]) page_size = int(result['page']["page_size"]) rps = [] for i in range(1, max_page + 1): url_ = f"{url}?page={i}&page_size={page_size}" rps_ = session.get(url_).json()['data'] rps.extend(rps_) return rps
[docs] def get_pride_data(self, code): ''' get data from PRIDE ''' cache_fn = Path(self.cache, f'{code}.pkl') data = None # Check if we already requested the data if Path(cache_fn).is_file(): logging.info(f'Found {code} in cache! {cache_fn}') with open(cache_fn, 'rb') as f: data = pickle.load(f) elif not Path(cache_fn).is_file(): ms_seqs = self.get_sequences_pride(code) ms_res_pairs = self.get_residue_pairs_pride(code) if ms_seqs is not None and len(ms_res_pairs) > 0: data = { 'pride_id': code, 'sequences': ms_seqs, 'residue_pairs': ms_res_pairs } with open(cache_fn, 'wb') as f: pickle.dump(data, f) else: logging.info(f'PRIDE data for {code} is incomplete') return data
[docs] def get_pride_ids(self) -> list: ''' get a list of all PRIDE ids from entry ''' ids = [] for i, d in enumerate(self.system.orphan_datasets): if isinstance(d, ihm.dataset.CXMSDataset): if isinstance(d.location, ihm.location.PRIDELocation) or \ isinstance(d.location, ihm.location.ProteomeXchangeLocation): try: pid = d.location.access_code ids.append(pid) except AttributeError: pass # Try to automatically convert jPOST ids to PRIDE ids if isinstance(d.location, ihm.location.JPOSTLocation): try: pid_ = d.location.access_code r = requests.get(f'https://repository.jpostdb.org/xml/{pid_}.0.xml') xml = ET.fromstring(r.content) pid = xml.find('Project').attrib['pxid'] ids.append(pid) logging.info(f'Found PRIDE ID {pid} for JPOST ID {pid_}') # blanket catch because there are too many # potential network exceptions except Exception as e: logging.error(e) pass return ids
[docs] def validate_pride_data(self, data: dict) -> tuple : """Match sequences, residues pairs and return stats""" out = (None, None, None) # Unpack pride data pid = data['pride_id'] ms_res_pairs = data['residue_pairs'] ms_seqs_ = data['sequences'] if len(ms_seqs_) == 0 or len(ms_res_pairs) == 0: return out # Get sequences from mmcif entry mmcif_seqs = {} mmcif_seqs_descriptions = {} for e in self.system.entities: if e.is_polymeric: seq = ''.join([x.code_canonical for x in e.sequence]) desc = e.description seq_ = pyhmmer.easel.TextSequence( sequence=seq, name=e._id.encode('utf-8'), description=desc.encode('utf-8') ).digitize(pyhmmer.easel.Alphabet.amino()) mmcif_seqs[e._id] = seq_ mmcif_seqs_descriptions[e._id] = desc # Get sequences from crosslinking-MS data ms_seqs = [ pyhmmer.easel.TextSequence( sequence=x['sequence'], name=x['id'].encode('utf-8'), source=x['file'].encode('utf-8') ).digitize(pyhmmer.easel.Alphabet.amino()) for x in ms_seqs_ ] # Match sequences using pyHMMER # select only 1st best match matched_seqs = {} matched_seqs_mapping = {} matched_seqs_ids = {} for k, v in mmcif_seqs.items(): matches_ = list(pyhmmer.hmmer.phmmer(v, ms_seqs))[0] if len(matches_) > 0: best_hit = list(pyhmmer.hmmer.phmmer(v, ms_seqs))[0][0] matched_seqs[k] = best_hit mapping_, _ = self.pyhmmer_alignment_to_map(best_hit) matched_seqs_mapping[k] = mapping_ matched_seqs_ids[k] = best_hit.best_domain.hit.name.decode("utf-8") else: logging.warning(f"Couldn't match mmCIF entity {k} to any entities in {pid}") matched_seqs[k] = None matched_seqs_ids[k] = None matched_mmcif_entities = list(matched_seqs_ids.keys()) matched_ms_seqs = list(matched_seqs_ids.values()) # Filter residue pairs from MS data ms_rps_filtered = [] for r in ms_res_pairs: eid1 = r['prot1'] rid1 = r['pos1'] eid2 = r['prot2'] rid2 = r['pos2'] sxl = tuple(sorted(((eid1, rid1), (eid2, rid2)))) ms_rps_filtered.append(sxl) ms_rps_filtered = set(ms_rps_filtered) # Select MS residue pairs from matched sequences ms_rps_mmcif_entities = 0 sxls = [] for r in ms_rps_filtered: (eid1, rid1), (eid2, rid2) = r if eid1 in matched_ms_seqs and eid2 in matched_ms_seqs: sxls.append(r) sxls = set(sxls) ms_rps_mmcif_entities = len(sxls) # Select residue pairs from the entry exls = [] for restr_ in self.system.restraints: # We are interested only in Chemical crosslinks if type(restr_) != ihm.restraint.CrossLinkRestraint: continue # Iterate over all crosslinks in the dataset for xl in restr_.cross_links: eid1 = xl.experimental_cross_link.residue1.entity._id rid1 = xl.experimental_cross_link.residue1.seq_id eid2 = xl.experimental_cross_link.residue2.entity._id rid2 = xl.experimental_cross_link.residue2.seq_id exl = tuple(sorted(((eid1, rid1), (eid2, rid2)))) exls.append(exl) exls = list(set(exls)) # Find corresponding entry - MS data crosslinks mmcif_rps_ms_entities = 0 rps_mapping = [] emxls = [] for rps in exls: (eid1, rid1), (eid2, rid2) = rps if eid1 in matched_mmcif_entities and eid2 in matched_mmcif_entities: mmcif_rps_ms_entities += 1 eid1_ = matched_seqs_ids[eid1] eid2_ = matched_seqs_ids[eid2] try: rid1_ = matched_seqs_mapping[eid1][rid1] except KeyError: logging.debug(f"Can't map residue {eid1} {rid1} to {eid1_}") continue try: rid2_ = matched_seqs_mapping[eid2][rid2] except KeyError: logging.debug(f"Can't map residue {eid2} {rid2} to {eid2_}") continue if eid1_ is None or eid2_ is None: continue exl = tuple(sorted(((eid1_, rid1_), (eid2_, rid2_)))) if exl in sxls: # Good matching crosslinks rps_mapping.append((rps, exl)) else: # Residue pairs unique to the entry rps_mapping.append((rps, None)) emxls.append(exl) emxls = set(emxls) # Add non-mapped residue pairs from MS data for exl in list(sxls): if exl not in emxls: rps_mapping.append((None, exl)) # Calculate some stats rps_both = len(set(emxls) & set(sxls)) mmcif_rps = len(exls) ms_rps = len(ms_rps_filtered) # Prepare output out = { 'pride_id': pid, 'entities_ms': len(ms_seqs), 'entities': len(mmcif_seqs), 'matches': [], 'stats': { 'entry': { 'total': mmcif_rps, 'mapped_entities': mmcif_rps_ms_entities, 'mapped_entities_pct': mmcif_rps_ms_entities / mmcif_rps * 100., 'matched': rps_both, 'matched_pct': rps_both / mmcif_rps * 100., }, 'ms': { 'total': ms_rps, 'mapped_entities': ms_rps_mmcif_entities, 'mapped_entities_pct': ms_rps_mmcif_entities / ms_rps * 100., 'matched': rps_both, 'matched_pct': rps_both / ms_rps * 100., } } } # Add stats about matches for k, v in matched_seqs.items(): if v is not None: match_ = { 'entity': k, 'entity_desc': mmcif_seqs_descriptions[k], 'entity_ms': v.best_domain.hit.name.decode("utf-8"), 'e-value': v.best_domain.c_evalue, 'exact_match': v.best_domain.alignment.target_sequence == v.best_domain.alignment.hmm_sequence.upper(), } else: match_ = { 'entity': k, 'entity_desc': mmcif_seqs_descriptions[k], 'entity_ms': utility.NA, 'e-value': utility.NA, 'exact_match': utility.NA, } out['matches'].append(match_) return (out, matched_seqs, rps_mapping)
[docs] def validate_all_pride_data(self) -> list: '''perform data quality validation for all crosslinking-MS datasets''' codes = self.get_pride_ids() outs = [] for code in codes: data = self.get_pride_data(code) if data is not None: out, _, __ = self.validate_pride_data(data) if out is not None: outs.append(out) return outs
[docs] @staticmethod def get_pyhmmer_version(): """return pyhmmer version""" return pyhmmer.__version__
[docs] @staticmethod def pyhmmer_alignment_to_map(hit) -> (dict, list): """Convert HMMER alignment into residue map""" # gaps in HMMER text format GAPS = set(['-', '.']) mapping_raw = [] mapping_short = {} aln = hit.best_domain.alignment mmcif_start = aln.hmm_from mmcif_seq = aln.hmm_sequence db_start = aln.target_from db_seq = aln.target_sequence ii = mmcif_start - 1 ij = db_start - 1 # iterate over alignment # residue indices start from 1 for i, (aai, aaj) in enumerate(zip(mmcif_seq.upper(), db_seq.upper())): if aai not in GAPS: ii += 1 if aaj not in GAPS: ij += 1 if len(set([aai, aaj]) & GAPS) == 0: mapping_raw.append(((ii, aai), (ij, aaj))) mapping_short[ii] = ij # return dict to map residue indices and raw mapping data return mapping_short, mapping_raw