Some feature additions to saved_model_cli AOT compile & bugfixes in freeze graph

1. saved_model_cli now properly identifies read-only and non-readonly vars in a graph
   and only freezes the readonly variables (and marks the others as not readonly).
2. fixed bugs in convert_variables_to_constants where blacklist/whitelist was not
   properly respected for chains of operations.  while the VarHandleOp node was properly
   blacklisted, downstream ops that used the resources would be converted as if the
   variable had been frozen.  so for example the following graph would break:

   VarHandleOp -> Identity -> [first arg in] ResourceAssign

   The attrs of the Identity op would be changed to DT_FLOAT though it should stay as
   DT_RESOURCE.
3. Added support for freezing of *Nd ops (ResourceGatherNd, ResourceScatterNd).

PiperOrigin-RevId: 293239272
Change-Id: I06de3d139c5585a93ba585f076edf92137e4c48a
This commit is contained in:
Eugene Brevdo 2020-02-04 15:12:01 -08:00 committed by TensorFlower Gardener
parent 3262f347ae
commit 96c84224c8
7 changed files with 645 additions and 425 deletions

View File

@ -514,7 +514,7 @@ def _convert_variables_to_constants_v2_impl(func,
# Get dtype and data for non-variable Placeholders (ex. values for 1.X
# Const ops that are loaded as Placeholders in 2.0)
_save_placeholder(node.name, node.attr["dtype"])
elif node.op in ["ReadVariableOp", "ResourceGather"]:
elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]:
# Get dtype and data for Placeholder ops associated with ReadVariableOp
# and ResourceGather ops. There can be an Identity in between the
# resource op and Placeholder. Store the dtype for the Identity ops.
@ -570,6 +570,15 @@ def _convert_variables_to_constants_v2_impl(func,
output_node.attr["Taxis"].CopyFrom(axis_dtype)
if "_class" in input_node.attr:
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
elif input_node.op == "ResourceGatherNd":
output_node.op = "GatherNd"
output_node.name = input_node.name
output_node.input.extend(
[input_node.input[0], input_node.input[1]])
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
if "_class" in input_node.attr:
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
# Update the function names and argument types for the conditional ops.
elif input_node.op in _CONDITIONAL_OPS:
_populate_if_op(output_node, input_node, function_data)

View File

@ -237,7 +237,9 @@ def tensor_shape_from_node_def_name(graph, input_name):
return shape
def _update_resource_identities(resource_identities, output_graph_def):
def _update_resource_identities(resource_identities, output_graph_def,
variable_names_whitelist,
variable_names_blacklist):
"""Updates the type of DT_RESOURCE Identity ops.
Updates the type of the `resource_identities` to the type of the node that
@ -248,6 +250,10 @@ def _update_resource_identities(resource_identities, output_graph_def):
resource_identities: List of NodeDef protos that are Identity ops with the
type DT_RESOURCE.
output_graph_def: GraphDef proto.
variable_names_whitelist: The set of variable names to convert (by default,
all variables are converted).
variable_names_blacklist: The set of variable names to omit converting
to constants.
"""
# Identify the nodes in the graph and the nodes consuming each node.
map_name_to_node = {}
@ -280,14 +286,27 @@ def _update_resource_identities(resource_identities, output_graph_def):
"This Identity's type was changed from DT_RESOURCE during graph "
"freezing.")
if input_node.attr["T"].type != dtypes.resource:
if input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY:
if (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY
and _should_convert(
input_node.input[0],
variable_names_whitelist,
variable_names_blacklist)):
node.attr["T"].CopyFrom(input_node.attr["T"])
node.attr["_debugging"].s = debugging_message
elif input_node.op == "VarHandleOp":
elif (input_node.op == "VarHandleOp"
and _should_convert(
input_node.name,
variable_names_whitelist,
variable_names_blacklist)):
node.attr["T"].CopyFrom(input_node.attr["dtype"])
node.attr["_debugging"].s = debugging_message
def _should_convert(name, whitelist, blacklist):
return ((whitelist is None or name in whitelist)
and (blacklist is None or name not in blacklist))
@deprecation.deprecated(
date=None,
instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`")
@ -315,6 +334,10 @@ def convert_variables_to_constants(sess,
Returns:
GraphDef containing a simplified version of the original.
Raises:
RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both
blacklisted AND whitelisted for freezing.
"""
get_input_name = lambda node, index=0: node.input[index].split(":")[0]
@ -344,44 +367,60 @@ def convert_variables_to_constants(sess,
variable_names = []
variable_dict_names = []
resource_op_types = {}
for node in inference_graph.node:
if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
variable_name = node.name
if ((variable_names_whitelist is not None and
variable_name not in variable_names_whitelist) or
(variable_names_blacklist is not None and
variable_name in variable_names_blacklist)):
if not _should_convert(
variable_name, variable_names_whitelist, variable_names_blacklist):
continue
variable_dict_names.append(variable_name)
if node.op == "VarHandleOp":
variable_names.append(variable_name + "/Read/ReadVariableOp:0")
else:
variable_names.append(variable_name + ":0")
elif node.op in ["ReadVariableOp", "ResourceGather"]:
elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]:
# There can be one or more Identity or control flow ops in between the
# ReadVariableOp and VarHandleOp. Store the ops with the associated
# dtypes.
source_op_names = [get_input_name(node)]
candidate_resource_op_types = {}
while (source_op_names and map_name_to_node[source_op_names[0]].op in
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
source_op_name = source_op_names.pop()
current_node = map_name_to_node[source_op_name]
if source_op_name not in resource_op_types:
resource_op_types[source_op_name] = node.attr["dtype"]
if (source_op_name not in resource_op_types and
source_op_name not in candidate_resource_op_types):
candidate_resource_op_types[source_op_name] = node.attr["dtype"]
source_op_names.append(get_input_name(current_node))
if current_node == "Merge":
merge_resource_name = get_input_name(current_node, index=1)
if merge_resource_name not in resource_op_types:
resource_op_types[merge_resource_name] = node.attr["dtype"]
if (merge_resource_name not in resource_op_types
and merge_resource_name not in candidate_resource_op_types):
candidate_resource_op_types[merge_resource_name] = (
node.attr["dtype"])
source_op_names.append(
get_input_name(map_name_to_node[merge_resource_name]))
should_convert_all = None
for source_node in source_op_names:
if map_name_to_node[source_node].op != "VarHandleOp":
raise ValueError("Cannot find the variable that is an input "
"to the ReadVariableOp.")
should_convert_node = _should_convert(
source_node, variable_names_whitelist, variable_names_blacklist)
if should_convert_all is None:
should_convert_all = should_convert_node
elif should_convert_all != should_convert_node:
raise RuntimeError(
"Found DT_RESOURCE node whose ancestor Variables are both "
"blacklisted AND whitelisted for freezing. Originating "
"descendant node: {}. Ancestor variables: {}.".format(
node.name, source_op_names))
if should_convert_all in (None, True):
resource_op_types.update(candidate_resource_op_types)
# Gets map of variables and the associated data.
if variable_names:
@ -391,6 +430,15 @@ def convert_variables_to_constants(sess,
variables_data_map = dict(zip(variable_dict_names, returned_variables))
logging.info("Froze %d variables.", len(returned_variables))
def _should_convert_ancestor(node):
input_node = map_name_to_node[_node_name(node.input[0])]
while (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY and
input_node.attr["T"].type == dtypes.resource):
input_node = map_name_to_node[_node_name(input_node.input[0])]
return _should_convert(input_node.name,
variable_names_whitelist,
variable_names_blacklist)
# Reconstruct the graph with constants in place of variables.
output_graph_def = graph_pb2.GraphDef()
how_many_converted = 0
@ -413,7 +461,8 @@ def convert_variables_to_constants(sess,
if str(attr_name) != "_output_shapes":
output_node.attr[attr_name].CopyFrom(input_node.attr[attr_name])
output_node.attr["T"].CopyFrom(resource_op_types[input_node.name])
elif input_node.op == "ReadVariableOp":
elif (input_node.op == "ReadVariableOp"
and _should_convert_ancestor(input_node)):
# The first branch converts all VarHandleOps of ResourceVariables to
# constants, so we need to convert the associated ReadVariableOps to
# Identity ops.
@ -423,7 +472,8 @@ def convert_variables_to_constants(sess,
output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
if "_class" in input_node.attr:
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
elif input_node.op == "ResourceGather":
elif (input_node.op == "ResourceGather"
and _should_convert_ancestor(input_node)):
# The first branch converts all VarHandleOps of ResourceGather to
# constants, so we need to convert the associated ResourceGather to Gather
# ops with a Const axis feeding into it.
@ -444,9 +494,19 @@ def convert_variables_to_constants(sess,
output_node.attr["Taxis"].CopyFrom(axis_dtype)
if "_class" in input_node.attr:
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
elif (input_node.op == "ResourceGatherNd"
and _should_convert_ancestor(input_node)):
output_node.op = "GatherNd"
output_node.name = input_node.name
output_node.input.extend(
[input_node.input[0], input_node.input[1]])
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
if "_class" in input_node.attr:
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
else:
output_node.CopyFrom(input_node)
output_graph_def.node.extend([output_node])
output_graph_def.node.append(output_node)
# Update the types of the DT_RESOURCE Identity nodes that do not have an
# associated ReadVariableOp.
@ -455,7 +515,10 @@ def convert_variables_to_constants(sess,
if node.op == "Identity" and node.attr["T"].type == dtypes.resource:
resource_identities.append(node)
if resource_identities:
_update_resource_identities(resource_identities, output_graph_def)
_update_resource_identities(resource_identities,
output_graph_def,
variable_names_whitelist,
variable_names_blacklist)
output_graph_def.library.CopyFrom(inference_graph.library)
logging.info("Converted %d variables to const ops.", how_many_converted)

View File

@ -36,6 +36,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_state_ops
@ -418,6 +419,53 @@ class ConvertVariablesToConstantsTest(test.TestCase):
output = self.evaluate(output_node)
self.assertNear(2.0, output, 0.00001)
def test_resource_variable_can_be_written_after_blacklisting(self):
with ops.Graph().as_default():
with variable_scope.variable_scope("", use_resource=True):
variable_node = variable_scope.get_variable(
"variable_node", initializer=1.0)
another_variable = variable_scope.get_variable(
"unused_variable_node", initializer=2.0)
with ops.control_dependencies([
variable_node.assign(another_variable + variable_node)]):
output_node = array_ops.identity(variable_node, name="output_node")
initializer_name = variable_node.initializer.name
with session.Session() as sess:
self.evaluate(variable_node.initializer)
self.evaluate(another_variable.initializer)
output = self.evaluate(output_node)
self.assertNear(3.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# Test variable name black list. This should result in the variable
# not being a const. Furthermore, the paths that read from and assign
# to the blacklisted variable should continue to be valid.
constant_graph_def_with_blacklist = (
graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node", initializer_name],
variable_names_blacklist=set(["variable_node"])))
variable_node = None
for node in constant_graph_def_with_blacklist.node:
if node.name == "variable_node":
variable_node = node
self.assertIsNotNone(variable_node)
self.assertEqual(variable_node.op, "VarHandleOp")
# Now we make sure another_variable is now a constant, but the original
# variable is not, and that the graph can be executed and update the
# variable can be updated with each execution.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def_with_blacklist, name="")
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
self.evaluate(sess.graph.get_operation_by_name(initializer_name))
output = self.evaluate(output_node)
self.assertNear(3.0, output, 0.00001)
output = self.evaluate(output_node)
self.assertNear(5.0, output, 0.00001)
def _inline_functions(self, graph_def, arrays):
meta_graph = export_meta_graph(graph_def=graph_def)
fetch_collection = meta_graph_pb2.CollectionDef()

View File

@ -322,9 +322,19 @@ py_library(
srcs = ["saved_model_cli.py"],
srcs_version = "PY2AND3",
deps = [
":saved_model_aot_compile",
":saved_model_utils",
"//tensorflow/python",
"//tensorflow/python/debug:local_cli_wrapper",
],
)
py_library(
name = "saved_model_aot_compile",
srcs = ["saved_model_aot_compile.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python",
"//tensorflow/python:tf_optimizer",
] + if_xla_available(
["//tensorflow/compiler/tf2xla:tf2xla_proto_py"],

View File

@ -0,0 +1,472 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper utilities for AOT compilation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import hashlib
import os
import pipes
import shlex
import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import sysconfig as sysconfig_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_lib
try:
from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top
except ImportError as e:
_pywrap_tfcompile_import_error = ImportError(
'Unable to import _pywrap_tfcompile; you must build TensorFlow '
'with XLA. You may need to build tensorflow with flag '
'--define=with_xla_support=true. Original error: {}'.format(str(e)))
else:
_pywrap_tfcompile_import_error = None
_READ_ONLY_VARIABLE_OPS = (
'ReadVariableOp',
'IsVariableInitializedOp',
'ResourceGather',
'ResourceGatherNd',
'VariableShape',
)
_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN')
def _shlex_quote(s):
if six.PY2:
return pipes.quote(s)
else:
return shlex.quote(s)
def _sysconfig_module():
"""Load tf.sysconfig if available and working (i.e., inside a pip package)."""
try:
_ = sysconfig_lib.get_include()
except ImportError:
return None
return sysconfig_lib
def _parse_tensor_name(name):
"""Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
if ':' in name and not name.endswith(':'):
node_name = name[:name.rfind(':')]
output_slot = int(name[name.rfind(':') + 1:])
return node_name, output_slot
else:
return name, None
_XLA_MAKEFILE_TEMPLATE = """
INC = -I{tensorflow_includes}
LIB = -L{compiled_dir}
CXXFLAGS = {cxx_flags}
"""
def _xla_makefile_string(output_prefix):
"""Returns a Makefile string with variables for using XLA binary object files.
Attempts to identify the right include header paths when run from either
an installed TensorFlow pip package, or from bazel run.
Args:
output_prefix: A string containing the output prefix for the XLA AOT
compiled header + object files.
Returns:
A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
"""
sysconfig = _sysconfig_module()
output_dir, _ = os.path.split(output_prefix)
if sysconfig:
tensorflow_includes = _shlex_quote(sysconfig.get_include())
else:
# Try hard to find the real source directory if this is a local bazel run.
if os.path.islink(__file__):
this_file = __file__
while os.path.islink(this_file):
this_file = os.readlink(this_file)
base = os.path.realpath(
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
else:
try:
base = test.test_src_dir_path('')
except KeyError: # Can't find TEST_SRCDIR in environment path.
base = os.path.realpath(
os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
expected_header = os.path.join(
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
if not os.path.exists(expected_header):
logging.error(
'Could not find includes path. Missing file: {}'
.format(expected_header))
tensorflow_includes = base
return _XLA_MAKEFILE_TEMPLATE.format(
tensorflow_includes=tensorflow_includes,
compiled_dir=_shlex_quote(output_dir),
cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
versions.CXX11_ABI_FLAG))
def _get_variable_nodes_from_graph_def(graph_def):
"""Get the list of Variable nodes from `graph_def`.
Args:
graph_def: An instance of `GraphDef`. This GraphDef *must*
have already been optimized by Grappler. In particular, function
inlining must have already happened.
Returns:
A dict mapping string names of variables to tuples `(node_def, modified)`,
where `node_def` is the `NodeDef` corresponding to variable, and `modified`
is a python bool describing whether the variable is modified during runtime.
"""
variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
variable_name_map = dict((n.name, n) for n in variables)
child_map = collections.defaultdict(lambda: [])
for n in graph_def.node:
for inp in n.input:
if not inp.startswith('^'):
child_map[inp].append(n)
variables = {}
for (v_name, v_node) in variable_name_map.items():
queue = list(child_map[v_name])
processed = set([])
while queue:
n_current = queue.pop()
if n_current.name in processed:
continue
processed.add(n_current.name)
if n_current.op in _PASS_THROUGH_VARIABLE_OPS:
children = child_map.get(n_current.name, [])
queue.extend(children)
elif n_current.op not in _READ_ONLY_VARIABLE_OPS:
variables[v_name] = (v_node, True)
queue = []
if v_name not in variables:
variables[v_name] = (v_node, False)
return variables
def _prune_removed_feed_nodes(signature_def, graph_def):
"""Identify the inputs in the signature no longer in graph_def, prune them.
Args:
signature_def: A `SignatureDef` instance.
graph_def: A `GraphDef` instance.
Returns:
A new pruned `SignatureDef`.
"""
node_names = set([n.name for n in graph_def.node])
new_signature_def = meta_graph_pb2.SignatureDef()
new_signature_def.CopyFrom(signature_def)
for (k, v) in signature_def.inputs.items():
tensor_name, _ = _parse_tensor_name(v.name)
if tensor_name not in node_names:
logging.warn(
'Signature input key \'{}\', tensor name \'{}\', has been pruned '
'while freezing the graph. Removing it from the compiled signatures.'
.format(k, tensor_name))
del new_signature_def.inputs[k]
return new_signature_def
def aot_compile_cpu_meta_graph_def(checkpoint_path,
meta_graph_def,
output_prefix,
signature_def_key,
cpp_class,
target_triple,
variables_to_feed=()):
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
Use XLA AOT (`tfcompile`) to convert the given meta graph and
signature into a header + object files. Also create an include makefile
that helps identify the appropriate necessary include and library paths
to incorporate these files into your C++ program.
The graph is always optimized with grappler, and optionally (by default)
variables are frozen as constants, before compilation happens.
If the `freeze_graph` is `True`, all variables are embedded as constants
into the graph and binary objects. If it is `False`, then the variable
values become inputs and outputs of the compiled class and the C++
caller must set these values manually.
Args:
checkpoint_path: Python string. Path to checkpoints/variables.
meta_graph_def: Instance of `MetaGraphDef`.
output_prefix: Python string. Path prefix for outputs.
signature_def_key: String, the signature_def to use in the SavedModel.
cpp_class: String, Name of output C++ class.
target_triple: String, LLVM target triple.
variables_to_feed: A list of strings, the variables that will be fed by the
user; these won't be frozen. If `None`, then we will extract all the
variables in the graph and mark them as to-feed. The default behavior is
an empty tuple: all variables must be frozen.
Raises:
RuntimeError: If tensorflow was not built with XLA.
ImportError: If tensorflow was built with XLA but there was another
issue importing the tfcompile python wrapper.
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
missing or has empty outputs.
"""
if _pywrap_tfcompile_import_error:
raise _pywrap_tfcompile_import_error
signature_def_map = meta_graph_def.signature_def
if signature_def_key not in signature_def_map:
raise ValueError(
'Unable to find signature_def key \'{}\' in signature def map. '
'Available keys: {}'.format(
signature_def_key,
list(signature_def_map.keys())))
signature_def = signature_def_map[signature_def_key]
if not signature_def.outputs:
raise ValueError(
'Signature key {} must have outputs, but saw none:\n{}'.format(
signature_def_key, str(signature_def)))
temp_dir = test.get_temp_dir()
file_io.recursive_create_dir(temp_dir)
if logging.get_verbosity() >= logging.INFO:
original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
# This updates graph_def in place.
_replace_input_placeholders_with_default_values(
meta_graph_def.graph_def, signature_def)
graph_def = _optimize_graph(meta_graph_def, signature_def)
all_variables = _get_variable_nodes_from_graph_def(graph_def)
if variables_to_feed is None:
variable_nodes_to_feed = list(all_variables.values())
else:
not_in_graph = set(variables_to_feed).difference(list(all_variables))
if not_in_graph:
raise ValueError(
'Asked to feed variables that were not found in graph: {}. '
'Variables contained in the graph: {}'.format(
not_in_graph, list(all_variables)))
variable_nodes_to_feed = [
all_variables[name] for name in variables_to_feed
]
if logging.get_verbosity() >= logging.INFO:
prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(graph_def.SerializeToString())
# Load the Variables so that we can freeze the graph.
with session.Session(graph=ops_lib.Graph()) as sess:
restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
restorer.restore(sess, checkpoint_path)
graph_def.CopyFrom(
graph_util.convert_variables_to_constants(
sess,
graph_def,
output_node_names=[
_parse_tensor_name(n.name)[0]
for n in signature_def.outputs.values()
],
variable_names_blacklist=[
name for (name, node_modified) in all_variables.items()
if node_modified[1]
],
))
signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(graph_def.SerializeToString())
config = _signature_to_tf2xla_config(
signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
config_writer.write(str(config))
output_dir = os.path.dirname(output_prefix)
file_io.recursive_create_dir(output_dir)
entry_digest = hashlib.md5()
entry_digest.update(str(config).encode())
entry_digest.update(str(graph_def).encode())
entry_digest = entry_digest.hexdigest()
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
makefile_writer.write(_xla_makefile_string(output_prefix))
output_prefix = _shlex_quote(output_prefix)
_pywrap_tfcompile.Compile(
graph=frozen_graph_def_location,
config=config_pbtxt_location,
cpp_class=cpp_class,
target_triple=target_triple,
entry_point='entry_{}'.format(entry_digest),
out_function_object='{}.o'.format(output_prefix),
out_header='{}.h'.format(output_prefix),
out_metadata_object='{}_metadata.o'.format(output_prefix),
gen_name_to_index=True,
# ProgramShape isn't uniquefied by entry_point.
gen_program_shape=False)
def _optimize_graph(meta_graph_def, signature_def):
"""Optimize `meta_graph_def` using grappler. Returns a `GraphDef`."""
# We need to add a collection called 'train_op' so that grappler
# knows what the outputs are.
new_meta_graph_def = copy.deepcopy(meta_graph_def)
fetch_collection = meta_graph_pb2.CollectionDef()
for tensor_info in (
list(signature_def.inputs.values()) +
list(signature_def.outputs.values())):
fetch_collection.node_list.value.append(tensor_info.name)
new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
config = config_pb2.ConfigProto()
return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
def _replace_input_placeholders_with_default_values(graph_def, signature_def):
"""Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
name_to_node_map = dict((n.name, n) for n in graph_def.node)
processed_nodes = set([])
for name, input_ in signature_def.inputs.items():
tensor_name, _ = _parse_tensor_name(input_.name)
if tensor_name in processed_nodes:
continue
processed_nodes.add(tensor_name)
if tensor_name not in name_to_node_map:
raise RuntimeError(
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
'Graph nodes are: {}'.format(tensor_name,
list(name_to_node_map.keys())))
node = name_to_node_map[tensor_name]
if node.op not in ('Placeholder', 'PlaceholderV2'):
logging.info(
'Tried to convert SavedModel input node \'{}\' from a placeholder, '
'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
node))
continue
shape = tensor_shape.TensorShape(input_.tensor_shape)
if not shape.is_fully_defined():
raise ValueError(
'Expected fully defined input shape for signature_def \'{}\', '
'tensor name: \'{}\'; but shape is: {}.'
.format(name, tensor_name, shape))
temp_graph = ops_lib.Graph()
with temp_graph.as_default():
const = array_ops.zeros(
shape, dtype=input_.dtype, name=tensor_name)
node.CopyFrom(const.op.node_def)
# Sometimes zeros() also creates additional nodes
for op in temp_graph.get_operations():
if op.name == const.op.name: # We just inserted this one.
continue
graph_def.node.append(op.node_def)
name_to_node_map[op.node_def.name] = op.node_def
def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
Args:
signature_def: Instance of `SignatureDef`.
variable_nodes_to_feed: List of tuples of form `(node_def, modified)`
corresponding to VarHandleOp, and a boolean `modified` that describes
whether the variable was modified during execution.
Returns:
An instance of `tf2xla.Config` proto.
Raises:
RuntimeError: If TensorFlow was not compiled with XLA.
"""
from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top
config = tf2xla_pb2.Config()
tensor_id = tf2xla_pb2.TensorId
for name, input_ in signature_def.inputs.items():
name = name.replace('/', '_')
name = 'feed_{}'.format(name)
(node_name, output_index) = _parse_tensor_name(input_.name)
output_index = int(output_index)
config.feed.append(
tf2xla_pb2.Feed(
id=tensor_id(node_name=node_name, output_index=output_index),
name=name,
type=input_.dtype,
shape=input_.tensor_shape))
for name, output_ in signature_def.outputs.items():
name = name.replace('/', '_')
name = 'fetch_{}'.format(name)
(node_name, output_index) = _parse_tensor_name(output_.name)
output_index = int(output_index)
config.fetch.append(
tf2xla_pb2.Fetch(
id=tensor_id(node_name=node_name, output_index=output_index),
name=name,
type=output_.dtype,
shape=output_.tensor_shape))
for (node, modified) in variable_nodes_to_feed:
name = node.name.replace('/', '_')
name = 'param_{}'.format(name)
config.variable.append(
tf2xla_pb2.Variable(
node_name=node.name,
name=name,
type=node.attr['dtype'].type,
shape=node.attr['shape'].shape,
readonly=not modified))
return config

View File

@ -25,12 +25,8 @@ from __future__ import print_function
import argparse
import collections
import copy
import hashlib
import os
import pipes
import re
import shlex
import sys
import numpy as np
@ -38,31 +34,22 @@ import six
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import meta_graph as meta_graph_lib
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import versions
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.platform import sysconfig as sysconfig_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_aot_compile
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as saver_lib
_XLA_DEBUG_OPTIONS_URL = (
@ -70,100 +57,10 @@ _XLA_DEBUG_OPTIONS_URL = (
'tensorflow/compiler/xla/debug_options_flags.cc')
try:
from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top
except ImportError as e:
_pywrap_tfcompile_import_error = ImportError(
'Unable to import _pywrap_tfcompile; you must build TensorFlow '
'with XLA. You may need to build tensorflow with flag '
'--define=with_xla_support=true. Original error: {}'.format(str(e)))
else:
_pywrap_tfcompile_import_error = None
# Set of ops to blacklist.
_OP_BLACKLIST = set(['WriteFile', 'ReadFile', 'PrintV2'])
def _shlex_quote(s):
if six.PY2:
return pipes.quote(s)
else:
return shlex.quote(s)
def _sysconfig_module():
"""Load tf.sysconfig if available and working (i.e., inside a pip package)."""
try:
_ = sysconfig_lib.get_include()
except ImportError:
return None
return sysconfig_lib
def _parse_tensor_name(name):
"""Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
if ':' in name and not name.endswith(':'):
node_name = name[:name.rfind(':')]
output_slot = int(name[name.rfind(':') + 1:])
return node_name, output_slot
else:
return name, None
_XLA_MAKEFILE_TEMPLATE = """
INC = -I{tensorflow_includes}
LIB = -L{compiled_dir}
CXXFLAGS = {cxx_flags}
"""
def _xla_makefile_string(output_prefix):
"""Returns a Makefile string with variables for using XLA binary object files.
Attempts to identify the right include header paths when run from either
an installed TensorFlow pip package, or from bazel run.
Args:
output_prefix: A string containing the output prefix for the XLA AOT
compiled header + object files.
Returns:
A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
"""
sysconfig = _sysconfig_module()
output_dir, _ = os.path.split(output_prefix)
if sysconfig:
tensorflow_includes = _shlex_quote(sysconfig.get_include())
else:
# Try hard to find the real source directory if this is a local bazel run.
if os.path.islink(__file__):
this_file = __file__
while os.path.islink(this_file):
this_file = os.readlink(this_file)
base = os.path.realpath(
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
else:
try:
base = test.test_src_dir_path('')
except KeyError: # Can't find TEST_SRCDIR in environment path.
base = os.path.realpath(
os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
expected_header = os.path.join(
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
if not os.path.exists(expected_header):
logging.error(
'Could not find includes path. Missing file: {}'
.format(expected_header))
tensorflow_includes = base
return _XLA_MAKEFILE_TEMPLATE.format(
tensorflow_includes=tensorflow_includes,
compiled_dir=_shlex_quote(output_dir),
cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
versions.CXX11_ABI_FLAG))
def _show_tag_sets(saved_model_dir):
"""Prints the tag-sets stored in SavedModel directory.
@ -178,47 +75,6 @@ def _show_tag_sets(saved_model_dir):
print('%r' % ', '.join(sorted(tag_set)))
def _get_variable_nodes_from_graph_def(graph_def):
"""Get the list of Variable nodes from `graph_def`.
Args:
graph_def: An instance of `GraphDef`.
Returns:
A list of `NodeDef` corresponding to variables in the graph.
"""
variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
for f in graph_def.library.function:
variables += [n for n in f.node_def if n.op == 'VarHandleOp']
return variables
def _prune_removed_feed_nodes(signature_def, graph_def):
"""Identify the inputs in the signature no longer in graph_def, prune them.
Args:
signature_def: A `SignatureDef` instance.
graph_def: A `GraphDef` instance.
Returns:
A new pruned `SignatureDef`.
"""
node_names = set([n.name for n in graph_def.node])
new_signature_def = meta_graph_pb2.SignatureDef()
new_signature_def.CopyFrom(signature_def)
for (k, v) in signature_def.inputs.items():
tensor_name, _ = _parse_tensor_name(v.name)
if tensor_name not in node_names:
logging.warn(
'Signature input key \'{}\', tensor name \'{}\', has been pruned '
'while freezing the graph. Removing it from the compiled signatures.'
.format(k, tensor_name))
del new_signature_def.inputs[k]
return new_signature_def
def _show_signature_def_map_keys(saved_model_dir, tag_set):
"""Prints the keys for each SignatureDef in the SignatureDef map.
@ -943,7 +799,7 @@ def aot_compile_cpu(args):
variables_to_feed = None # We will identify them after.
else:
variables_to_feed = args.variables_to_feed.split(',')
aot_compile_cpu_meta_graph_def(
saved_model_aot_compile.aot_compile_cpu_meta_graph_def(
checkpoint_path=checkpoint_path,
meta_graph_def=saved_model_utils.get_meta_graph_def(
args.dir, args.tag_set),
@ -954,210 +810,6 @@ def aot_compile_cpu(args):
cpp_class=args.cpp_class)
def aot_compile_cpu_meta_graph_def(checkpoint_path,
meta_graph_def,
output_prefix,
signature_def_key,
cpp_class,
target_triple,
variables_to_feed=()):
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
Use XLA AOT (`tfcompile`) to convert the given meta graph and
signature into a header + object files. Also create an include makefile
that helps identify the appropriate necessary include and library paths
to incorporate these files into your C++ program.
The graph is always optimized with grappler, and optionally (by default)
variables are frozen as constants, before compilation happens.
If the `freeze_graph` is `True`, all variables are embedded as constants
into the graph and binary objects. If it is `False`, then the variable
values become inputs and outputs of the compiled class and the C++
caller must set these values manually.
Args:
checkpoint_path: Python string. Path to checkpoints/variables.
meta_graph_def: Instance of `MetaGraphDef`.
output_prefix: Python string. Path prefix for outputs.
signature_def_key: String, the signature_def to use in the SavedModel.
cpp_class: String, Name of output C++ class.
target_triple: String, LLVM target triple.
variables_to_feed: A list of strings, the variables that will be fed by the
user; these won't be frozen. If `None`, then we will extract all the
variables in the graph and mark them as to-feed. The default behavior is
an empty tuple: all variables must be frozen.
Raises:
RuntimeError: If tensorflow was not built with XLA.
ImportError: If tensorflow was built with XLA but there was another
issue importing the tfcompile python wrapper.
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
missing or has empty outputs.
"""
if _pywrap_tfcompile_import_error:
raise _pywrap_tfcompile_import_error
signature_def_map = meta_graph_def.signature_def
if signature_def_key not in signature_def_map:
raise ValueError(
'Unable to find signature_def key \'{}\' in signature def map. '
'Available keys: {}'.format(
signature_def_key,
list(signature_def_map.keys())))
signature_def = signature_def_map[signature_def_key]
if not signature_def.outputs:
raise ValueError(
'Signature key {} must have outputs, but saw none:\n{}'.format(
signature_def_key, str(signature_def)))
temp_dir = test.get_temp_dir()
file_io.recursive_create_dir(temp_dir)
if logging.get_verbosity() >= logging.INFO:
original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
# This updates graph_def in place.
_replace_input_placeholders_with_default_values(
meta_graph_def.graph_def, signature_def)
graph_def = _optimize_graph(meta_graph_def, signature_def)
all_variables = _get_variable_nodes_from_graph_def(graph_def)
if variables_to_feed is None:
variable_nodes_to_feed = list(all_variables)
else:
not_in_graph = (
set(variables_to_feed).difference([x.name for x in all_variables]))
if not_in_graph:
raise ValueError(
'Asked to feed variables that were not found in graph: {}. '
'Variables contained in the graph: {}'.format(
not_in_graph, set([x.name for x in all_variables])))
all_variables_map = dict((x.name, x) for x in all_variables)
variable_nodes_to_feed = [
all_variables_map[name] for name in variables_to_feed
]
if logging.get_verbosity() >= logging.INFO:
prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
# Load the Variables so that we can freeze the graph.
with session.Session(graph=ops_lib.Graph()) as sess:
restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
restorer.restore(sess, checkpoint_path)
graph_def.CopyFrom(
graph_util.convert_variables_to_constants(
sess,
graph_def,
output_node_names=[
_parse_tensor_name(n.name)[0]
for n in signature_def.outputs.values()
],
))
signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(graph_def.SerializeToString())
config = _signature_to_tf2xla_config(
signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
config_writer.write(str(config))
output_dir = os.path.dirname(output_prefix)
file_io.recursive_create_dir(output_dir)
entry_digest = hashlib.md5()
entry_digest.update(str(config).encode())
entry_digest.update(str(graph_def).encode())
entry_digest = entry_digest.hexdigest()
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
makefile_writer.write(_xla_makefile_string(output_prefix))
output_prefix = _shlex_quote(output_prefix)
_pywrap_tfcompile.Compile(
graph=frozen_graph_def_location,
config=config_pbtxt_location,
cpp_class=cpp_class,
target_triple=target_triple,
entry_point='entry_{}'.format(entry_digest),
out_function_object='{}.o'.format(output_prefix),
out_header='{}.h'.format(output_prefix),
out_metadata_object='{}_metadata.o'.format(output_prefix),
gen_name_to_index=True,
# ProgramShape isn't uniquefied by entry_point.
gen_program_shape=False)
def _optimize_graph(meta_graph_def, signature_def):
"""Optimize `meta_graph_def` using grappler. Returns a `GraphDef`."""
# We need to add a collection called 'train_op' so that grappler
# knows what the outputs are.
new_meta_graph_def = copy.deepcopy(meta_graph_def)
fetch_collection = meta_graph_pb2.CollectionDef()
for tensor_info in (
list(signature_def.inputs.values()) +
list(signature_def.outputs.values())):
fetch_collection.node_list.value.append(tensor_info.name)
new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
config = config_pb2.ConfigProto()
return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
def _replace_input_placeholders_with_default_values(graph_def, signature_def):
"""Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
name_to_node_map = dict((n.name, n) for n in graph_def.node)
processed_nodes = set([])
for name, input_ in signature_def.inputs.items():
tensor_name, _ = _parse_tensor_name(input_.name)
if tensor_name in processed_nodes:
continue
processed_nodes.add(tensor_name)
if tensor_name not in name_to_node_map:
raise RuntimeError(
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
'Graph nodes are: {}'.format(tensor_name,
list(name_to_node_map.keys())))
node = name_to_node_map[tensor_name]
if node.op not in ('Placeholder', 'PlaceholderV2'):
logging.info(
'Tried to convert SavedModel input node \'{}\' from a placeholder, '
'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
node))
continue
shape = tensor_shape.TensorShape(input_.tensor_shape)
if not shape.is_fully_defined():
raise ValueError(
'Expected fully defined input shape for signature_def \'{}\', '
'tensor name: \'{}\'; but shape is: {}.'
.format(name, tensor_name, shape))
temp_graph = ops_lib.Graph()
with temp_graph.as_default():
const = array_ops.zeros(
shape, dtype=input_.dtype, name=tensor_name)
node.CopyFrom(const.op.node_def)
# Sometimes zeros() also creates additional nodes
for op in temp_graph.get_operations():
if op.name == const.op.name: # We just inserted this one.
continue
graph_def.node.append(op.node_def)
name_to_node_map[op.node_def.name] = op.node_def
def add_show_subparser(subparsers):
"""Add parser for `show`."""
show_msg = (
@ -1482,61 +1134,6 @@ def create_parser():
return parser
def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
Args:
signature_def: Instance of `SignatureDef`.
variable_nodes_to_feed: List NodeDefs corresponding to VarHandleOp,
the list of variables to feed.
Returns:
An instance of `tf2xla.Config` proto.
Raises:
RuntimeError: If TensorFlow was not compiled with XLA.
"""
from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top
config = tf2xla_pb2.Config()
tensor_id = tf2xla_pb2.TensorId
for name, input_ in signature_def.inputs.items():
name = name.replace('/', '_')
name = 'feed_{}'.format(name)
(node_name, output_index) = _parse_tensor_name(input_.name)
output_index = int(output_index)
config.feed.append(
tf2xla_pb2.Feed(
id=tensor_id(node_name=node_name, output_index=output_index),
name=name,
type=input_.dtype,
shape=input_.tensor_shape))
for name, output_ in signature_def.outputs.items():
name = name.replace('/', '_')
name = 'fetch_{}'.format(name)
(node_name, output_index) = _parse_tensor_name(output_.name)
output_index = int(output_index)
config.fetch.append(
tf2xla_pb2.Fetch(
id=tensor_id(node_name=node_name, output_index=output_index),
name=name,
type=output_.dtype,
shape=output_.tensor_shape))
for node in variable_nodes_to_feed:
name = node.name.replace('/', '_')
name = 'param_{}'.format(name)
config.variable.append(
tf2xla_pb2.Variable(
node_name=node.name,
name=name,
type=node.attr['dtype'].type,
shape=node.attr['shape'].shape,
readonly=True))
return config
def main():
logging.set_verbosity(logging.INFO)
parser = create_parser()

View File

@ -733,6 +733,7 @@ Defined Functions:
def __init__(self):
self.var = variables.Variable(1.0, name='my_var')
self.write_var = variables.Variable(1.0, name='write_var')
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
@ -752,20 +753,32 @@ Defined Functions:
del y
return {'res': x + self.var}
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
])
def func_write(self, x, y):
del y
self.write_var.assign(x + self.var)
return {'res': self.write_var}
@parameterized.named_parameters(
('VariablesToFeedNone', '', 'func2'),
('VariablesToFeedAll', 'all', 'func2'),
('VariablesToFeedMyVar', 'my_var', 'func2'),
('VariablesToFeedNoneLargeConstant', '', 'func3'))
('VariablesToFeedNoneLargeConstant', '', 'func3'),
('WriteToWriteVar', 'all', 'func_write'),
)
def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed, func):
if not test.is_built_with_xla():
self.skipTest('Skipping test because XLA is not compiled in.')
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
dummy_model = self.AOTCompileDummyModel()
func = dummy_model.func2 if func == 'func2' else dummy_model.func3
func = getattr(dummy_model, func)
with self.cached_session():
self.evaluate(dummy_model.var.initializer)
self.evaluate(dummy_model.write_var.initializer)
save.save(dummy_model, saved_model_dir, signatures={'func': func})
self.parser = saved_model_cli.create_parser()
@ -793,7 +806,15 @@ Defined Functions:
# arg_y got filtered out as it's not used by the output.
self.assertNotIn('arg_feed_y_data', header_contents)
if variables_to_feed:
self.assertIn('var_param_my_var', header_contents)
# Read-only-variables' setters preserve constness.
self.assertIn('set_var_param_my_var_data(const float', header_contents)
self.assertNotIn('set_var_param_my_var_data(float', header_contents)
if func == dummy_model.func_write:
# Writeable variables setters do not preserve constness.
self.assertIn('set_var_param_write_var_data(float', header_contents)
self.assertNotIn(
'set_var_param_write_var_data(const float', header_contents)
makefile_contents = file_io.read_file_to_string(
'{}_makefile.inc'.format(output_prefix))
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)