Source code for descmap.analysis.plot

"""Functionality to plot interactive graphs"""
import os
import numpy as np
from pathlib import Path
import chart_studio.plotly as py
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import matplotlib
import matplotlib.pyplot as plt

[docs]def plot_contour(path_out, x, y, z, title, x_label, y_label, hover_label, x_lit=None, y_lit=None, lit_label=None, zmin_cutoff=None, zmax_cutoff=None, x_scaling=None, y_scaling=None, y_scaling_upper=None, y_scaling_lower=None): """Creates a Contour plot using Plotly Parameters ---------- path_out : str Name of file (ending with .html extension) x : list x data to plot y : list y data to plot z : list z data to plot title : str Title of plot x_label : str x axis label y_label : str y axis label zmin_cutoff : float, optional Minimum cutoff z value. The minimum of ``z`` will be used if it is higher than ``zmin_cutoff`` or if ``zmin_cutoff`` is not specified. zmax_cutoff : float, optional Maximum cutoff z value. The maximum of ``z`` will be used if it is higher than ``zmax_cutoff`` of ir ``zmax_cutoff`` is not specified. """ layout={'title': {'text': title}, 'xaxis': {'title': x_label, 'tickformat': '0.2f', 'ticks': 'outside', 'mirror': True, 'showline': True}, 'yaxis': {'title': y_label, 'tickformat': '0.2f', 'ticks': 'outside', 'mirror': True, 'showline': True, 'linewidth': 2.}, 'legend': {'x': 0., 'y': 1}} # Process zmin_cutoff zmin_cutoff, zmax_cutoff = _get_cutoff(data=z, min_cutoff=zmin_cutoff, max_cutoff=zmax_cutoff) fig = go.Figure(go.Contour(x=x, y=y, z=z, hovertext=hover_label, zmin=zmin_cutoff, zmax=zmax_cutoff, name=title, showlegend=False, colorscale="Viridis"), layout=layout) if x_lit is not None: fig.add_trace(go.Scatter(x=x_lit, y=y_lit, hovertext=lit_label, mode='markers+text', text=lit_label, name='Literature', textposition="bottom center", legendgroup='Literature', showlegend=True, textfont={'size': 15, 'color': 'white'}, marker={'size': 10, 'color': 'white', 'line': {'width': 2, 'color': 'black'}})) # Add scaling line between descriptors if y_scaling is not None: # Add line to plot fig.add_trace(go.Scatter(x=x_scaling, y=y_scaling, mode='lines', name='Descriptor Correlation', legendgroup='Intervals', showlegend=False, line = {'width': 1, 'color': 'white'})) minor_scaling_kwargs = {'mode': 'lines', 'name': 'Prediction Intervals', 'legendgroup': 'Intervals', 'line': {'width': 1, 'color': 'gray'}} if y_scaling_upper is not None: fig.add_trace(go.Scatter(x=x_scaling, y=y_scaling_upper, showlegend=True, **minor_scaling_kwargs)) if y_scaling_lower is not None: fig.add_trace(go.Scatter(x=x_scaling, y=y_scaling_lower, showlegend=False, **minor_scaling_kwargs)) # x_linspace = np.linspace(np.min(x_lit), np.max(x_lit)) # y_linspace = lit_lin_model.predict(X=x_linspace.reshape(-1, 1)).flatten() # upper_interval, lower_interval = get_prediction_interval(y_data=y_lit_model.flatten(), # y_model=y_lit_pred, # x_data=x_lit, # x_linspace=x_linspace, # alpha=0.99) # upper_interval = upper_interval + y_linspace # lower_interval = lower_interval + y_linspace # x_ell, y_ell = get_ellipse(x=x_lit.values, # y=y_lit.values, # x_model=x_lit_model, # y_model=y_lit_pred, # slope=lit_lin_model.coef_[0][0], # n=50) # fig.add_trace(go.Scatter(x=x_ell, # y=y_ell, # mode='lines', # name='Ellipse')) # fig.update_layout(legend_orientation="h") fig.write_html(path_out) fig.write_image(path_out.replace('html', 'png'), scale=10, width=6, height=8) fig.write_image(path_out.replace('html', 'svg'), scale=10, width=6, height=8)
[docs]def plot_density(path_out, jobs_data, desc_labels, conv_data, selec_data, yield_data, reactant_name, product_name, hover_label, lit_data=None, design_space_mask=None, conv_min_cutoff=0., conv_max_cutoff=1., selec_min_cutoff=0., selec_max_cutoff=1., yield_min_cutoff=0., yield_max_cutoff=1.): """Plots a density map normalized to x axis slices Parameters ---------- path_out : str Name of file (ending with .html extension) x : list x data to plot y : list y data to plot title : str Title of plot x_label : str x axis label y_label : str y axis label kwargs : keyword arguments Keyword arguments for `numpy.histogram2d`_ .. _`numpy.histogram2d`: https://numpy.org/doc/1.18/reference/generated/numpy.histogram2d.html """ fig = make_subplots(cols=len(desc_labels), rows=3) for i, desc_label in enumerate(desc_labels, start=1): # Get x data x = jobs_data[desc_label] QoI_data = (conv_data, selec_data, yield_data) QoI_labels = ('{} Conv'.format(reactant_name), '{} Selectivity'.format(product_name), '{} Yield'.format(product_name)) ymin_cutoffs = (conv_min_cutoff, selec_min_cutoff, yield_min_cutoff) ymax_cutoffs = (conv_max_cutoff, selec_max_cutoff, yield_max_cutoff) for j, (y_data, y_label, ymin_cutoff, ymax_cutoff) in \ enumerate(zip(QoI_data, QoI_labels, ymin_cutoffs, ymax_cutoffs), start=1): # Since legends are grouped, only display one toggle option if i == 1 and j == 1: showlegend = True else: showlegend = False ymin_cutoff, ymax_cutoff = _get_cutoff(data=y_data, min_cutoff=ymin_cutoff, max_cutoff=ymax_cutoff) trace_name = '{}<br />vs.<br />{}'.format(y_label, desc_label) # hist_trace = go.Histogram2dContour(x=x, y=y_data, # name=trace_name, # showscale=False, # colorscale='Blues') scatter_trace = go.Scatter(x=x, y=y_data, mode='markers', hovertext=hover_label, name='MKM Data', marker={'size': 3, 'color': 'rgba(0., 0., 0., 0.5)'}, legendgroup='MKM Data', showlegend=showlegend) '''Find Pareto Optimal for Design Space''' # Get unique values x_unique = sorted(np.unique(x)) y_largest = [] y_smallest = [] y_largest_design = [] y_smallest_design = [] for x_val in x_unique: # Find indices corresponding to x value k = np.where(x_val == x)[0] y_largest_val = np.max(y_data[k]) y_smallest_val = np.min(y_data[k]) y_largest.append(y_largest_val) y_smallest.append(y_smallest_val) # Find indices inside design space k = np.where(np.logical_and(x_val == x, np.array(design_space_mask) == True))[0] y_largest_design_val = np.max(y_data[k]) y_smallest_design_val = np.min(y_data[k]) y_largest_design.append(y_largest_design_val) y_smallest_design.append(y_smallest_design_val) pareto_largest_trace = go.Scatter(x=x_unique, y=y_largest, fill=None, mode='lines', line_color='#1f77b4', showlegend=False, legendgroup='QoI Range (All)', name='QoI Range (All)') pareto_smallest_trace = go.Scatter(x=x_unique, y=y_smallest, fill='tonexty', mode='lines', line_color='#1f77b4', showlegend=showlegend, legendgroup='QoI Range (All)', name='QoI Range (All)') pareto_largest_design_trace = go.Scatter(x=x_unique, y=y_largest_design, fill=None, mode='lines', legendgroup='QoI Range (Prediction Intervals)', line_color='#ff7f0e', showlegend=False, name='QoI Range (Prediction Intervals)') pareto_smallest_design_trace = go.Scatter(x=x_unique, y=y_smallest_design, fill='tonexty', mode='lines', legendgroup='QoI Range (Prediction Intervals)', line_color='#ff7f0e', showlegend=showlegend, name='QoI Range (Prediction Intervals)') fig.add_traces(data=[scatter_trace, pareto_largest_trace, pareto_smallest_trace, pareto_largest_design_trace, pareto_smallest_design_trace], cols=(i, i, i, i, i), rows=(j, j, j, j, j)) # Update axes fig.update_xaxes(title_text='{} (eV)'.format(desc_label), col=i, row=j, ticks='outside', tickformat='.2f', showline=True, mirror=True, linewidth=1., linecolor='black') fig.update_yaxes(tickformat='.2f', showline=True, mirror=True, row=j, col=i, linewidth=1., linecolor='black', range=[ymin_cutoff, ymax_cutoff]) # Add Y Label if i == 1: for j, y_label in enumerate(QoI_labels, start=1): fig.update_yaxes(title_text=y_label, col=i, row=j, ticks='outside') # Add literature data # if lit_data is not None: # x_lit = lit_data[desc_label] # for j in range(1, 4): # for name, x_lit_point in x_lit.iteritems(): # fig.add_shape({'type': 'line', # 'xref': 'x', # 'yref': 'paper', # 'x0': x_lit_point, # 'x1': x_lit_point, # 'y0': , # 'y1': 1, # }, col=i, row=j) fig.write_html(path_out) # fig.write_image(path_out.replace('html', 'png'), # scale=10, width=8, height=14) # fig.write_image(path_out.replace('html', 'svg'), # scale=10, width=8, height=14) fig.update_layout(width=500.*len(desc_labels), height=800.) fig.write_image(path_out.replace('html', 'svg'))
[docs]def plot_1d_volcano(path_out, x, y, title, x_label, y_label, hover_label, x_lit=None, lit_label=None, ymin_cutoff=None, ymax_cutoff=None): """Creates a Contour plot using Plotly Parameters ---------- path_out : str Name of file (ending with .html extension) x : list x data to plot y : list y data to plot title : str Title of plot x_label : str x axis label y_label : str y axis label ymin_cutoff : float, optional Minimum cutoff y value. The minimum of ``y`` will be used if it is higher than ``ymin_cutoff`` or if ``ymin_cutoff`` is not specified. ymax_cutoff : float, optional Maximum cutoff y value. The maximum of ``y`` will be used if it is higher than ``ymax_cutoff`` of ir ``ymax_cutoff`` is not specified. """ layout={'title': {'text': title}, 'xaxis': {'title': x_label, 'tickformat': '0.2f', 'ticks': 'outside', 'mirror': True, 'showline': True}, 'yaxis': {'title': y_label, 'tickformat': '0.2f', 'ticks': 'outside', 'mirror': True, 'showline': True, 'linewidth': 2.}, 'legend': {'x': 0., 'y': 1}} ymin_cutoff, ymax_cutoff = _get_cutoff(data=y, min_cutoff=ymin_cutoff, max_cutoff=ymax_cutoff) fig = go.Figure(go.Scatter(x=x, y=y, hovertext=hover_label, mode='lines'), layout=layout) fig.update_yaxes(range=[ymin_cutoff, ymax_cutoff]) if x_lit is not None: interp_fn = interpolate.interp1d(x=x, y=y) y_lit = interp_fn(x_lit) fig.add_trace(go.Scatter(x=x_lit, y=y_lit, hovertext=lit_label, mode='markers+text', text=lit_label, name='Literature', textposition="top right", legendgroup='Literature', showlegend=True, textfont={'size': 15, 'color': 'black'}, marker={'size': 10, 'color': 'white', 'line': {'width': 2, 'color': 'black'}})) fig.update_layout(legend_orientation="h") fig.write_html(path_out) fig.write_image(path_out.replace('html', 'png'), scale=10, width=6, height=8) fig.write_image(path_out.replace('html', 'svg'), scale=10, width=6, height=8)
[docs]def plot_1d_simple(path_out, x, y, title, x_label, y_label, ticketformat, ymin_cutoff = None, ymax_cutoff = None): """Creates a simple 1d volcano plot using Plotly Parameters ---------- path_out : str Name of file (ending with .html extension) x : list x data to plot y : list y data to plot title : str Title of plot x_label : str x axis label y_label : str y axis label ymin_cutoff : float, optional Minimum cutoff y value. The minimum of ``y`` will be used if it is higher than ``ymin_cutoff`` or if ``ymin_cutoff`` is not specified. ymax_cutoff : float, optional Maximum cutoff y value. The maximum of ``y`` will be used if it is higher than ``ymax_cutoff`` of ir ``ymax_cutoff`` is not specified. """ layout={'title': {'text': title}, 'xaxis': {'title': x_label, 'tickformat': ticketformat, 'ticks': 'outside', 'mirror': True, 'showline': True}, 'yaxis': {'title': y_label, 'tickformat': ticketformat, 'ticks': 'outside', 'mirror': True, 'showline': True, 'linewidth': 2.}, 'legend': {'x': 0., 'y': 1}} fig = go.Figure(go.Scatter(x=x, y=y, mode='markers+text', marker=dict(color='red', size=12), text=y, textposition='bottom center', textfont=dict(size=12)), layout=layout) fig.update_yaxes(range = [ymin_cutoff, ymax_cutoff]) fig.write_html(path_out) fig.write_image(path_out.replace('html', 'png'), scale=10, width=6, height=8) fig.write_image(path_out.replace('html', 'svg'), scale=10, width=6, height=8)
[docs]def plot_ols(path_out, x, y, title, x_label, y_label, tickformat, ymin_cutoff = None, ymax_cutoff = None): """Creates ordinary least squares (OLS) regression plot using Plotly Parameters ---------- path_out : str Name of file (ending with .html extension) x : list x data to plot y : list y data to plot title : str Title of plot x_label : str x axis label y_label : str y axis label ymin_cutoff : float, optional Minimum cutoff y value. The minimum of ``y`` will be used if it is higher than ``ymin_cutoff`` or if ``ymin_cutoff`` is not specified. ymax_cutoff : float, optional Maximum cutoff y value. The maximum of ``y`` will be used if it is higher than ``ymax_cutoff`` of ir ``ymax_cutoff`` is not specified. tickformat: str Specify number of decimals to keep in ticks """ layout={'title': {'text': title}, 'xaxis': {'title': x_label, 'tickformat': tickformat, 'ticks': 'outside', 'mirror': True, 'showline': True}, 'yaxis': {'title': y_label, 'tickformat': tickformat, 'ticks': 'outside', 'mirror': True, 'showline': True, 'linewidth': 2.}, 'legend': {'x': 0., 'y': 1}} fig = go.Figure(px.scatter(x=x, y=y, trendline='ols', trendline_color_override='red')) fig.update_yaxes(range = [ymin_cutoff, ymax_cutoff]) fig.update_layout(layout) fig.write_html(path_out) fig.write_image(path_out.replace('html', 'png'), scale=10, width=6, height=8) fig.write_image(path_out.replace('html', 'svg'), scale=10, width=6, height=8)
[docs]def plot_coverage(jobs_name, cov_dict_list, species_list, path): """Creates surface coverage plots using Matplotlib Parameters ---------- jobs_name : list List of descriptor names cov_dict_list : list List of dictionaries containing surface species species_list : list List of surface species path : str Path to save output figures """ n_jobs = len(jobs_name) for i in range(n_jobs): gcn = jobs_name[i] cov_dict = cov_dict_list[i] plt.figure() for j in species_list: spe_cov = cov_dict[j] r_time = cov_dict['Time (s)'] if max(spe_cov) >= 1e-5: # only plot species with coverage larger than 1e-5 plt.plot(r_time, spe_cov, label = j) gcn_string = f'{gcn:.3f}' plt.legend(loc = 'best') plt.xlabel('Reaction Time (s)') plt.ylabel('Surface Coverage (ML)') plt.savefig(path + 'GCN{}_coverage.png'.format(gcn_string), transparent = False)
def _get_cutoff(data, min_cutoff=None, max_cutoff=None): """Helper method to calculate cutoff values for plotting. Parameters ---------- data : pandas.Series object Data to evaluate min_cutoff : float, optional Minimum cutoff value. If not specified, the minimum axes value will be based on ``data``. max_cutoff : float, optional Maximum cutoff value. If not specified, the maximum axes value will be based on ``data``. Returns ------- min_cutoff : float Minimum cutoff value max_cutoff : float Maximum cutoff value """ # Process min cutoff temp_min = np.floor(data.min()) if min_cutoff is None: min_cutoff = temp_min elif min_cutoff < temp_min: min_cutoff = temp_min # Process max_cutoff temp_max = np.ceil(data.max()) if max_cutoff is None: max_cutoff = temp_max elif max_cutoff > temp_max: max_cutoff = temp_max return (min_cutoff, max_cutoff)