diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py index 2884a0a809b..f990dae9966 100644 --- a/tensorflow/python/framework/convert_to_constants.py +++ b/tensorflow/python/framework/convert_to_constants.py @@ -514,7 +514,7 @@ def _convert_variables_to_constants_v2_impl(func, # Get dtype and data for non-variable Placeholders (ex. values for 1.X # Const ops that are loaded as Placeholders in 2.0) _save_placeholder(node.name, node.attr["dtype"]) - elif node.op in ["ReadVariableOp", "ResourceGather"]: + elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]: # Get dtype and data for Placeholder ops associated with ReadVariableOp # and ResourceGather ops. There can be an Identity in between the # resource op and Placeholder. Store the dtype for the Identity ops. @@ -570,6 +570,15 @@ def _convert_variables_to_constants_v2_impl(func, output_node.attr["Taxis"].CopyFrom(axis_dtype) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) + elif input_node.op == "ResourceGatherNd": + output_node.op = "GatherNd" + output_node.name = input_node.name + output_node.input.extend( + [input_node.input[0], input_node.input[1]]) + output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"]) + output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"]) + if "_class" in input_node.attr: + output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) # Update the function names and argument types for the conditional ops. elif input_node.op in _CONDITIONAL_OPS: _populate_if_op(output_node, input_node, function_data) diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index 6b88d7b02c3..b8d434278e3 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -237,7 +237,9 @@ def tensor_shape_from_node_def_name(graph, input_name): return shape -def _update_resource_identities(resource_identities, output_graph_def): +def _update_resource_identities(resource_identities, output_graph_def, + variable_names_whitelist, + variable_names_blacklist): """Updates the type of DT_RESOURCE Identity ops. Updates the type of the `resource_identities` to the type of the node that @@ -248,6 +250,10 @@ def _update_resource_identities(resource_identities, output_graph_def): resource_identities: List of NodeDef protos that are Identity ops with the type DT_RESOURCE. output_graph_def: GraphDef proto. + variable_names_whitelist: The set of variable names to convert (by default, + all variables are converted). + variable_names_blacklist: The set of variable names to omit converting + to constants. """ # Identify the nodes in the graph and the nodes consuming each node. map_name_to_node = {} @@ -280,14 +286,27 @@ def _update_resource_identities(resource_identities, output_graph_def): "This Identity's type was changed from DT_RESOURCE during graph " "freezing.") if input_node.attr["T"].type != dtypes.resource: - if input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY: + if (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY + and _should_convert( + input_node.input[0], + variable_names_whitelist, + variable_names_blacklist)): node.attr["T"].CopyFrom(input_node.attr["T"]) node.attr["_debugging"].s = debugging_message - elif input_node.op == "VarHandleOp": + elif (input_node.op == "VarHandleOp" + and _should_convert( + input_node.name, + variable_names_whitelist, + variable_names_blacklist)): node.attr["T"].CopyFrom(input_node.attr["dtype"]) node.attr["_debugging"].s = debugging_message +def _should_convert(name, whitelist, blacklist): + return ((whitelist is None or name in whitelist) + and (blacklist is None or name not in blacklist)) + + @deprecation.deprecated( date=None, instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`") @@ -315,6 +334,10 @@ def convert_variables_to_constants(sess, Returns: GraphDef containing a simplified version of the original. + + Raises: + RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both + blacklisted AND whitelisted for freezing. """ get_input_name = lambda node, index=0: node.input[index].split(":")[0] @@ -344,44 +367,60 @@ def convert_variables_to_constants(sess, variable_names = [] variable_dict_names = [] resource_op_types = {} + for node in inference_graph.node: if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name - if ((variable_names_whitelist is not None and - variable_name not in variable_names_whitelist) or - (variable_names_blacklist is not None and - variable_name in variable_names_blacklist)): + if not _should_convert( + variable_name, variable_names_whitelist, variable_names_blacklist): continue variable_dict_names.append(variable_name) if node.op == "VarHandleOp": variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") - elif node.op in ["ReadVariableOp", "ResourceGather"]: + elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]: # There can be one or more Identity or control flow ops in between the # ReadVariableOp and VarHandleOp. Store the ops with the associated # dtypes. source_op_names = [get_input_name(node)] + candidate_resource_op_types = {} while (source_op_names and map_name_to_node[source_op_names[0]].op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY): source_op_name = source_op_names.pop() current_node = map_name_to_node[source_op_name] - if source_op_name not in resource_op_types: - resource_op_types[source_op_name] = node.attr["dtype"] + if (source_op_name not in resource_op_types and + source_op_name not in candidate_resource_op_types): + candidate_resource_op_types[source_op_name] = node.attr["dtype"] source_op_names.append(get_input_name(current_node)) if current_node == "Merge": merge_resource_name = get_input_name(current_node, index=1) - if merge_resource_name not in resource_op_types: - resource_op_types[merge_resource_name] = node.attr["dtype"] + if (merge_resource_name not in resource_op_types + and merge_resource_name not in candidate_resource_op_types): + candidate_resource_op_types[merge_resource_name] = ( + node.attr["dtype"]) source_op_names.append( get_input_name(map_name_to_node[merge_resource_name])) + should_convert_all = None for source_node in source_op_names: if map_name_to_node[source_node].op != "VarHandleOp": raise ValueError("Cannot find the variable that is an input " "to the ReadVariableOp.") + should_convert_node = _should_convert( + source_node, variable_names_whitelist, variable_names_blacklist) + if should_convert_all is None: + should_convert_all = should_convert_node + elif should_convert_all != should_convert_node: + raise RuntimeError( + "Found DT_RESOURCE node whose ancestor Variables are both " + "blacklisted AND whitelisted for freezing. Originating " + "descendant node: {}. Ancestor variables: {}.".format( + node.name, source_op_names)) + if should_convert_all in (None, True): + resource_op_types.update(candidate_resource_op_types) # Gets map of variables and the associated data. if variable_names: @@ -391,6 +430,15 @@ def convert_variables_to_constants(sess, variables_data_map = dict(zip(variable_dict_names, returned_variables)) logging.info("Froze %d variables.", len(returned_variables)) + def _should_convert_ancestor(node): + input_node = map_name_to_node[_node_name(node.input[0])] + while (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY and + input_node.attr["T"].type == dtypes.resource): + input_node = map_name_to_node[_node_name(input_node.input[0])] + return _should_convert(input_node.name, + variable_names_whitelist, + variable_names_blacklist) + # Reconstruct the graph with constants in place of variables. output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 @@ -413,7 +461,8 @@ def convert_variables_to_constants(sess, if str(attr_name) != "_output_shapes": output_node.attr[attr_name].CopyFrom(input_node.attr[attr_name]) output_node.attr["T"].CopyFrom(resource_op_types[input_node.name]) - elif input_node.op == "ReadVariableOp": + elif (input_node.op == "ReadVariableOp" + and _should_convert_ancestor(input_node)): # The first branch converts all VarHandleOps of ResourceVariables to # constants, so we need to convert the associated ReadVariableOps to # Identity ops. @@ -423,7 +472,8 @@ def convert_variables_to_constants(sess, output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) - elif input_node.op == "ResourceGather": + elif (input_node.op == "ResourceGather" + and _should_convert_ancestor(input_node)): # The first branch converts all VarHandleOps of ResourceGather to # constants, so we need to convert the associated ResourceGather to Gather # ops with a Const axis feeding into it. @@ -444,9 +494,19 @@ def convert_variables_to_constants(sess, output_node.attr["Taxis"].CopyFrom(axis_dtype) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) + elif (input_node.op == "ResourceGatherNd" + and _should_convert_ancestor(input_node)): + output_node.op = "GatherNd" + output_node.name = input_node.name + output_node.input.extend( + [input_node.input[0], input_node.input[1]]) + output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"]) + output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"]) + if "_class" in input_node.attr: + output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: output_node.CopyFrom(input_node) - output_graph_def.node.extend([output_node]) + output_graph_def.node.append(output_node) # Update the types of the DT_RESOURCE Identity nodes that do not have an # associated ReadVariableOp. @@ -455,7 +515,10 @@ def convert_variables_to_constants(sess, if node.op == "Identity" and node.attr["T"].type == dtypes.resource: resource_identities.append(node) if resource_identities: - _update_resource_identities(resource_identities, output_graph_def) + _update_resource_identities(resource_identities, + output_graph_def, + variable_names_whitelist, + variable_names_blacklist) output_graph_def.library.CopyFrom(inference_graph.library) logging.info("Converted %d variables to const ops.", how_many_converted) diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py index c28167a68c3..1f04993b343 100644 --- a/tensorflow/python/framework/graph_util_test.py +++ b/tensorflow/python/framework/graph_util_test.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_state_ops @@ -418,6 +419,53 @@ class ConvertVariablesToConstantsTest(test.TestCase): output = self.evaluate(output_node) self.assertNear(2.0, output, 0.00001) + def test_resource_variable_can_be_written_after_blacklisting(self): + with ops.Graph().as_default(): + with variable_scope.variable_scope("", use_resource=True): + variable_node = variable_scope.get_variable( + "variable_node", initializer=1.0) + another_variable = variable_scope.get_variable( + "unused_variable_node", initializer=2.0) + with ops.control_dependencies([ + variable_node.assign(another_variable + variable_node)]): + output_node = array_ops.identity(variable_node, name="output_node") + initializer_name = variable_node.initializer.name + with session.Session() as sess: + self.evaluate(variable_node.initializer) + self.evaluate(another_variable.initializer) + output = self.evaluate(output_node) + self.assertNear(3.0, output, 0.00001) + variable_graph_def = sess.graph.as_graph_def() + + # Test variable name black list. This should result in the variable + # not being a const. Furthermore, the paths that read from and assign + # to the blacklisted variable should continue to be valid. + constant_graph_def_with_blacklist = ( + graph_util.convert_variables_to_constants( + sess, + variable_graph_def, ["output_node", initializer_name], + variable_names_blacklist=set(["variable_node"]))) + + variable_node = None + for node in constant_graph_def_with_blacklist.node: + if node.name == "variable_node": + variable_node = node + self.assertIsNotNone(variable_node) + self.assertEqual(variable_node.op, "VarHandleOp") + + # Now we make sure another_variable is now a constant, but the original + # variable is not, and that the graph can be executed and update the + # variable can be updated with each execution. + with ops.Graph().as_default(): + _ = importer.import_graph_def(constant_graph_def_with_blacklist, name="") + with session.Session() as sess: + output_node = sess.graph.get_tensor_by_name("output_node:0") + self.evaluate(sess.graph.get_operation_by_name(initializer_name)) + output = self.evaluate(output_node) + self.assertNear(3.0, output, 0.00001) + output = self.evaluate(output_node) + self.assertNear(5.0, output, 0.00001) + def _inline_functions(self, graph_def, arrays): meta_graph = export_meta_graph(graph_def=graph_def) fetch_collection = meta_graph_pb2.CollectionDef() diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index f85f3b7ec06..07e2ac0e6b9 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -322,9 +322,19 @@ py_library( srcs = ["saved_model_cli.py"], srcs_version = "PY2AND3", deps = [ + ":saved_model_aot_compile", ":saved_model_utils", "//tensorflow/python", "//tensorflow/python/debug:local_cli_wrapper", + ], +) + +py_library( + name = "saved_model_aot_compile", + srcs = ["saved_model_aot_compile.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python", "//tensorflow/python:tf_optimizer", ] + if_xla_available( ["//tensorflow/compiler/tf2xla:tf2xla_proto_py"], diff --git a/tensorflow/python/tools/saved_model_aot_compile.py b/tensorflow/python/tools/saved_model_aot_compile.py new file mode 100644 index 00000000000..7c594bb238f --- /dev/null +++ b/tensorflow/python/tools/saved_model_aot_compile.py @@ -0,0 +1,472 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper utilities for AOT compilation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import copy +import hashlib +import os +import pipes +import shlex + +import six + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import versions +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import sysconfig as sysconfig_lib +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib + +try: + from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top +except ImportError as e: + _pywrap_tfcompile_import_error = ImportError( + 'Unable to import _pywrap_tfcompile; you must build TensorFlow ' + 'with XLA. You may need to build tensorflow with flag ' + '--define=with_xla_support=true. Original error: {}'.format(str(e))) +else: + _pywrap_tfcompile_import_error = None + + +_READ_ONLY_VARIABLE_OPS = ( + 'ReadVariableOp', + 'IsVariableInitializedOp', + 'ResourceGather', + 'ResourceGatherNd', + 'VariableShape', +) + +_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN') + + +def _shlex_quote(s): + if six.PY2: + return pipes.quote(s) + else: + return shlex.quote(s) + + +def _sysconfig_module(): + """Load tf.sysconfig if available and working (i.e., inside a pip package).""" + try: + _ = sysconfig_lib.get_include() + except ImportError: + return None + return sysconfig_lib + + +def _parse_tensor_name(name): + """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0).""" + if ':' in name and not name.endswith(':'): + node_name = name[:name.rfind(':')] + output_slot = int(name[name.rfind(':') + 1:]) + return node_name, output_slot + else: + return name, None + + +_XLA_MAKEFILE_TEMPLATE = """ +INC = -I{tensorflow_includes} +LIB = -L{compiled_dir} +CXXFLAGS = {cxx_flags} +""" + + +def _xla_makefile_string(output_prefix): + """Returns a Makefile string with variables for using XLA binary object files. + + Attempts to identify the right include header paths when run from either + an installed TensorFlow pip package, or from bazel run. + + Args: + output_prefix: A string containing the output prefix for the XLA AOT + compiled header + object files. + + Returns: + A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`. + """ + sysconfig = _sysconfig_module() + output_dir, _ = os.path.split(output_prefix) + if sysconfig: + tensorflow_includes = _shlex_quote(sysconfig.get_include()) + else: + # Try hard to find the real source directory if this is a local bazel run. + if os.path.islink(__file__): + this_file = __file__ + while os.path.islink(this_file): + this_file = os.readlink(this_file) + base = os.path.realpath( + os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3))) + else: + try: + base = test.test_src_dir_path('') + except KeyError: # Can't find TEST_SRCDIR in environment path. + base = os.path.realpath( + os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3))) + expected_header = os.path.join( + base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h') + if not os.path.exists(expected_header): + logging.error( + 'Could not find includes path. Missing file: {}' + .format(expected_header)) + tensorflow_includes = base + + return _XLA_MAKEFILE_TEMPLATE.format( + tensorflow_includes=tensorflow_includes, + compiled_dir=_shlex_quote(output_dir), + cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format( + versions.CXX11_ABI_FLAG)) + + +def _get_variable_nodes_from_graph_def(graph_def): + """Get the list of Variable nodes from `graph_def`. + + Args: + graph_def: An instance of `GraphDef`. This GraphDef *must* + have already been optimized by Grappler. In particular, function + inlining must have already happened. + + Returns: + A dict mapping string names of variables to tuples `(node_def, modified)`, + where `node_def` is the `NodeDef` corresponding to variable, and `modified` + is a python bool describing whether the variable is modified during runtime. + """ + variables = [n for n in graph_def.node if n.op == 'VarHandleOp'] + variable_name_map = dict((n.name, n) for n in variables) + child_map = collections.defaultdict(lambda: []) + for n in graph_def.node: + for inp in n.input: + if not inp.startswith('^'): + child_map[inp].append(n) + variables = {} + for (v_name, v_node) in variable_name_map.items(): + queue = list(child_map[v_name]) + processed = set([]) + while queue: + n_current = queue.pop() + if n_current.name in processed: + continue + processed.add(n_current.name) + if n_current.op in _PASS_THROUGH_VARIABLE_OPS: + children = child_map.get(n_current.name, []) + queue.extend(children) + elif n_current.op not in _READ_ONLY_VARIABLE_OPS: + variables[v_name] = (v_node, True) + queue = [] + if v_name not in variables: + variables[v_name] = (v_node, False) + + return variables + + +def _prune_removed_feed_nodes(signature_def, graph_def): + """Identify the inputs in the signature no longer in graph_def, prune them. + + Args: + signature_def: A `SignatureDef` instance. + graph_def: A `GraphDef` instance. + + Returns: + A new pruned `SignatureDef`. + """ + node_names = set([n.name for n in graph_def.node]) + new_signature_def = meta_graph_pb2.SignatureDef() + new_signature_def.CopyFrom(signature_def) + for (k, v) in signature_def.inputs.items(): + tensor_name, _ = _parse_tensor_name(v.name) + if tensor_name not in node_names: + logging.warn( + 'Signature input key \'{}\', tensor name \'{}\', has been pruned ' + 'while freezing the graph. Removing it from the compiled signatures.' + .format(k, tensor_name)) + del new_signature_def.inputs[k] + return new_signature_def + + +def aot_compile_cpu_meta_graph_def(checkpoint_path, + meta_graph_def, + output_prefix, + signature_def_key, + cpp_class, + target_triple, + variables_to_feed=()): + """Compile a `MetaGraphDef` to header+object files in `output_prefix`. + + Use XLA AOT (`tfcompile`) to convert the given meta graph and + signature into a header + object files. Also create an include makefile + that helps identify the appropriate necessary include and library paths + to incorporate these files into your C++ program. + + The graph is always optimized with grappler, and optionally (by default) + variables are frozen as constants, before compilation happens. + + If the `freeze_graph` is `True`, all variables are embedded as constants + into the graph and binary objects. If it is `False`, then the variable + values become inputs and outputs of the compiled class and the C++ + caller must set these values manually. + + Args: + checkpoint_path: Python string. Path to checkpoints/variables. + meta_graph_def: Instance of `MetaGraphDef`. + output_prefix: Python string. Path prefix for outputs. + signature_def_key: String, the signature_def to use in the SavedModel. + cpp_class: String, Name of output C++ class. + target_triple: String, LLVM target triple. + variables_to_feed: A list of strings, the variables that will be fed by the + user; these won't be frozen. If `None`, then we will extract all the + variables in the graph and mark them as to-feed. The default behavior is + an empty tuple: all variables must be frozen. + + Raises: + RuntimeError: If tensorflow was not built with XLA. + ImportError: If tensorflow was built with XLA but there was another + issue importing the tfcompile python wrapper. + ValueError: If `meta_graph_def.signature_def[signature_def_key]` is + missing or has empty outputs. + """ + if _pywrap_tfcompile_import_error: + raise _pywrap_tfcompile_import_error + + signature_def_map = meta_graph_def.signature_def + if signature_def_key not in signature_def_map: + raise ValueError( + 'Unable to find signature_def key \'{}\' in signature def map. ' + 'Available keys: {}'.format( + signature_def_key, + list(signature_def_map.keys()))) + signature_def = signature_def_map[signature_def_key] + if not signature_def.outputs: + raise ValueError( + 'Signature key {} must have outputs, but saw none:\n{}'.format( + signature_def_key, str(signature_def))) + + temp_dir = test.get_temp_dir() + file_io.recursive_create_dir(temp_dir) + if logging.get_verbosity() >= logging.INFO: + original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb') + with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: + graph_writer.write(meta_graph_def.graph_def.SerializeToString()) + + # This updates graph_def in place. + _replace_input_placeholders_with_default_values( + meta_graph_def.graph_def, signature_def) + + graph_def = _optimize_graph(meta_graph_def, signature_def) + + all_variables = _get_variable_nodes_from_graph_def(graph_def) + if variables_to_feed is None: + variable_nodes_to_feed = list(all_variables.values()) + else: + not_in_graph = set(variables_to_feed).difference(list(all_variables)) + if not_in_graph: + raise ValueError( + 'Asked to feed variables that were not found in graph: {}. ' + 'Variables contained in the graph: {}'.format( + not_in_graph, list(all_variables))) + variable_nodes_to_feed = [ + all_variables[name] for name in variables_to_feed + ] + + if logging.get_verbosity() >= logging.INFO: + prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb') + with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: + graph_writer.write(graph_def.SerializeToString()) + + # Load the Variables so that we can freeze the graph. + with session.Session(graph=ops_lib.Graph()) as sess: + restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) + restorer.restore(sess, checkpoint_path) + graph_def.CopyFrom( + graph_util.convert_variables_to_constants( + sess, + graph_def, + output_node_names=[ + _parse_tensor_name(n.name)[0] + for n in signature_def.outputs.values() + ], + variable_names_blacklist=[ + name for (name, node_modified) in all_variables.items() + if node_modified[1] + ], + )) + + signature_def = _prune_removed_feed_nodes(signature_def, graph_def) + + frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb') + config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt') + logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) + with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: + graph_writer.write(graph_def.SerializeToString()) + config = _signature_to_tf2xla_config( + signature_def, variable_nodes_to_feed=variable_nodes_to_feed) + logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) + with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: + config_writer.write(str(config)) + + output_dir = os.path.dirname(output_prefix) + file_io.recursive_create_dir(output_dir) + + entry_digest = hashlib.md5() + entry_digest.update(str(config).encode()) + entry_digest.update(str(graph_def).encode()) + entry_digest = entry_digest.hexdigest() + + logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) + + makefile_inc_location = '{}_makefile.inc'.format(output_prefix) + with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer: + makefile_writer.write(_xla_makefile_string(output_prefix)) + + output_prefix = _shlex_quote(output_prefix) + + _pywrap_tfcompile.Compile( + graph=frozen_graph_def_location, + config=config_pbtxt_location, + cpp_class=cpp_class, + target_triple=target_triple, + entry_point='entry_{}'.format(entry_digest), + out_function_object='{}.o'.format(output_prefix), + out_header='{}.h'.format(output_prefix), + out_metadata_object='{}_metadata.o'.format(output_prefix), + gen_name_to_index=True, + # ProgramShape isn't uniquefied by entry_point. + gen_program_shape=False) + + +def _optimize_graph(meta_graph_def, signature_def): + """Optimize `meta_graph_def` using grappler. Returns a `GraphDef`.""" + # We need to add a collection called 'train_op' so that grappler + # knows what the outputs are. + new_meta_graph_def = copy.deepcopy(meta_graph_def) + fetch_collection = meta_graph_pb2.CollectionDef() + for tensor_info in ( + list(signature_def.inputs.values()) + + list(signature_def.outputs.values())): + fetch_collection.node_list.value.append(tensor_info.name) + + new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection) + + config = config_pb2.ConfigProto() + return tf_optimizer.OptimizeGraph(config, new_meta_graph_def) + + +def _replace_input_placeholders_with_default_values(graph_def, signature_def): + """Replace graphdef's `tf.placeholder` input ops with all-zero constants.""" + name_to_node_map = dict((n.name, n) for n in graph_def.node) + processed_nodes = set([]) + for name, input_ in signature_def.inputs.items(): + tensor_name, _ = _parse_tensor_name(input_.name) + if tensor_name in processed_nodes: + continue + processed_nodes.add(tensor_name) + if tensor_name not in name_to_node_map: + raise RuntimeError( + 'Unable to find input signature tensor \'{}\' in optimized GraphDef. ' + 'Graph nodes are: {}'.format(tensor_name, + list(name_to_node_map.keys()))) + node = name_to_node_map[tensor_name] + if node.op not in ('Placeholder', 'PlaceholderV2'): + logging.info( + 'Tried to convert SavedModel input node \'{}\' from a placeholder, ' + 'but it doesn\'t look like a placeholder: {}'.format(tensor_name, + node)) + continue + shape = tensor_shape.TensorShape(input_.tensor_shape) + if not shape.is_fully_defined(): + raise ValueError( + 'Expected fully defined input shape for signature_def \'{}\', ' + 'tensor name: \'{}\'; but shape is: {}.' + .format(name, tensor_name, shape)) + temp_graph = ops_lib.Graph() + with temp_graph.as_default(): + const = array_ops.zeros( + shape, dtype=input_.dtype, name=tensor_name) + node.CopyFrom(const.op.node_def) + # Sometimes zeros() also creates additional nodes + for op in temp_graph.get_operations(): + if op.name == const.op.name: # We just inserted this one. + continue + graph_def.node.append(op.node_def) + name_to_node_map[op.node_def.name] = op.node_def + + +def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed): + """Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto. + + Args: + signature_def: Instance of `SignatureDef`. + variable_nodes_to_feed: List of tuples of form `(node_def, modified)` + corresponding to VarHandleOp, and a boolean `modified` that describes + whether the variable was modified during execution. + + Returns: + An instance of `tf2xla.Config` proto. + + Raises: + RuntimeError: If TensorFlow was not compiled with XLA. + """ + from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top + + config = tf2xla_pb2.Config() + tensor_id = tf2xla_pb2.TensorId + + for name, input_ in signature_def.inputs.items(): + name = name.replace('/', '_') + name = 'feed_{}'.format(name) + (node_name, output_index) = _parse_tensor_name(input_.name) + output_index = int(output_index) + config.feed.append( + tf2xla_pb2.Feed( + id=tensor_id(node_name=node_name, output_index=output_index), + name=name, + type=input_.dtype, + shape=input_.tensor_shape)) + for name, output_ in signature_def.outputs.items(): + name = name.replace('/', '_') + name = 'fetch_{}'.format(name) + (node_name, output_index) = _parse_tensor_name(output_.name) + output_index = int(output_index) + config.fetch.append( + tf2xla_pb2.Fetch( + id=tensor_id(node_name=node_name, output_index=output_index), + name=name, + type=output_.dtype, + shape=output_.tensor_shape)) + for (node, modified) in variable_nodes_to_feed: + name = node.name.replace('/', '_') + name = 'param_{}'.format(name) + config.variable.append( + tf2xla_pb2.Variable( + node_name=node.name, + name=name, + type=node.attr['dtype'].type, + shape=node.attr['shape'].shape, + readonly=not modified)) + + return config diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 3b266dc7482..01494bbf04f 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -25,12 +25,8 @@ from __future__ import print_function import argparse import collections -import copy -import hashlib import os -import pipes import re -import shlex import sys import numpy as np @@ -38,31 +34,22 @@ import six from tensorflow.core.example import example_pb2 from tensorflow.core.framework import types_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.client import session from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun -from tensorflow.python.framework import graph_util from tensorflow.python.framework import meta_graph as meta_graph_lib from tensorflow.python.framework import ops as ops_lib -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec -from tensorflow.python.framework import versions -from tensorflow.python.grappler import tf_optimizer from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import array_ops from tensorflow.python.platform import app # pylint: disable=unused-import -from tensorflow.python.platform import sysconfig as sysconfig_lib -from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import load from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import save from tensorflow.python.saved_model import signature_constants +from tensorflow.python.tools import saved_model_aot_compile from tensorflow.python.tools import saved_model_utils -from tensorflow.python.training import saver as saver_lib _XLA_DEBUG_OPTIONS_URL = ( @@ -70,100 +57,10 @@ _XLA_DEBUG_OPTIONS_URL = ( 'tensorflow/compiler/xla/debug_options_flags.cc') -try: - from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top -except ImportError as e: - _pywrap_tfcompile_import_error = ImportError( - 'Unable to import _pywrap_tfcompile; you must build TensorFlow ' - 'with XLA. You may need to build tensorflow with flag ' - '--define=with_xla_support=true. Original error: {}'.format(str(e))) -else: - _pywrap_tfcompile_import_error = None - - # Set of ops to blacklist. _OP_BLACKLIST = set(['WriteFile', 'ReadFile', 'PrintV2']) -def _shlex_quote(s): - if six.PY2: - return pipes.quote(s) - else: - return shlex.quote(s) - - -def _sysconfig_module(): - """Load tf.sysconfig if available and working (i.e., inside a pip package).""" - try: - _ = sysconfig_lib.get_include() - except ImportError: - return None - return sysconfig_lib - - -def _parse_tensor_name(name): - """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0).""" - if ':' in name and not name.endswith(':'): - node_name = name[:name.rfind(':')] - output_slot = int(name[name.rfind(':') + 1:]) - return node_name, output_slot - else: - return name, None - - -_XLA_MAKEFILE_TEMPLATE = """ -INC = -I{tensorflow_includes} -LIB = -L{compiled_dir} -CXXFLAGS = {cxx_flags} -""" - - -def _xla_makefile_string(output_prefix): - """Returns a Makefile string with variables for using XLA binary object files. - - Attempts to identify the right include header paths when run from either - an installed TensorFlow pip package, or from bazel run. - - Args: - output_prefix: A string containing the output prefix for the XLA AOT - compiled header + object files. - - Returns: - A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`. - """ - sysconfig = _sysconfig_module() - output_dir, _ = os.path.split(output_prefix) - if sysconfig: - tensorflow_includes = _shlex_quote(sysconfig.get_include()) - else: - # Try hard to find the real source directory if this is a local bazel run. - if os.path.islink(__file__): - this_file = __file__ - while os.path.islink(this_file): - this_file = os.readlink(this_file) - base = os.path.realpath( - os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3))) - else: - try: - base = test.test_src_dir_path('') - except KeyError: # Can't find TEST_SRCDIR in environment path. - base = os.path.realpath( - os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3))) - expected_header = os.path.join( - base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h') - if not os.path.exists(expected_header): - logging.error( - 'Could not find includes path. Missing file: {}' - .format(expected_header)) - tensorflow_includes = base - - return _XLA_MAKEFILE_TEMPLATE.format( - tensorflow_includes=tensorflow_includes, - compiled_dir=_shlex_quote(output_dir), - cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format( - versions.CXX11_ABI_FLAG)) - - def _show_tag_sets(saved_model_dir): """Prints the tag-sets stored in SavedModel directory. @@ -178,47 +75,6 @@ def _show_tag_sets(saved_model_dir): print('%r' % ', '.join(sorted(tag_set))) -def _get_variable_nodes_from_graph_def(graph_def): - """Get the list of Variable nodes from `graph_def`. - - Args: - graph_def: An instance of `GraphDef`. - - Returns: - A list of `NodeDef` corresponding to variables in the graph. - """ - variables = [n for n in graph_def.node if n.op == 'VarHandleOp'] - - for f in graph_def.library.function: - variables += [n for n in f.node_def if n.op == 'VarHandleOp'] - - return variables - - -def _prune_removed_feed_nodes(signature_def, graph_def): - """Identify the inputs in the signature no longer in graph_def, prune them. - - Args: - signature_def: A `SignatureDef` instance. - graph_def: A `GraphDef` instance. - - Returns: - A new pruned `SignatureDef`. - """ - node_names = set([n.name for n in graph_def.node]) - new_signature_def = meta_graph_pb2.SignatureDef() - new_signature_def.CopyFrom(signature_def) - for (k, v) in signature_def.inputs.items(): - tensor_name, _ = _parse_tensor_name(v.name) - if tensor_name not in node_names: - logging.warn( - 'Signature input key \'{}\', tensor name \'{}\', has been pruned ' - 'while freezing the graph. Removing it from the compiled signatures.' - .format(k, tensor_name)) - del new_signature_def.inputs[k] - return new_signature_def - - def _show_signature_def_map_keys(saved_model_dir, tag_set): """Prints the keys for each SignatureDef in the SignatureDef map. @@ -943,7 +799,7 @@ def aot_compile_cpu(args): variables_to_feed = None # We will identify them after. else: variables_to_feed = args.variables_to_feed.split(',') - aot_compile_cpu_meta_graph_def( + saved_model_aot_compile.aot_compile_cpu_meta_graph_def( checkpoint_path=checkpoint_path, meta_graph_def=saved_model_utils.get_meta_graph_def( args.dir, args.tag_set), @@ -954,210 +810,6 @@ def aot_compile_cpu(args): cpp_class=args.cpp_class) -def aot_compile_cpu_meta_graph_def(checkpoint_path, - meta_graph_def, - output_prefix, - signature_def_key, - cpp_class, - target_triple, - variables_to_feed=()): - """Compile a `MetaGraphDef` to header+object files in `output_prefix`. - - Use XLA AOT (`tfcompile`) to convert the given meta graph and - signature into a header + object files. Also create an include makefile - that helps identify the appropriate necessary include and library paths - to incorporate these files into your C++ program. - - The graph is always optimized with grappler, and optionally (by default) - variables are frozen as constants, before compilation happens. - - If the `freeze_graph` is `True`, all variables are embedded as constants - into the graph and binary objects. If it is `False`, then the variable - values become inputs and outputs of the compiled class and the C++ - caller must set these values manually. - - Args: - checkpoint_path: Python string. Path to checkpoints/variables. - meta_graph_def: Instance of `MetaGraphDef`. - output_prefix: Python string. Path prefix for outputs. - signature_def_key: String, the signature_def to use in the SavedModel. - cpp_class: String, Name of output C++ class. - target_triple: String, LLVM target triple. - variables_to_feed: A list of strings, the variables that will be fed by the - user; these won't be frozen. If `None`, then we will extract all the - variables in the graph and mark them as to-feed. The default behavior is - an empty tuple: all variables must be frozen. - - Raises: - RuntimeError: If tensorflow was not built with XLA. - ImportError: If tensorflow was built with XLA but there was another - issue importing the tfcompile python wrapper. - ValueError: If `meta_graph_def.signature_def[signature_def_key]` is - missing or has empty outputs. - """ - if _pywrap_tfcompile_import_error: - raise _pywrap_tfcompile_import_error - - signature_def_map = meta_graph_def.signature_def - if signature_def_key not in signature_def_map: - raise ValueError( - 'Unable to find signature_def key \'{}\' in signature def map. ' - 'Available keys: {}'.format( - signature_def_key, - list(signature_def_map.keys()))) - signature_def = signature_def_map[signature_def_key] - if not signature_def.outputs: - raise ValueError( - 'Signature key {} must have outputs, but saw none:\n{}'.format( - signature_def_key, str(signature_def))) - - temp_dir = test.get_temp_dir() - file_io.recursive_create_dir(temp_dir) - if logging.get_verbosity() >= logging.INFO: - original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb') - with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: - graph_writer.write(meta_graph_def.graph_def.SerializeToString()) - - # This updates graph_def in place. - _replace_input_placeholders_with_default_values( - meta_graph_def.graph_def, signature_def) - graph_def = _optimize_graph(meta_graph_def, signature_def) - - all_variables = _get_variable_nodes_from_graph_def(graph_def) - if variables_to_feed is None: - variable_nodes_to_feed = list(all_variables) - else: - not_in_graph = ( - set(variables_to_feed).difference([x.name for x in all_variables])) - if not_in_graph: - raise ValueError( - 'Asked to feed variables that were not found in graph: {}. ' - 'Variables contained in the graph: {}'.format( - not_in_graph, set([x.name for x in all_variables]))) - all_variables_map = dict((x.name, x) for x in all_variables) - variable_nodes_to_feed = [ - all_variables_map[name] for name in variables_to_feed - ] - - if logging.get_verbosity() >= logging.INFO: - prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb') - with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: - graph_writer.write(meta_graph_def.graph_def.SerializeToString()) - - # Load the Variables so that we can freeze the graph. - with session.Session(graph=ops_lib.Graph()) as sess: - restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) - restorer.restore(sess, checkpoint_path) - graph_def.CopyFrom( - graph_util.convert_variables_to_constants( - sess, - graph_def, - output_node_names=[ - _parse_tensor_name(n.name)[0] - for n in signature_def.outputs.values() - ], - )) - - signature_def = _prune_removed_feed_nodes(signature_def, graph_def) - - frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb') - config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt') - logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) - with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: - graph_writer.write(graph_def.SerializeToString()) - config = _signature_to_tf2xla_config( - signature_def, variable_nodes_to_feed=variable_nodes_to_feed) - logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) - with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: - config_writer.write(str(config)) - - output_dir = os.path.dirname(output_prefix) - file_io.recursive_create_dir(output_dir) - - entry_digest = hashlib.md5() - entry_digest.update(str(config).encode()) - entry_digest.update(str(graph_def).encode()) - entry_digest = entry_digest.hexdigest() - - logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) - - makefile_inc_location = '{}_makefile.inc'.format(output_prefix) - with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer: - makefile_writer.write(_xla_makefile_string(output_prefix)) - - output_prefix = _shlex_quote(output_prefix) - - _pywrap_tfcompile.Compile( - graph=frozen_graph_def_location, - config=config_pbtxt_location, - cpp_class=cpp_class, - target_triple=target_triple, - entry_point='entry_{}'.format(entry_digest), - out_function_object='{}.o'.format(output_prefix), - out_header='{}.h'.format(output_prefix), - out_metadata_object='{}_metadata.o'.format(output_prefix), - gen_name_to_index=True, - # ProgramShape isn't uniquefied by entry_point. - gen_program_shape=False) - - -def _optimize_graph(meta_graph_def, signature_def): - """Optimize `meta_graph_def` using grappler. Returns a `GraphDef`.""" - # We need to add a collection called 'train_op' so that grappler - # knows what the outputs are. - new_meta_graph_def = copy.deepcopy(meta_graph_def) - fetch_collection = meta_graph_pb2.CollectionDef() - for tensor_info in ( - list(signature_def.inputs.values()) + - list(signature_def.outputs.values())): - fetch_collection.node_list.value.append(tensor_info.name) - - new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection) - - config = config_pb2.ConfigProto() - return tf_optimizer.OptimizeGraph(config, new_meta_graph_def) - - -def _replace_input_placeholders_with_default_values(graph_def, signature_def): - """Replace graphdef's `tf.placeholder` input ops with all-zero constants.""" - name_to_node_map = dict((n.name, n) for n in graph_def.node) - processed_nodes = set([]) - for name, input_ in signature_def.inputs.items(): - tensor_name, _ = _parse_tensor_name(input_.name) - if tensor_name in processed_nodes: - continue - processed_nodes.add(tensor_name) - if tensor_name not in name_to_node_map: - raise RuntimeError( - 'Unable to find input signature tensor \'{}\' in optimized GraphDef. ' - 'Graph nodes are: {}'.format(tensor_name, - list(name_to_node_map.keys()))) - node = name_to_node_map[tensor_name] - if node.op not in ('Placeholder', 'PlaceholderV2'): - logging.info( - 'Tried to convert SavedModel input node \'{}\' from a placeholder, ' - 'but it doesn\'t look like a placeholder: {}'.format(tensor_name, - node)) - continue - shape = tensor_shape.TensorShape(input_.tensor_shape) - if not shape.is_fully_defined(): - raise ValueError( - 'Expected fully defined input shape for signature_def \'{}\', ' - 'tensor name: \'{}\'; but shape is: {}.' - .format(name, tensor_name, shape)) - temp_graph = ops_lib.Graph() - with temp_graph.as_default(): - const = array_ops.zeros( - shape, dtype=input_.dtype, name=tensor_name) - node.CopyFrom(const.op.node_def) - # Sometimes zeros() also creates additional nodes - for op in temp_graph.get_operations(): - if op.name == const.op.name: # We just inserted this one. - continue - graph_def.node.append(op.node_def) - name_to_node_map[op.node_def.name] = op.node_def - - def add_show_subparser(subparsers): """Add parser for `show`.""" show_msg = ( @@ -1482,61 +1134,6 @@ def create_parser(): return parser -def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed): - """Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto. - - Args: - signature_def: Instance of `SignatureDef`. - variable_nodes_to_feed: List NodeDefs corresponding to VarHandleOp, - the list of variables to feed. - - Returns: - An instance of `tf2xla.Config` proto. - - Raises: - RuntimeError: If TensorFlow was not compiled with XLA. - """ - from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top - - config = tf2xla_pb2.Config() - tensor_id = tf2xla_pb2.TensorId - - for name, input_ in signature_def.inputs.items(): - name = name.replace('/', '_') - name = 'feed_{}'.format(name) - (node_name, output_index) = _parse_tensor_name(input_.name) - output_index = int(output_index) - config.feed.append( - tf2xla_pb2.Feed( - id=tensor_id(node_name=node_name, output_index=output_index), - name=name, - type=input_.dtype, - shape=input_.tensor_shape)) - for name, output_ in signature_def.outputs.items(): - name = name.replace('/', '_') - name = 'fetch_{}'.format(name) - (node_name, output_index) = _parse_tensor_name(output_.name) - output_index = int(output_index) - config.fetch.append( - tf2xla_pb2.Fetch( - id=tensor_id(node_name=node_name, output_index=output_index), - name=name, - type=output_.dtype, - shape=output_.tensor_shape)) - for node in variable_nodes_to_feed: - name = node.name.replace('/', '_') - name = 'param_{}'.format(name) - config.variable.append( - tf2xla_pb2.Variable( - node_name=node.name, - name=name, - type=node.attr['dtype'].type, - shape=node.attr['shape'].shape, - readonly=True)) - - return config - - def main(): logging.set_verbosity(logging.INFO) parser = create_parser() diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index 12799394602..cc9e2f21ddc 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -733,6 +733,7 @@ Defined Functions: def __init__(self): self.var = variables.Variable(1.0, name='my_var') + self.write_var = variables.Variable(1.0, name='write_var') @def_function.function(input_signature=[ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), @@ -752,20 +753,32 @@ Defined Functions: del y return {'res': x + self.var} + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + ]) + def func_write(self, x, y): + del y + self.write_var.assign(x + self.var) + return {'res': self.write_var} + @parameterized.named_parameters( ('VariablesToFeedNone', '', 'func2'), ('VariablesToFeedAll', 'all', 'func2'), ('VariablesToFeedMyVar', 'my_var', 'func2'), - ('VariablesToFeedNoneLargeConstant', '', 'func3')) + ('VariablesToFeedNoneLargeConstant', '', 'func3'), + ('WriteToWriteVar', 'all', 'func_write'), + ) def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed, func): if not test.is_built_with_xla(): self.skipTest('Skipping test because XLA is not compiled in.') saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model') dummy_model = self.AOTCompileDummyModel() - func = dummy_model.func2 if func == 'func2' else dummy_model.func3 + func = getattr(dummy_model, func) with self.cached_session(): self.evaluate(dummy_model.var.initializer) + self.evaluate(dummy_model.write_var.initializer) save.save(dummy_model, saved_model_dir, signatures={'func': func}) self.parser = saved_model_cli.create_parser() @@ -793,7 +806,15 @@ Defined Functions: # arg_y got filtered out as it's not used by the output. self.assertNotIn('arg_feed_y_data', header_contents) if variables_to_feed: - self.assertIn('var_param_my_var', header_contents) + # Read-only-variables' setters preserve constness. + self.assertIn('set_var_param_my_var_data(const float', header_contents) + self.assertNotIn('set_var_param_my_var_data(float', header_contents) + if func == dummy_model.func_write: + # Writeable variables setters do not preserve constness. + self.assertIn('set_var_param_write_var_data(float', header_contents) + self.assertNotIn( + 'set_var_param_write_var_data(const float', header_contents) + makefile_contents = file_io.read_file_to_string( '{}_makefile.inc'.format(output_prefix)) self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)