Source code for strawberryfields.io.xir_io
# Copyright 2019-2021 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains functions for loading and saving Strawberry Fields
:class:`~.Program` objects from/to XIR scripts and Strawberry Fields
code.
"""
# pylint: disable=protected-access,too-many-nested-blocks
from decimal import Decimal
from typing import Iterable, List, Sequence
import numpy as np
import xir
import strawberryfields.parameters as sfpar
from strawberryfields.program import Program
from strawberryfields.tdm import TDMProgram, is_ptype
from strawberryfields import ops
def get_expanded_statements(prog: xir.Program) -> Sequence[xir.Statement]:
"""Get a list of statements with all gate definitions expanded.
Args:
prog (xir.Program): XIR program with statements and definitions
Returns:
list[xir.Statement]: list of expanded XIR statements
"""
def expand_statements(statements: Sequence[xir.Statement]) -> Sequence[xir.Statement]:
flattened_statements = []
for op in statements:
sub_statements = expand_statements(prog.gates.get(op.name, []))
if sub_statements:
wire_mapping = dict(zip(prog.search("gate", "wires", op.name), op.wires))
param_mapping = dict(zip(prog.search("gate", "params", op.name), op.params))
# create a new statement object with substituted parameters and wires
for stmt in sub_statements:
wires = tuple(wire_mapping[w] for w in stmt.wires)
params = [param_mapping[w] for w in stmt.params]
flattened_statements.append(xir.Statement(stmt.name, params, wires))
else:
flattened_statements.append(op)
return flattened_statements
return expand_statements(prog.statements)
# pylint: disable=too-many-branches
def from_xir(xir_prog: xir.Program) -> Program:
"""Convert an XIR Program to a Strawberry Fields program.
Args:
xir_prog (xir.Program): the input XIR program object
Returns:
Program: corresponding Strawberry Fields program
Raises:
ValueError: if the XIR program is empty
"""
# only script-level statements are part of `xir_prog.statements`, which can only have integer
# wires, leading to `xir_prog.wires` only containing integer wire labels
if not xir_prog.wires:
raise ValueError(
"The XIR program is empty and cannot be transformed "
"into a Strawberry Fields program."
)
num_of_modes = int(max(xir_prog.wires)) + 1
name = xir_prog.options.get("_name_", "sf_from_xir")
prog = Program(num_of_modes, name=name)
# append the quantum operations
with prog.context as q:
for op in get_expanded_statements(xir_prog):
# check if operation name is in the list of
# defined StrawberryFields operations.
# This is used by checking against the ops.py __all__
# module attribute, which contains the names
# of all defined quantum operations
if op.name in ops.__all__:
# get the quantum operation from the sf.ops module
gate = getattr(ops, op.name)
else:
raise NameError(f"Quantum operation {op.name!r} not defined!")
# create the list of regrefs
regrefs = [q[i] for i in op.wires]
if op.params:
# convert symbolic expressions to symbolic expressions containing the corresponding
# MeasuredParameter and FreeParameter instances.
if isinstance(op.params, dict):
vals = sfpar.par_convert(op.params.values(), prog)
params = dict(zip(op.params.keys(), vals))
gate(**params) | regrefs # pylint:disable=expression-not-assigned
else:
params = []
for p in op.params:
if isinstance(p, Decimal):
params.append(float(p))
elif isinstance(p, Iterable):
params.append(np.array(_listr(p)))
else:
params.append(p)
params = sfpar.par_convert(params, prog)
gate(*params) | regrefs # pylint:disable=expression-not-assigned
else:
gate() | regrefs # pylint:disable=expression-not-assigned,pointless-statement
prog._target = xir_prog.options.get("_target_", None) # pylint: disable=protected-access
if "shots" in xir_prog.options:
prog.run_options["shots"] = xir_prog.options["shots"]
if "cutoff_dim" in xir_prog.options:
prog.backend_options["cutoff_dim"] = xir_prog.options["cutoff_dim"]
return prog
# pylint: disable=too-many-branches
def from_xir_to_tdm(xir_prog: xir.Program) -> TDMProgram:
"""Convert an XIR Program to a ``TDMProgram``.
Args:
xir_prog (xir.Program): the input XIR program object
Returns:
TDMProgram: corresponding ``TDMProgram``
Raises:
ValueError: if the number of modes 'N' is missing from the XIR program options
NameError: if an applied quantum operation is not defined in Strawberry Fields
"""
N = xir_prog.options.get("N")
if not N:
raise ValueError("Number of modes 'N' is missing from the XIR program options.")
prog = TDMProgram(N, name=xir_prog.options.get("_name_", "xir"))
# extract the tdm gate arguments from the xir program constants
args = [val for key, val in xir_prog.constants.items() if is_ptype(key)]
# convert arguments to float/complex if stored as Decimal/DecimalComplex objects
for i, params in enumerate(args):
for j, p in enumerate(params):
if isinstance(p, Decimal):
args[i][j] = float(p)
elif isinstance(p, xir.DecimalComplex):
args[i][j] = complex(p)
# append the quantum operations
with prog.context(*args) as (p, q):
for op in get_expanded_statements(xir_prog):
# check if operation name is in the list of
# defined StrawberryFields operations.
# This is used by checking against the ops.py __all__
# module attribute, which contains the names
# of all defined quantum operations
if op.name in ops.__all__:
# get the quantum operation from the sf.ops module
gate = getattr(ops, op.name)
else:
raise NameError(f"Quantum operation {op.name!r} not defined!")
# create the list of regrefs
regrefs = [q[int(i)] for i in op.wires]
if op.params:
# convert symbolic expressions to symbolic expressions containing the corresponding
# MeasuredParameter and FreeParameter instances.
if isinstance(op.params, dict):
vals = sfpar.par_convert(op.params.values(), prog)
params = dict(zip(op.params.keys(), vals))
for key, val in params.items():
if is_ptype(val):
params[key] = p[int(val[1:])]
gate(**params) | regrefs # pylint:disable=expression-not-assigned
else:
params = []
for param in op.params:
if isinstance(param, Decimal):
params.append(float(param))
elif isinstance(param, (list, np.ndarray)):
params.append(np.array(_listr(param)))
elif isinstance(param, str) and is_ptype(param):
params.append(p[int(param[1:])])
else:
params.append(param)
params = sfpar.par_convert(params, prog)
gate(*params) | regrefs # pylint:disable=expression-not-assigned
else:
gate() | regrefs # pylint:disable=expression-not-assigned,pointless-statement
prog._target = xir_prog.options.get("target", None) # pylint: disable=protected-access
if "shots" in xir_prog.options:
prog.run_options["shots"] = xir_prog.options["shots"]
return prog
[docs]def to_xir(prog: Program, **kwargs) -> xir.Program:
"""Convert a Strawberry Fields Program to an XIR Program.
Args:
prog (Program): the Strawberry Fields program
Keyword Args:
add_decl (bool): Whether gate and output declarations should be added to
the XIR program. Default is ``False``.
Returns:
xir.Program
"""
xir_prog = xir.Program()
add_decl = kwargs.get("add_decl", False)
if isinstance(prog, TDMProgram):
xir_prog.add_option("_type_", "tdm")
xir_prog.add_option("N", prog.N)
for i, p in enumerate(prog.tdm_params):
xir_prog.add_constant(f"p{i}", _listr(p))
if prog.name:
xir_prog.add_option("_name_", prog.name)
if prog.target:
xir_prog.add_option("target", prog.target) # pylint: disable=protected-access
if "cutoff_dim" in prog.backend_options:
xir_prog.add_option("cutoff_dim", prog.backend_options["cutoff_dim"])
if "shots" in prog.run_options:
xir_prog.add_option("shots", prog.run_options["shots"])
# fill in the quantum circuit
for cmd in prog.circuit or []:
name = cmd.op.__class__.__name__
wires = tuple(i.ind for i in cmd.reg)
if "Measure" in name:
if add_decl:
output_decl = xir.Declaration(name, type_="out", wires=wires)
xir_prog.add_declaration(output_decl)
params = {}
if cmd.op.p:
# argument is quadrature phase
a = cmd.op.p[0]
if a in getattr(prog, "loop_vars", ()):
params["phi"] = a.name
else:
params["phi"] = a
# special case to take into account 'select' keyword argument
if cmd.op.select is not None:
params["select"] = cmd.op.select
if name == "MeasureFock":
# special case to take into account 'dark_counts' keyword argument
if cmd.op.dark_counts is not None:
params["dark_counts"] = cmd.op.dark_counts
else:
if add_decl:
if name not in [gdecl.name for gdecl in xir_prog.declarations["gate"]]:
params = [f"p{i}" for i, _ in enumerate(cmd.op.p)]
gate_decl = xir.Declaration(
name, type_="gate", params=params, wires=tuple(range(len(wires)))
)
xir_prog.add_declaration(gate_decl)
params = []
for i, a in enumerate(cmd.op.p):
if sfpar.par_is_symbolic(a):
# try to evaluate symbolic parameter
try:
a = sfpar.par_evaluate(a)
except sfpar.ParameterError:
# if a tdm param
if a in getattr(prog, "loop_vars", ()):
a = a.name
# if a pure symbol (free parameter), convert to string
elif a.is_symbol:
a = a.name
# else, assume it's a symbolic function and replace all free parameters
# with string representations
else:
symbolic_func = a.copy()
for s in symbolic_func.free_symbols:
symbolic_func = symbolic_func.subs(s, s.name)
a = str(symbolic_func)
elif isinstance(a, str):
pass
elif isinstance(a, Iterable):
# if an iterable, make sure it only consists of lists and Python types
a = _listr(a)
params.append(a)
op = xir.Statement(name, params, wires)
xir_prog.add_statement(op)
return xir_prog
def _listr(mixed_iterable: Iterable) -> List:
"""Casts a nested iterable to a list recursively, maintaining the same shape.
Any iterable will be cast to a list, including casting all internal types to native Python
types (e.g., ``Decimal`` and ``np.floating`` to ``float``); Python strings will be cast to
lists of strings containing a single character each.
.. warning:
Currently, strings cannot be passed to the function; an error will be raised if
``mixed_iterable`` is of type ``str``.
"""
if isinstance(mixed_iterable, str):
raise TypeError("Strings cannot be passed to _listr().")
list_ = []
for l in mixed_iterable:
# if string, then create a list of chars
if isinstance(l, str):
list_.append(l)
elif isinstance(l, Iterable):
list_.append(_listr(l))
else:
if isinstance(l, (Decimal, np.floating)):
list_.append(float(l))
elif isinstance(l, (xir.DecimalComplex, np.complexfloating)):
list_.append(complex(l))
else:
try:
# if a NumPy-like object, extract the internal object
# with native Python type (e.g., `np.int` to Python `int`)
list_.append(l.item())
except AttributeError:
list_.append(l)
return list_
_modules/strawberryfields/io/xir_io
Download Python script
Download Notebook
View on GitHub