Source code for strawberryfields.apps.plot
# Copyright 2019 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Tools for visualizing graphs, subgraphs, point processes, and vibronic spectra.
Visualization requires installation of the Plotly library, which is not a dependency of
Strawberry Fields. Plotly can be installed using ``pip install plotly`` or by visiting their
`installation instructions <https://plot.ly/python/getting-started/#installation>`__.
"""
# pylint: disable=import-outside-toplevel
from typing import Optional, Tuple
import networkx as nx
import numpy as np
def _node_coords(g: nx.Graph, l: dict) -> Tuple:
"""Converts coordinates for the graph nodes for plotting purposes.
Args:
g (nx.Graph): input graph
l (dict[int, float]): Dictionary of nodes and their respective coordinates. Can be
generated using a NetworkX `layout <https://networkx.github.io/documentation/latest/
reference/drawing.html#module-networkx.drawing.layout>`__
Returns:
dict[str, list]: lists of x and y coordinates accessed as keys of a dictionary
"""
n_x = []
n_y = []
for n in g.nodes():
n_x.append(l[n][0])
n_y.append(l[n][1])
return {"x": n_x, "y": n_y}
def _edge_coords(g: nx.Graph, l: dict) -> dict:
"""Converts coordinates for the graph edges for plotting purposes.
Args:
g (nx.Graph): input graph
l (dict[int, float]): Dictionary of nodes and their respective coordinates. Can be
generated using a NetworkX `layout <https://networkx.github.io/documentation/latest/
reference/drawing.html#module-networkx.drawing.layout>`__
Returns:
dict[str, list]: lists of x and y coordinates for the beginning and end of each edge.
``None`` is placed as a separator between pairs of nodes/edges.
"""
e_x = []
e_y = []
for e in g.edges():
start_x, start_y = l[e[0]]
end_x, end_y = l[e[1]]
e_x.append(start_x)
e_x.append(end_x)
e_y.append(start_y)
e_y.append(end_y)
e_x.append(None)
e_y.append(None)
return {"x": e_x, "y": e_y}
plotly_error = (
"Plotly required for using this function. It can be installed using pip install "
"plotly or visiting https://plot.ly/python/getting-started/#installation"
)
GREEN = "#3e9651"
RED = "#cc2529"
GREY = "#737373"
LIGHT_GREY = "#CDCDCD"
VERY_LIGHT_GREY = "#F2F2F2"
graph_node_colour = GREEN
graph_edge_colour = LIGHT_GREY
subgraph_node_colour = RED
subgraph_edge_colour = RED
graph_node_size = 14
subgraph_node_size = 16
[docs]def graph(g: nx.Graph, s: Optional[list] = None, plot_size: Tuple = (500, 500)): # pragma: no cover
"""Creates a plot of the input graph.
This function can plot the input graph only, or the graph with a specified subgraph highlighted.
Graphs are plotted using the Kamada-Kawai layout with an aspect ratio of 1:1.
**Example usage:**
>>> graph = nx.complete_graph(10)
>>> fig = plot.graph(graph, [0, 1, 2, 3])
>>> fig.show()
.. image:: ../../_static/complete_graph.png
:width: 40%
:align: center
:target: javascript:void(0);
Args:
g (nx.Graph): input graph
s (list): optional list of nodes comprising the subgraph to highlight
plot_size (int): size of the plot in pixels, given as a pair of integers ``(x_size,
y_size)``
Returns:
Figure: figure for graph and optionally highlighted subgraph
"""
try:
import plotly.graph_objects as go
import plotly.io as pio
except ImportError:
raise ImportError(plotly_error)
try:
in_notebook = get_ipython().__class__.__name__ == "ZMQInteractiveShell"
except NameError:
in_notebook = False
if not in_notebook:
pio.renderers.default = "browser"
l = nx.kamada_kawai_layout(g)
g_nodes = go.Scatter(
**_node_coords(g, l),
mode="markers",
hoverinfo="text",
marker=dict(color=graph_node_colour, size=graph_node_size, line_width=2),
)
g_edges = go.Scatter(
**_edge_coords(g, l),
line=dict(width=1, color=graph_edge_colour),
hoverinfo="none",
mode="lines",
)
g_nodes.text = [str(i) for i in g.nodes()]
layout = go.Layout(
showlegend=False,
hovermode="closest",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
margin=dict(b=0, l=0, r=0, t=25),
height=plot_size[1],
width=plot_size[0],
plot_bgcolor="#ffffff",
)
if s is not None:
s = g.subgraph(s)
s_edges = go.Scatter(
**_edge_coords(s, l),
line=dict(width=2, color=subgraph_edge_colour),
hoverinfo="none",
mode="lines",
)
s_nodes = go.Scatter(
**_node_coords(s, l),
mode="markers",
hoverinfo="text",
marker=dict(color=subgraph_node_colour, size=subgraph_node_size, line_width=2),
)
s_nodes.text = [str(i) for i in s.nodes()]
f = go.Figure(data=[g_edges, s_edges, g_nodes, s_nodes], layout=layout)
else:
f = go.Figure(data=[g_edges, g_nodes], layout=layout)
return f
[docs]def subgraph(s: nx.Graph, plot_size: Tuple = (500, 500)): # pragma: no cover
"""Creates a plot of the input subgraph.
Subgraphs are plotted using the Kamada-Kawai layout with an aspect ratio of 1:1.
**Example usage:**
>>> graph = nx.complete_graph(10)
>>> subgraph = graph.subgraph([0, 1, 2, 3])
>>> fig = plot.subgraph(subgraph)
>>> fig.show()
.. image:: ../../_static/complete_subgraph.png
:width: 40%
:align: center
:target: javascript:void(0);
Args:
s (nx.Graph): input subgraph
plot_size (int): size of the plot in pixels, given as a pair of integers ``(x_size,
y_size)``
Returns:
Figure: figure for subgraph
"""
try:
import plotly.graph_objects as go
except ImportError:
raise ImportError(plotly_error)
l = nx.kamada_kawai_layout(s)
g_edges = go.Scatter(
**_edge_coords(s, l),
line=dict(width=1.5, color=subgraph_edge_colour),
hoverinfo="none",
mode="lines",
)
g_nodes = go.Scatter(
**_node_coords(s, l),
mode="markers",
hoverinfo="text",
marker=dict(color=subgraph_node_colour, size=graph_node_size, line_width=2),
)
g_nodes.text = [str(i) for i in s.nodes()]
layout = go.Layout(
showlegend=False,
hovermode="closest",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
margin=dict(b=0, l=0, r=0, t=25),
height=plot_size[1],
width=plot_size[0],
plot_bgcolor="#ffffff",
)
f = go.Figure(data=[g_edges, g_nodes], layout=layout)
return f
[docs]def points(
R: np.ndarray,
sample: Optional[list] = None,
plot_size: Tuple = (500, 500),
point_size: float = 30,
): # pragma: no cover
"""Creates a plot of two-dimensional points given their input coordinates. Sampled
points can be optionally highlighted among all points.
**Example usage:**
>>> R = np.random.normal(0, 1, (50, 2))
>>> sample = [1] * 10 + [0] * 40 # select first ten points
>>> plot.points(R, sample).show()
.. image:: ../../_static/normal_pp.png
:width: 40%
:align: center
:target: javascript:void(0);
Args:
R (np.array): Coordinate matrix. Rows of this array are the coordinates of the points.
sample (list[int]): optional subset of sampled points to be highlighted
plot_size (int): size of the plot in pixels, given as a pair of integers ``(x_size,
y_size)``
point_size (int): size of the points, proportional to its radius
Returns:
Figure: figure of points with optionally highlighted sample
"""
try:
import plotly.graph_objects as go
except ImportError:
raise ImportError(plotly_error)
layout = go.Layout(
showlegend=False,
hovermode="closest",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
margin=dict(b=0, l=0, r=0, t=25),
height=plot_size[1],
width=plot_size[0],
plot_bgcolor="white",
)
p = go.Scatter(
x=R[:, 0],
y=R[:, 1],
mode="markers",
hoverinfo="text",
marker=dict(
color=VERY_LIGHT_GREY, size=point_size, line=dict(color="black", width=point_size / 20)
),
)
p.text = [str(i) for i in range(len(R))]
if sample:
s_x = []
s_y = []
sampled_points = [i for i in range(len(sample)) if sample[i] > 0]
for i in sampled_points:
s_x.append(R[i, 0])
s_y.append(R[i, 1])
samp = go.Scatter(
x=s_x,
y=s_y,
mode="markers",
hoverinfo="text",
marker=dict(
color=RED, size=point_size, line=dict(color="black", width=point_size / 20)
),
)
samp.text = [str(i) for i in sampled_points]
f = go.Figure(data=[p, samp], layout=layout)
else:
f = go.Figure(data=[p], layout=layout)
return f
[docs]def spectrum(
energies: list, gamma: float = 100.0, xmin: float = None, xmax: float = None
): # pragma: no cover
"""Plots a vibronic spectrum based on input sampled energies.
**Example usage:**
>>> formic = data.Formic()
>>> e = qchem.vibronic.energies(formic, formic.w, formic.wp)
>>> full_spectrum = plot.spectrum(e, xmin=-1000, xmax=8000)
>>> full_spectrum.show()
.. image:: ../../_static/formic_spectrum.png
:width: 50%
:align: center
:target: javascript:void(0);
Args:
energies (list[float]): a list of sampled energies
gamma (float): parameter specifying the width of the Lorentzian function
xmin (float): minimum limit of the x axis
xmax (float): maximum limit of the x axis
Returns:
Figure: spectrum in the form of a histogram of energies with a Lorentzian-like curve
"""
if len(energies) < 2:
raise ValueError("Number of sampled energies must be at least two")
try:
import plotly.graph_objects as go
except ImportError:
raise ImportError(plotly_error)
emin = min(energies)
emax = max(energies)
if xmin is None:
xmin = emin - 0.1 * (emax - emin)
if xmax is None:
xmax = emax + 0.1 * (emax - emin)
bins = int(emax - emin) // 5
bar_width = (xmax - xmin) * 0.005
line_width = 3.0
h = np.histogram(energies, bins)
X = np.linspace(xmin, xmax, int(xmax - xmin))
L = 0
for e in energies:
L += (gamma / 2) ** 2 / ((X - e) ** 2 + (gamma / 2) ** 2)
text_font = dict(color="black", family="Computer Modern")
axis_style = dict(
titlefont_size=30,
tickfont=text_font,
tickfont_size=20,
showline=True,
linecolor="black",
mirror=True,
)
layout = go.Layout(
yaxis=dict(title={"text": "Counts", "font": text_font}, **axis_style, rangemode="tozero"),
xaxis=dict(
title={"text": "Energy (cm<sup>-1</sup>)", "font": text_font},
**axis_style,
range=[xmin, xmax],
),
plot_bgcolor="white",
margin=dict(t=25),
bargap=0.04,
showlegend=False,
)
bars = go.Bar(x=h[1].tolist(), y=h[0].tolist(), width=bar_width, marker=dict(color=GREY))
line = go.Scatter(x=X, y=L, mode="lines", line=dict(color=GREEN, width=line_width))
f = go.Figure([bars, line], layout=layout)
return f
_modules/strawberryfields/apps/plot
Download Python script
Download Notebook
View on GitHub