Source code for ihm_validation.get_plots

# -*- coding: utf-8 -*-
#
# get_plots.py.py - Generate overview plots
#
# Copyright (C) 2019-2025 Arthur Zalevsky, Sai Ganesan, Benjamin M. Webb, Brinda Vallat
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""
Generate overview plots
"""

import os
from pathlib import Path
import utility
from mmcif_io import GetInputInformation
import bokeh
import numpy as np
from bokeh.io import output_file, curdoc, export_svg, show
from bokeh.models import (ColumnDataSource, Legend, LegendItem, FactorRange,
                          Div, BasicTickFormatter)
from bokeh.palettes import linear_palette, Greys256, Blues256, Oranges256, Greens256
from bokeh.plotting import figure, save
from bokeh.models.widgets import Tabs, Panel
from bokeh.layouts import row
from bokeh.core.validation import silence
from bokeh.core.validation.warnings import MISSING_RENDERERS, EMPTY_LAYOUT
from bokeh.transform import factor_cmap
from bokeh.layouts import gridplot, column
silence(MISSING_RENDERERS, True)
silence(EMPTY_LAYOUT, True)

import json
from bokeh.embed import json_item

MAXPLOTS = 256

[docs] class Plots(GetInputInformation): def __init__(self, *args, imageDirName, **kwargs): super().__init__(*args, **kwargs) self.imageDirName = imageDirName
[docs] def plot_quality_at_glance(self, molprobity_data: dict=None, exv_data: dict=None, sas_data_quality: dict=None, sas_fit: dict=None, cx_data_quality: dict=None, cx_fit: dict=None, em_data_quality: dict=None, em_fit: dict=None ) -> dict: glance_plots = { 'MQ': False, 'DQ': False, 'FQ': False, } # MODEL QUALITY mq_plots = [] # check for MolProbity or excluded volume data if molprobity_data is not None: # if MolProbity data, plot that # every model has clashscore, rama outliers, and rota outliers Models = list(molprobity_data.keys()) Scores = [ 'Sidechain outliers', 'Ramachandran outliers', 'Clashscore', ] data = {'models': Models} for s in Scores: data[s] = [x[s] for x in molprobity_data.values()] y = [(f"Model {model}", score) for model in Models for score in Scores] counts = sum( zip( data['Sidechain outliers'], data['Ramachandran outliers'], data['Clashscore'], ), () ) source = ColumnDataSource(data=dict(y=y, counts=counts)) # if there are more than 7 models, we will increase the size of the plots # this is important, else the plots look ugly plots = [] # get data ranges lower, upper = utility.calc_optimal_range(counts) # create plot for i, name_ in enumerate(Models): p = figure( y_range=FactorRange(*y[i * 3: (i + 1) * 3]), # Force left limit at zero x_range=(0, upper), plot_height=120, plot_width=700 ) p_ = p.hbar(y=source.data['y'][i * 3: (i + 1) * 3], right=source.data['counts'][i * 3: (i + 1) * 3], line_color="white", fill_color=factor_cmap('y', palette=linear_palette(Greys256, len(Scores) + 2)[1:-1], factors=Scores, start=1, end=2) ) # set labels and fonts p.xaxis.axis_label_text_font_size = "14pt" p.yaxis.axis_label_text_font_size = "14pt" p.xaxis.major_label_text_font_size = "14pt" p.yaxis.major_label_text_font_size = "14pt" p.xaxis.axis_label_text_font_style = 'normal' p.yaxis.axis_label_text_font_style = 'normal' p.xaxis.axis_label = 'Outliers' p.left[0].group_text_font_size = '14pt' p.left[0].group_text_color = p.left[0].major_label_text_color p.left[0].group_text_font_style = p.left[0].major_label_text_font_style p.left[0].group_text_font_size = p.left[0].major_label_text_font_size p.left[0].group_label_orientation = 'horizontal' p.title.vertical_align = 'top' p.title.align = "center" p.output_backend = "svg" fname = Path(self.imageDirName, f"{self.ID_f}_{name_}_quality_at_glance_MQ_mp.svg") export_svg(p, filename=fname) plots.append(p) grid = gridplot(plots, ncols=1, merge_tools=True, toolbar_location='right') grid.children[1].css_classes = ['scrollable'] grid.children[1].sizing_mode = 'fixed' grid.children[1].height = 450 grid.children[1].width = 800 title = Div(text="<p>Model Quality: Molprobity Analysis</p>", style={"font-size": "1.5em", "font-weight": "bold", "text-align": "center", "width": '100%'}, width=800 ) fullplot = column(title, grid) mq_plots.append(fullplot) # If excluded volume data is available if exv_data: numplots = min(len(exv_data['Model ID']), MAXPLOTS) Models = exv_data['Model ID'][:numplots] # analysed = exv_data['Analysed'][:numplots] # violations = exv_data['Number of violations'][:numplots] # satisfaction = exv_data['Number of violations'][:numplots] satisfaction = exv_data['Excluded Volume Satisfaction (%)'][:numplots] # make sure data is plot-able try: counts = [float(i) for i in satisfaction] except (ValueError): return Scores = [f'Model {m}' for m in Models] legends = [f'{s:.2f} %' for s in satisfaction] # set the size of the axis # n = 3 if len(model) < 3 else len(model) n = len(counts) color = linear_palette(Greys256, 5)[1] # match with darkest molprobity color colors = [color for x in range(n)] source = ColumnDataSource( data=dict(Scores=Scores, counts=counts, legends=legends, color=colors)) # build plots plots = [] # get ranges lower, upper = 0, 102 for i, name_ in enumerate(Models): p = figure(y_range=FactorRange(factors=source.data['Scores'][i: i + 1]), x_range=(lower, upper), plot_height=90, plot_width=700) # , title='Model Quality: Excluded Volume Analysis') # p.xaxis.formatter = BasicTickFormatter(use_scientific=True, power_limit_high=3) p.xaxis.ticker.desired_num_ticks = 3 p_ = p.hbar(y=source.data['Scores'][i:i + 1], right=source.data['counts'][i: i + 1], color=source.data['color'][i:i + 1], height=1.0, line_color='white') p.xaxis.axis_label = 'Satisfaction rate [%]' legend = Legend(items=[LegendItem(label=legends[i:i + 1][j], renderers=[ p_], index=j) for j in range(len(legends[i:i + 1]))], location='center', label_text_font_size='14pt', orientation='vertical') p.add_layout(legend, 'right') p.legend.border_line_width = 0 p.xaxis.major_label_text_font_size = "14pt" p.yaxis.major_label_text_font_size = "14pt" p.xaxis.axis_label_text_font_style = 'normal' p.yaxis.axis_label_text_font_style = 'normal' p.title.vertical_align = 'top' p.title.align = "center" p.output_backend = "svg" p.min_border_top = 20 fname = Path(self.imageDirName, f"{self.ID_f}_{name_}_quality_at_glance_MQ_exv.svg") export_svg(p, filename=fname) plots.append(p) grid = gridplot(plots, ncols=1, merge_tools=True, toolbar_location='right') grid.children[1].css_classes = ['scrollable'] grid.children[1].sizing_mode = 'fixed' grid.children[1].height = 450 grid.children[1].width = 800 title = Div(text='<p>Model Quality: Excluded Volume Analysis</p>', style={"font-size": "1.5em", "font-weight": "bold", "text-align": "center", "width": '100%'}, width=800 ) fullplot = column(title, grid) mq_plots.append(fullplot) # if neither exc vol nor molp data exists, we create a blank plot # pdb-ihm visuals keep changing, so this plot might or might not make sense # we are keeping it, just in case the visuals change again # else: # Scores = [''] # counts = [''] # legends = [''] # source = ColumnDataSource( # data=dict(Scores=Scores, counts=counts, legends=legends)) # p = figure(y_range=Scores, x_range=(0, 1), # plot_height=300, plot_width=800) # # p.ygrid.grid_line_color = None # p.xaxis.axis_label_text_font_size = "14pt" # p.yaxis.axis_label_text_font_size = "14pt" # p.title.text_font_size = '14pt' # p.title.align = "center" # # p.output_backend = "svg" # p.title.vertical_align = 'top' # fullplot = p # # mq_plots.append(fullplot) # make panel figures # first panel is model quality # export_svg(fullplot, filename=self.filename+'/' + # self.ID_f + "_quality_at_glance_MQ.svg") # save(fullplot, filename=self.filename+'/' + # self.ID_f +"_quality_at_glance_MQ.html") if len(mq_plots) > 0: pd = gridplot(mq_plots, ncols=1, merge_tools=True, toolbar_location="above", # sizing_mode='stretch_width' ) # pd = column(*dq_plots) plots = self.save_plots(pd, 'quality_at_glance_MQ') glance_plots['MQ'] = True # DATA QUALITY # check for sas data, if exists, plot # this section will be updated with more data assessments, as and when it is complete dq_plots = [] if len(sas_data_quality.keys()) > 0: # Don't forget to update palette if adding new metric Rgl = {0: 'P(r)', 1: 'Guinier'} Scores = [Rgl[m] + ' ('+i+')' for i, j in sas_data_quality.items() for m, n in enumerate(j)] counts = [float(n)for i, j in sas_data_quality.items() for m, n in enumerate(j)] legends = [str(i)+' nm' for i in counts] # Update palette if we need more colors colors_ = linear_palette(Blues256, len(Rgl) + 2)[1:-1] colors = [colors_[m] for i, j in sas_data_quality.items() for m, n in enumerate(j)] source = ColumnDataSource(data=dict( Scores=Scores, counts=counts, legends=legends, color=colors)) pd = figure(y_range=Scores, x_range=(0, max( counts)+1), plot_height=90 + len(counts) * 20, plot_width=800, title="Data Quality for SAS: Rg Analysis",) rd = pd.hbar(y='Scores', right='counts', color='color', height=1.0, source=source, line_color='white') pd.ygrid.grid_line_color = None pd.xaxis.axis_label = 'Distance [nm]' pd.title.text_font_size = '14pt' legend = Legend(items=[LegendItem(label=legends[i], renderers=[ rd], index=i) for i in range(len(legends))], location='center', orientation='vertical', label_text_font_size="14pt") pd.add_layout(legend, 'right') pd.legend.items.reverse() pd.legend.border_line_width = 0 pd.legend.label_text_font_size = "14pt" pd.xaxis.axis_label_text_font_size = "14pt" pd.yaxis.axis_label_text_font_size = "14pt" pd.xaxis.major_label_text_font_size = "14pt" pd.yaxis.major_label_text_font_size = "14pt" pd.xaxis.axis_label_text_font_style = 'normal' pd.yaxis.axis_label_text_font_style = 'normal' pd.title.vertical_align = 'top' pd.title.align = "center" pd.output_backend = "svg" dq_plots.append(pd) # If crosslinking-MS data is available if cx_data_quality is not None and len(cx_data_quality) > 0: Models = [data["pride_id"] for data in cx_data_quality] # Don't forget to update palette if adding new metric Scores = ['Total', 'Mapped to matching entities', 'Matched'] legends = [] for data in cx_data_quality: legends.append(f'{data["stats"]["ms"]["total"]}') legends.append(f'{data["stats"]["ms"]["mapped_entities"]} ({data["stats"]["ms"]["mapped_entities_pct"]:.2f}%)') legends.append(f'{data["stats"]["ms"]["matched"]} ({data["stats"]["ms"]["matched_pct"]:.2f}%)') data = {'models': Models, 'Total': [data['stats']['ms']['total'] for data in cx_data_quality], 'Mapped to matching entities': [data['stats']['ms']['mapped_entities'] for data in cx_data_quality], 'Matched': [data['stats']['ms']['matched'] for data in cx_data_quality] } y = [(model, score) for model in Models for score in Scores] counts = sum(zip(data['Total'], data['Mapped to matching entities'], data['Matched']), ()) source = ColumnDataSource(data=dict(y=y, counts=counts, legends=legends)) # if there are more than 7 models, we will increase the size of the plots # this is important, else the plots look ugly plots = [] # get data ranges lower, upper = utility.calc_optimal_range(counts) # create plot for i, name_ in enumerate(Models): title_txt = "Crosslinking-MS Data Quality" title = Div(text=f"<p>{title_txt}</p>", style={"font-size": "1.5em", "font-weight": "bold", "text-align": "center"} ) p = figure( y_range=FactorRange(*y[i * 3: (i + 1) * 3]), # Force left limit at zero x_range=(lower, upper), plot_height=95 + 3 * 20, plot_width=700, title=title_txt ) p.xaxis.ticker.desired_num_ticks = 3 rd = p.hbar(y=source.data['y'][i * 3: (i + 1) * 3], right=source.data['counts'][i * 3: (i + 1) * 3], line_color="white", fill_color=factor_cmap('y', palette=linear_palette(Oranges256, len(Scores) + 2)[1:-1], factors=Scores, start=1, end=2) ) legends_ = source.data['legends'][i * 3: (i + 1) * 3] legend = Legend(items=[LegendItem(label=legends_[j], renderers=[ rd], index=j) for j in range(len(legends_))], location='center', orientation='vertical', label_text_font_size="14pt") legend.items = legend.items[::-1] p.add_layout(legend, 'right') p.legend.border_line_width = 0 # set labels and fonts p.xaxis.axis_label_text_font_size = "14pt" p.yaxis.axis_label_text_font_size = "14pt" p.xaxis.major_label_text_font_size = "14pt" p.yaxis.major_label_text_font_size = "14pt" p.xaxis.axis_label_text_font_style = 'normal' p.yaxis.axis_label_text_font_style = 'normal' p.yaxis.major_label_text_align='right' p.xaxis.axis_label = 'Residue pairs' p.left[0].group_text_font_size = '14px' p.left[0].group_label_orientation = 'horizontal' p.title.text_font_size = '14pt' p.title.vertical_align = 'top' p.title.align = "center" p.output_backend = "svg" plots.append(p) grid = gridplot(plots, ncols=1, merge_tools=True, toolbar_location=None, ) # grid.children[1].css_classes = ['scrollable'] # grid.children[1].sizing_mode = 'fixed' # grid.children[1].height = 450 # grid.children[1].width = 800 # fullplot = column(title, grid) dq_plots.append(grid) if em_data_quality is not None and len(em_data_quality) > 0: Scores = [] counts = [] for dataset in em_data_quality: emdbid = dataset['emdbid'] try: s = float(dataset['data_stats']['resolution']) except (ValueError, TypeError): continue Scores.append(emdbid) counts.append(s) if len(counts) > 0: legends = [f'{i:.2f} Å' for i in counts] source = ColumnDataSource(data=dict( Scores=Scores, counts=counts, legends=legends, color=linear_palette(Greens256, len(legends) + 2)[1:-1])) pf = figure(y_range=Scores, x_range=(0, 80), plot_height=95 + len(counts) * 20, plot_width=800, title="3DEM resolution") rf = pf.hbar(y='Scores', right='counts', color='color', height=1.0, source=source, line_color='white') pf.ygrid.grid_line_color = None pf.title.text_font_size = '14pt' pf.xaxis.axis_label = 'Resolution [Å]' legend = Legend(items=[LegendItem(label=legends[i], renderers=[ rf], index=i) for i in range(len(legends))], location="center", orientation='vertical', label_text_font_size="14pt") pf.add_layout(legend, 'right') pf.legend.items.reverse() pf.legend.border_line_width = 0 pf.title.vertical_align = 'top' pf.title.align = "center" pf.output_backend = "svg" pf.legend.label_text_font_size = "14pt" pf.xaxis.axis_label_text_font_size = "14pt" pf.yaxis.axis_label_text_font_size = "14pt" pf.xaxis.major_label_text_font_size = "14pt" pf.yaxis.major_label_text_font_size = "14pt" pf.xaxis.axis_label_text_font_style = 'normal' pf.yaxis.axis_label_text_font_style = 'normal' dq_plots.append(pf) if len(dq_plots) > 0: pd = gridplot(dq_plots, ncols=1, merge_tools=True, toolbar_location="above", # sizing_mode='stretch_width' ) # pd = column(*dq_plots) plots = self.save_plots(pd, 'quality_at_glance_DQ') glance_plots['DQ'] = True # FIT TO DATA QUALITY # check for sas data, if exists, plot # this section will be updated with more data assessments, as and when it is complete fq_plots = [] if len(sas_fit.keys()) > 0: Scores = [' \u03C7\u00b2 Fit ' + str(int(m+1)) + ' ('+i+')' for i, j in sas_fit.items() for m, n in enumerate(j)] counts = [float(n) for i, j in sas_fit.items() for m, n in enumerate(j)] legends = [str(i) for i in counts] source = ColumnDataSource(data=dict( Scores=Scores, counts=counts, legends=legends, color=linear_palette(Blues256, len(legends) + 2)[1:-1])) pf = figure(y_range=Scores, x_range=(0, max(counts)+1), plot_height=100 + len(counts) * 20, plot_width=800, title="Fit to SAS Data: \u03C7\u00b2 Fit") rf = pf.hbar(y='Scores', right='counts', color='color', height=1.0, source=source, line_color='white') pf.ygrid.grid_line_color = None pf.xaxis.axis_label = 'Fit value' legend = Legend(items=[LegendItem(label=legends[i], renderers=[ rf], index=i) for i in range(len(legends))], location="center", orientation='vertical', label_text_font_size="14pt") pf.add_layout(legend, 'right') pf.legend.items.reverse() pf.legend.border_line_width = 0 pf.title.vertical_align = 'top' pf.title.align = "center" pf.output_backend = "svg" pf.title.text_font_size = '14pt' pf.legend.label_text_font_size = "14pt" pf.xaxis.axis_label_text_font_size = "14pt" pf.yaxis.axis_label_text_font_size = "14pt" pf.xaxis.major_label_text_font_size = "14pt" pf.yaxis.major_label_text_font_size = "14pt" pf.xaxis.axis_label_text_font_style = 'normal' pf.yaxis.axis_label_text_font_style = 'normal' pf.title.vertical_align = 'top' pf.title.align = "center" pf.output_backend = "svg" fq_plots.append(pf) if cx_fit is not None and len(cx_fit) > 0: Scores = [] counts = [] i = 0 for sg, sgv in cx_fit.items(): for st, stv in sgv.items(): for mg, mgv in stv.items(): i += 1 try: s = float(mgv['cx_stats']['All']['Satisfied']) except ValueError: continue Scores.append(f'Model group/Ensemble {i}') counts.append(s) if len(counts) > 0: legends = [f'{i} %' for i in counts] # Select dark orange; # identical in all plots because they're separated source = ColumnDataSource(data=dict( Scores=Scores, counts=counts, legends=legends, color=linear_palette(Oranges256, len(legends) + 2)[1:-1])) pf = figure(y_range=Scores, x_range=(0, 102), plot_height=95 + len(counts) * 20, plot_width=800, title="Crosslink satisfaction") rf = pf.hbar(y='Scores', right='counts', color='color', height=1.0, source=source, line_color='white') pf.ygrid.grid_line_color = None pf.xaxis.axis_label = 'Satisfaction rate [%]' legend = Legend(items=[LegendItem(label=legends[i], renderers=[ rf], index=i) for i in range(len(legends))], location="center", orientation='vertical', label_text_font_size="14pt") pf.add_layout(legend, 'right') pf.legend.items.reverse() pf.legend.border_line_width = 0 pf.title.vertical_align = 'top' pf.title.align = "center" pf.output_backend = "svg" pf.title.text_font_size = '14pt' pf.legend.label_text_font_size = "14pt" pf.xaxis.axis_label_text_font_size = "14pt" pf.yaxis.axis_label_text_font_size = "14pt" pf.xaxis.major_label_text_font_size = "14pt" pf.yaxis.major_label_text_font_size = "14pt" pf.xaxis.axis_label_text_font_style = 'normal' pf.yaxis.axis_label_text_font_style = 'normal' fq_plots.append(pf) if em_fit is not None and len(em_fit) > 0: Scores = [] counts = [] for dataset in em_fit: emdbid = dataset['emdbid'] for mid, data_ in dataset['fit_stats'].items(): try: s = float(data_['q_score']['average']) except ValueError: continue Scores.append(f'Model {mid}/{emdbid}') counts.append(s) if len(counts) > 0: legends = [f'{i:.3f}' for i in counts] source = ColumnDataSource(data=dict( Scores=Scores, counts=counts, legends=legends, color=linear_palette(Greens256, len(legends) + 2)[1:-1])) pf = figure(y_range=Scores, x_range=(-1, 1), plot_height=95 + len(counts) * 20, plot_width=800, title="Q-score") rf = pf.hbar(y='Scores', right='counts', color='color', height=1.0, source=source, line_color='white') pf.ygrid.grid_line_color = None pf.xaxis.axis_label = 'Q-score' legend = Legend(items=[LegendItem(label=legends[i], renderers=[ rf], index=i) for i in range(len(legends))], location="center", orientation='vertical', label_text_font_size="14pt") pf.add_layout(legend, 'right') pf.legend.items.reverse() pf.legend.border_line_width = 0 pf.title.vertical_align = 'top' pf.title.align = "center" pf.output_backend = "svg" pf.title.text_font_size = '14pt' pf.legend.label_text_font_size = "14pt" pf.xaxis.axis_label_text_font_size = "14pt" pf.yaxis.axis_label_text_font_size = "14pt" pf.xaxis.major_label_text_font_size = "14pt" pf.yaxis.major_label_text_font_size = "14pt" pf.xaxis.axis_label_text_font_style = 'normal' pf.yaxis.axis_label_text_font_style = 'normal' fq_plots.append(pf) if len(fq_plots) > 0: pd = gridplot(fq_plots, ncols=1, toolbar_location="above", # sizing_mode='stretch_width' ) plots = self.save_plots(pd, 'quality_at_glance_FQ') glance_plots['FQ'] = True return glance_plots
[docs] def save_plots(self, p, plot_name: str) -> dict: """Save html and svg plots""" fname_html = Path(self.imageDirName, f"{self.ID_f}_{plot_name}.html") fname_svg = fname_html.with_suffix('.svg') fname_json = fname_html.with_suffix('.json') # output_file(filename=fname_html, mode='inline') # save(p) export_svg(p, filename=fname_svg) with open(fname_json, 'w') as f: json.dump(json_item(p), f) return {'html': fname_html, 'svg': fname_svg, 'json': fname_json}