Add SaveOptions.save_debug_info and associated plumbing for SavedModel V2.
(second try: original rolled back due to path sep mismatch in a unit test for Windows) The main behavior change is that I chose to mangle the trace key to be: op_name "@" func_name (global ops use func_name = '') Originally it was just a simple concat: func_name op_name I don't have a strong opinion on the specific mangling to be used, but I do believe that how it was done is collision prone and best to correct now. Specifically, I chose this form because: a) func_name does not seem to have strong validation on its naming (and in practice can be quite varied in the presence of lambdas, etc) b) op_name does have strong validation on its syntax and specifically excludes '@' (matches regex: [A-Za-z0-9.][A-Za-z0-9_.\\-/]*) Given these points, what I propose should be collision free. I'm not sure I should be making this decision, though. Please advise. The test coverage was pretty sparse for all of this and I tried to buff it up around my modifications. PiperOrigin-RevId: 267674315
This commit is contained in:
parent
09279ed3cd
commit
c974f04d09
@ -1139,9 +1139,10 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
||||
for (int i = 0, e = original_nodes.size(); i != e; ++i) {
|
||||
auto node_name = original_nodes[i];
|
||||
auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
|
||||
// Use the catenation of function and node names as the lookup key. This
|
||||
// is to match the utility of generating the GraphDebugInfo.
|
||||
node_call_sites.push_back(node_name_to_call_site(func_name + node_name));
|
||||
// Use the catenation of function and node names as the lookup key.
|
||||
// This matches the way that the key is formed on the python side.
|
||||
std::string key = node_name + "@" + func_name;
|
||||
node_call_sites.push_back(node_name_to_call_site(key));
|
||||
}
|
||||
return mlir::FusedLoc::get(node_call_sites, context_);
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "GraphDebugInfoProtos";
|
||||
option java_multiple_files = true;
|
||||
@ -37,5 +38,14 @@ message GraphDebugInfo {
|
||||
repeated string files = 1;
|
||||
|
||||
// This maps a node name to a stack trace in the source code.
|
||||
// The map key is a mangling of the containing function and op name with
|
||||
// syntax:
|
||||
// op.name '@' func_name
|
||||
// For ops in the top-level graph, the func_name is the empty string.
|
||||
// Note that op names are restricted to a small number of characters which
|
||||
// exclude '@', making it impossible to collide keys of this form. Function
|
||||
// names accept a much wider set of characters.
|
||||
// It would be preferable to avoid mangling and use a tuple key of (op.name,
|
||||
// func_name), but this is not supported with protocol buffers.
|
||||
map<string, StackTrace> traces = 2;
|
||||
}
|
||||
|
@ -1077,7 +1077,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
|
||||
|
||||
# Check the add node in the inlined function is included.
|
||||
func = sess.graph.as_graph_def().library.function[0].signature.name
|
||||
self.assertIn((func + 'add'), converter._debug_info.traces)
|
||||
self.assertIn(('add@' + func), converter._debug_info.traces)
|
||||
|
||||
|
||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||
|
@ -38,12 +38,31 @@ _INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
|
||||
|
||||
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
|
||||
|
||||
_BAD_FILE_SUBSTRINGS = [
|
||||
os.path.join("tensorflow", "python"),
|
||||
os.path.join("tensorflow", "contrib"),
|
||||
os.path.join("tensorflow_estimator", "python"),
|
||||
os.path.join("tensorflow_estimator", "contrib"),
|
||||
"<embedded",
|
||||
|
||||
# Remove the last three path components from this module's file (i.e.
|
||||
# python/framework/error_interpolation.py) so that we have an absolute path
|
||||
# prefix to the root of the installation.
|
||||
_FRAMEWORK_COMMON_PREFIX = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# Sub-directories under the common prefix that are considered part of the
|
||||
# framework.
|
||||
_FRAMEWORK_PATH_PREFIXES = [
|
||||
os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep,
|
||||
os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep,
|
||||
]
|
||||
|
||||
# Patterns of filename patterns that should be considered internal to
|
||||
# the TensorFlow framework.
|
||||
_FRAMEWORK_FILENAME_PATTERNS = [
|
||||
re.compile(r"<embedded"),
|
||||
]
|
||||
|
||||
# Patterns of filename patterns that should be considered external to
|
||||
# TensorFlow regardless of framework prefix match.
|
||||
_EXTERNAL_FILENAME_PATTERNS = [
|
||||
# Explicitly treat test frames as not part of the framework.
|
||||
re.compile(r"_test\.py$"),
|
||||
]
|
||||
|
||||
|
||||
@ -178,13 +197,39 @@ def _compute_colocation_summary_from_op(op, prefix=""):
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _is_framework_filename(filename):
|
||||
"""Returns whether a filename should be considered a part of the framework.
|
||||
|
||||
A file is part of the framework if it does not match a pattern in
|
||||
_EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in
|
||||
_FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix.
|
||||
|
||||
Args:
|
||||
filename: A filename string.
|
||||
|
||||
Returns:
|
||||
Whether the filename should be considered to be internal to the
|
||||
TensorFlow framework for the purposes of reporting errors.
|
||||
"""
|
||||
for pattern in _EXTERNAL_FILENAME_PATTERNS:
|
||||
if pattern.search(filename):
|
||||
return False
|
||||
for pattern in _FRAMEWORK_FILENAME_PATTERNS:
|
||||
if pattern.search(filename):
|
||||
return True
|
||||
for prefix in _FRAMEWORK_PATH_PREFIXES:
|
||||
if filename.startswith(prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _find_index_of_defining_frame_for_op(op):
|
||||
"""Return index in op.traceback with first 'useful' frame.
|
||||
|
||||
This method reads through the stack stored in op.traceback looking for the
|
||||
innermost frame which (hopefully) belongs to the caller. It accomplishes this
|
||||
by rejecting frames whose filename appears to come from TensorFlow (see
|
||||
error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings).
|
||||
by rejecting frames deemed to be part of the TensorFlow framework (by
|
||||
pattern matching the filename).
|
||||
|
||||
Args:
|
||||
op: the Operation object for which we would like to find the defining
|
||||
@ -201,8 +246,9 @@ def _find_index_of_defining_frame_for_op(op):
|
||||
filenames = [frame.filename for frame in tf_traceback]
|
||||
# We process the filenames from the innermost frame to outermost.
|
||||
for idx, filename in enumerate(reversed(filenames)):
|
||||
contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS]
|
||||
if not any(contains_bad_substrings):
|
||||
is_framework = _is_framework_filename(filename)
|
||||
if not is_framework:
|
||||
# Consider this to be the defining frame.
|
||||
return size - idx - 1
|
||||
return 0
|
||||
|
||||
@ -237,11 +283,13 @@ def _compute_useful_frames(op, num):
|
||||
return op.traceback[outermost_included:innermost_excluded]
|
||||
|
||||
|
||||
def create_graph_debug_info_def(operations):
|
||||
def create_graph_debug_info_def(func_named_operations):
|
||||
"""Construct and returns a `GraphDebugInfo` protocol buffer.
|
||||
|
||||
Args:
|
||||
operations: An iterable of op.Operation objects having _traceback members.
|
||||
func_named_operations: An iterable of (func_name, op.Operation) tuples
|
||||
where the Operation instances have a _traceback members. The func_name
|
||||
should be the empty string for operations in the top-level Graph.
|
||||
|
||||
Returns:
|
||||
GraphDebugInfo protocol buffer.
|
||||
@ -256,9 +304,9 @@ def create_graph_debug_info_def(operations):
|
||||
# collects the unique file names.
|
||||
all_file_names = set()
|
||||
node_to_trace = {}
|
||||
for func, op in operations:
|
||||
for func_name, op in func_named_operations:
|
||||
# Gets the stack trace of the operation and then the file location.
|
||||
node_name = func + op.name
|
||||
node_name = op.name + "@" + func_name
|
||||
node_to_trace[node_name] = _compute_useful_frames(op, 10)
|
||||
for frame in node_to_trace[node_name]:
|
||||
all_file_names.add(frame.filename)
|
||||
|
@ -48,7 +48,7 @@ def _make_frame_with_filename(op, idx, filename):
|
||||
def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
|
||||
num_inner_tf_frames):
|
||||
"""Replace op._traceback with a new traceback using special filenames."""
|
||||
tf_filename = "%d" + error_interpolation._BAD_FILE_SUBSTRINGS[0]
|
||||
tf_filename = error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "%d.py"
|
||||
user_filename = os.path.join("%d", "my_favorite_file.py")
|
||||
|
||||
num_requested_frames = num_user_frames + num_inner_tf_frames
|
||||
@ -122,10 +122,70 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
|
||||
self.assertIn("No node-device colocations", summary)
|
||||
|
||||
|
||||
# Note that the create_graph_debug_info_def needs to run on graph mode ops,
|
||||
# so it is excluded from eager tests. Even when used in eager mode, it is
|
||||
# via FunctionGraphs, and directly verifying in graph mode is the narrowest
|
||||
# way to unit test the functionality.
|
||||
@test_util.run_deprecated_v1
|
||||
class CreateGraphDebugInfoDefTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CreateGraphDebugInfoDefTest, self).setUp()
|
||||
ops.reset_default_graph()
|
||||
|
||||
def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index):
|
||||
self.assertIn(key, graph_debug_info.traces)
|
||||
stack_trace = graph_debug_info.traces[key]
|
||||
found_flc = None
|
||||
for flc in stack_trace.file_line_cols:
|
||||
if flc.file_index == file_index:
|
||||
found_flc = flc
|
||||
break
|
||||
self.assertIsNotNone(found_flc,
|
||||
"Could not find a stack trace entry for file")
|
||||
return found_flc
|
||||
|
||||
def testStackTraceExtraction(self):
|
||||
# Since the create_graph_debug_info_def() function does not actually
|
||||
# do anything special with functions except name mangling, just verify
|
||||
# it with a loose op and manually provided function name.
|
||||
# The following ops *must* be on consecutive lines (it will be verified
|
||||
# in the resulting trace).
|
||||
# pyformat: disable
|
||||
global_op = constant_op.constant(0, name="Global").op
|
||||
op1 = constant_op.constant(1, name="One").op
|
||||
op2 = constant_op.constant(2, name="Two").op
|
||||
# pyformat: enable
|
||||
|
||||
export_ops = [("", global_op), ("func1", op1), ("func2", op2)]
|
||||
graph_debug_info = error_interpolation.create_graph_debug_info_def(
|
||||
export_ops)
|
||||
this_file_index = -1
|
||||
for file_index, file_name in enumerate(graph_debug_info.files):
|
||||
if "{}error_interpolation_test.py".format(os.sep) in file_name:
|
||||
this_file_index = file_index
|
||||
self.assertGreaterEqual(
|
||||
this_file_index, 0,
|
||||
"Could not find this file in trace:" + repr(graph_debug_info))
|
||||
|
||||
# Verify the traces exist for each op.
|
||||
global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@",
|
||||
this_file_index)
|
||||
op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1",
|
||||
this_file_index)
|
||||
op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2",
|
||||
this_file_index)
|
||||
|
||||
global_line = global_flc.line
|
||||
self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line")
|
||||
self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line")
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(InterpolateFilenamesAndLineNumbersTest, self).setUp()
|
||||
ops.reset_default_graph()
|
||||
# Add nodes to the graph for retrieval by name later.
|
||||
constant_op.constant(1, name="One")
|
||||
@ -133,17 +193,6 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
||||
three = constant_op.constant(3, name="Three")
|
||||
self.graph = three.graph
|
||||
|
||||
# Change the list of bad file substrings so that constant_op.py is chosen
|
||||
# as the defining stack frame for constant_op.constant ops.
|
||||
self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS
|
||||
error_interpolation._BAD_FILE_SUBSTRINGS = [
|
||||
"%sops.py" % os.sep,
|
||||
"%sutil" % os.sep,
|
||||
]
|
||||
|
||||
def tearDown(self):
|
||||
error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings
|
||||
|
||||
def testFindIndexOfDefiningFrameForOp(self):
|
||||
local_op = constant_op.constant(42).op
|
||||
user_filename = "hope.py"
|
||||
@ -187,58 +236,50 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
|
||||
two_tags_no_seps = "{{node One}}{{node Three}}"
|
||||
interpolated_string = error_interpolation.interpolate(
|
||||
two_tags_no_seps, self.graph)
|
||||
self.assertRegexpMatches(interpolated_string,
|
||||
"constant_op.py:[0-9]+.*constant_op.py:[0-9]+")
|
||||
self.assertRegexpMatches(
|
||||
interpolated_string, r"error_interpolation_test\.py:[0-9]+."
|
||||
r"*error_interpolation_test\.py:[0-9]+")
|
||||
|
||||
def testTwoTagsWithSeps(self):
|
||||
two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
|
||||
interpolated_string = error_interpolation.interpolate(
|
||||
two_tags_with_seps, self.graph)
|
||||
expected_regex = (
|
||||
r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
|
||||
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
|
||||
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
|
||||
self.assertRegexpMatches(interpolated_string, expected_regex)
|
||||
|
||||
def testNewLine(self):
|
||||
newline = "\n\n{{node One}}"
|
||||
interpolated_string = error_interpolation.interpolate(newline, self.graph)
|
||||
self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
|
||||
self.assertRegexpMatches(interpolated_string,
|
||||
r"error_interpolation_test\.py:[0-9]+.*")
|
||||
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
class InputNodesTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(InputNodesTest, self).setUp()
|
||||
# Add nodes to the graph for retrieval by name later.
|
||||
one = constant_op.constant(1, name="One")
|
||||
two = constant_op.constant(2, name="Two")
|
||||
three = math_ops.add(one, two, name="Three")
|
||||
self.graph = three.graph
|
||||
|
||||
# Change the list of bad file substrings so that constant_op.py is chosen
|
||||
# as the defining stack frame for constant_op.constant ops.
|
||||
self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS
|
||||
error_interpolation._BAD_FILE_SUBSTRINGS = [
|
||||
"%sops.py" % os.sep,
|
||||
"%sutil" % os.sep,
|
||||
]
|
||||
|
||||
def tearDown(self):
|
||||
error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings
|
||||
|
||||
def testNoInputs(self):
|
||||
two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;"
|
||||
interpolated_string = error_interpolation.interpolate(
|
||||
two_tags_with_seps, self.graph)
|
||||
expected_regex = (
|
||||
r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
|
||||
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
|
||||
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
|
||||
self.assertRegexpMatches(interpolated_string, expected_regex)
|
||||
|
||||
def testBasicInputs(self):
|
||||
tag = ";;;{{node Three}};;;"
|
||||
interpolated_string = error_interpolation.interpolate(tag, self.graph)
|
||||
expected_regex = re.compile(
|
||||
r"^;;;.*op_def_library.py:[0-9]+\) ;;;.*Input.*constant_op.py:[0-9]+\)",
|
||||
re.DOTALL)
|
||||
r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
|
||||
r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL)
|
||||
self.assertRegexpMatches(interpolated_string, expected_regex)
|
||||
|
||||
|
||||
@ -249,6 +290,7 @@ class InterpolateDeviceSummaryTest(test.TestCase):
|
||||
return "/cpu:*"
|
||||
|
||||
def setUp(self):
|
||||
super(InterpolateDeviceSummaryTest, self).setUp()
|
||||
ops.reset_default_graph()
|
||||
self.zero = constant_op.constant([0.0], name="zero")
|
||||
with ops.device("/cpu"):
|
||||
@ -290,6 +332,7 @@ class InterpolateDeviceSummaryTest(test.TestCase):
|
||||
class InterpolateColocationSummaryTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(InterpolateColocationSummaryTest, self).setUp()
|
||||
ops.reset_default_graph()
|
||||
# Add nodes to the graph for retrieval by name later.
|
||||
node_one = constant_op.constant(1, name="One")
|
||||
@ -337,5 +380,23 @@ class InterpolateColocationSummaryTest(test.TestCase):
|
||||
self.assertNotIn("Two", result)
|
||||
|
||||
|
||||
class IsFrameworkFilenameTest(test.TestCase):
|
||||
|
||||
def testAllowsUnitTests(self):
|
||||
self.assertFalse(
|
||||
error_interpolation._is_framework_filename(
|
||||
error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "foobar_test.py"))
|
||||
|
||||
def testFrameworkPythonFile(self):
|
||||
self.assertTrue(
|
||||
error_interpolation._is_framework_filename(
|
||||
error_interpolation.__file__))
|
||||
|
||||
def testEmbedded(self):
|
||||
self.assertTrue(
|
||||
error_interpolation._is_framework_filename(
|
||||
"<embedded stdlib>/context_lib.py"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -741,11 +741,11 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
biases1 = resource_variable_ops.ResourceVariable(
|
||||
[0.1] * 3, name="biases")
|
||||
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
|
||||
operations = []
|
||||
func_named_operations = []
|
||||
for op in graph1.get_operations():
|
||||
operations.append(("", op))
|
||||
func_named_operations.append(("", op))
|
||||
debug_info_def = error_interpolation.create_graph_debug_info_def(
|
||||
operations=operations)
|
||||
func_named_operations)
|
||||
|
||||
# The unique file names in all the stack traces should be larger or equal
|
||||
# than 1.
|
||||
|
@ -323,9 +323,12 @@ tf_py_test(
|
||||
additional_deps = [
|
||||
":loader",
|
||||
":save",
|
||||
":save_options",
|
||||
":signature_constants",
|
||||
":tag_constants",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:error_interpolation",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
|
@ -86,6 +86,25 @@ tf_export(
|
||||
"saved_model.constants.SAVED_MODEL_FILENAME_PBTXT"
|
||||
]).export_constant(__name__, "SAVED_MODEL_FILENAME_PBTXT")
|
||||
|
||||
# Subdirectory where debugging related files are written.
|
||||
DEBUG_DIRECTORY = "debug"
|
||||
tf_export(
|
||||
"saved_model.DEBUG_DIRECTORY",
|
||||
v1=[
|
||||
"saved_model.DEBUG_DIRECTORY",
|
||||
"saved_model.constants.DEBUG_DIRECTORY",
|
||||
]).export_constant(__name__, "DEBUG_DIRECTORY")
|
||||
|
||||
# File name for GraphDebugInfo protocol buffer which corresponds to the
|
||||
# SavedModel.
|
||||
DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"
|
||||
tf_export(
|
||||
"saved_model.DEBUG_INFO_FILENAME_PB",
|
||||
v1=[
|
||||
"saved_model.DEBUG_INFO_FILENAME_PB",
|
||||
"saved_model.constants.DEBUG_INFO_FILENAME_PB"
|
||||
]).export_constant(__name__, "DEBUG_INFO_FILENAME_PB")
|
||||
|
||||
# File name for json format of SavedModel.
|
||||
# Not exported while keras_saved_model is in contrib.
|
||||
SAVED_MODEL_FILENAME_JSON = "saved_model.json"
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import error_interpolation
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -546,7 +547,8 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
|
||||
namespace_whitelist: List of strings containing whitelisted op namespaces.
|
||||
|
||||
Returns:
|
||||
An _AssetInfo, which contains information to help creating the SavedModel.
|
||||
A tuple of (_AssetInfo, Graph) containing the captured assets and
|
||||
exported Graph generated from tracing the saveable_view.
|
||||
"""
|
||||
# List objects from the eager context to make sure Optimizers give us the
|
||||
# right Graph-dependent variables.
|
||||
@ -700,6 +702,28 @@ def _write_object_proto(obj, proto, asset_file_def_index):
|
||||
proto.user_object.CopyFrom(registered_type_proto)
|
||||
|
||||
|
||||
def _export_debug_info(exported_graph):
|
||||
"""Exports debug information from a graph.
|
||||
|
||||
Args:
|
||||
exported_graph: A Graph that has been created by tracing a saveable view.
|
||||
|
||||
Returns:
|
||||
Corresponding GraphDebugInfo with traces for ops in all functions of the
|
||||
exported_graph.
|
||||
"""
|
||||
exported_operations = []
|
||||
for fn_name in exported_graph._functions: # pylint: disable=protected-access
|
||||
fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access
|
||||
if not isinstance(fn, defun._EagerDefinedFunction): # pylint: disable=protected-access
|
||||
continue
|
||||
|
||||
fn_graph = fn.graph
|
||||
for fn_op in fn_graph.get_operations():
|
||||
exported_operations.append((fn_name, fn_op))
|
||||
return error_interpolation.create_graph_debug_info_def(exported_operations)
|
||||
|
||||
|
||||
@tf_export("saved_model.save",
|
||||
v1=["saved_model.save", "saved_model.experimental.save"])
|
||||
def save(obj, export_dir, signatures=None, options=None):
|
||||
@ -907,6 +931,16 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||
saveable_view, asset_info.asset_index)
|
||||
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
||||
file_io.atomic_write_string_to_file(path, saved_model.SerializeToString())
|
||||
|
||||
# Save debug info, if requested.
|
||||
if options.save_debug_info:
|
||||
graph_debug_info = _export_debug_info(exported_graph)
|
||||
file_io.atomic_write_string_to_file(
|
||||
os.path.join(
|
||||
utils_impl.get_or_create_debug_dir(export_dir),
|
||||
constants.DEBUG_INFO_FILENAME_PB),
|
||||
graph_debug_info.SerializeToString())
|
||||
|
||||
# Clean reference cycles so repeated export()s don't make work for the garbage
|
||||
# collector. Before this point we need to keep references to captured
|
||||
# constants in the saved graph.
|
||||
|
@ -33,9 +33,9 @@ class SaveOptions(object):
|
||||
"""
|
||||
|
||||
# Define object attributes in __slots__ for improved memory and performance.
|
||||
__slots__ = ("namespace_whitelist",)
|
||||
__slots__ = ("namespace_whitelist", "save_debug_info")
|
||||
|
||||
def __init__(self, namespace_whitelist=None):
|
||||
def __init__(self, namespace_whitelist=None, save_debug_info=False):
|
||||
"""Creates an object that stores options for SavedModel saving.
|
||||
|
||||
Args:
|
||||
@ -43,9 +43,14 @@ class SaveOptions(object):
|
||||
when saving a model. Saving an object that uses namespaced ops must
|
||||
explicitly add all namespaces to the whitelist. The namespaced ops must
|
||||
be registered into the framework when loading the SavedModel.
|
||||
save_debug_info: Boolean indicating whether debug information is saved.
|
||||
If True, then a debug/saved_model_debug_info.pb file will be written
|
||||
with the contents of a GraphDebugInfo binary protocol buffer containing
|
||||
stack trace information for all ops and functions that are saved.
|
||||
"""
|
||||
self.namespace_whitelist = _validate_namespace_whitelist(
|
||||
namespace_whitelist)
|
||||
self.save_debug_info = save_debug_info
|
||||
|
||||
|
||||
def _validate_namespace_whitelist(namespace_whitelist):
|
||||
|
@ -24,6 +24,7 @@ import sys
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import backprop
|
||||
@ -48,6 +49,7 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
@ -415,6 +417,48 @@ class SavingOptionsTest(test.TestCase):
|
||||
save._verify_ops(graph_def, [])
|
||||
save._verify_ops(graph_def, ["Test"])
|
||||
|
||||
def test_save_debug_info_enabled(self):
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(
|
||||
lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"),
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(
|
||||
root,
|
||||
save_dir,
|
||||
root.f,
|
||||
options=save_options.SaveOptions(save_debug_info=True))
|
||||
debug_info_file_name = os.path.join(save_dir, "debug",
|
||||
"saved_model_debug_info.pb")
|
||||
self.assertTrue(os.path.exists(debug_info_file_name))
|
||||
debug_info = graph_debug_info_pb2.GraphDebugInfo()
|
||||
with open(debug_info_file_name, "rb") as f:
|
||||
debug_info.ParseFromString(f.read())
|
||||
|
||||
# Verify that there is a trace for DEBUG_INFO_OP just to ensure that
|
||||
# function debug info tracing is nominally functioning.
|
||||
found_op = False
|
||||
for key in debug_info.traces.keys():
|
||||
if key.startswith("DEBUG_INFO_OP@"):
|
||||
found_op = True
|
||||
break
|
||||
self.assertTrue(found_op, "Did not find DEBUG_INFO_OP in trace")
|
||||
|
||||
def test_save_debug_info_disabled(self):
|
||||
root = tracking.AutoTrackable()
|
||||
root.f = def_function.function(
|
||||
lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"),
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(
|
||||
root,
|
||||
save_dir,
|
||||
root.f,
|
||||
options=save_options.SaveOptions(save_debug_info=False))
|
||||
debug_info_file_name = os.path.join(save_dir, "debug",
|
||||
"saved_model_debug_info.pb")
|
||||
self.assertFalse(os.path.exists(debug_info_file_name))
|
||||
|
||||
|
||||
class AssetTests(test.TestCase):
|
||||
|
||||
|
@ -244,3 +244,19 @@ def get_assets_dir(export_dir):
|
||||
return os.path.join(
|
||||
compat.as_text(export_dir),
|
||||
compat.as_text(constants.ASSETS_DIRECTORY))
|
||||
|
||||
|
||||
def get_or_create_debug_dir(export_dir):
|
||||
"""Returns path to the debug sub-directory, creating if it does not exist."""
|
||||
debug_dir = get_debug_dir(export_dir)
|
||||
|
||||
if not file_io.file_exists(debug_dir):
|
||||
file_io.recursive_create_dir(debug_dir)
|
||||
|
||||
return debug_dir
|
||||
|
||||
|
||||
def get_debug_dir(export_dir):
|
||||
"""Returns path to the debug sub-directory in the SavedModel."""
|
||||
return os.path.join(
|
||||
compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY))
|
||||
|
@ -6,8 +6,12 @@ tf_class {
|
||||
name: "namespace_whitelist"
|
||||
mtype: "<type \'member_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "save_debug_info"
|
||||
mtype: "<type \'member_descriptor\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'namespace_whitelist\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,14 @@ tf_module {
|
||||
name: "ASSETS_KEY"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEBUG_DIRECTORY"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEBUG_INFO_FILENAME_PB"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "LEGACY_INIT_OP_KEY"
|
||||
mtype: "<type \'str\'>"
|
||||
|
@ -32,6 +32,14 @@ tf_module {
|
||||
name: "CLASSIFY_OUTPUT_SCORES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEBUG_DIRECTORY"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEBUG_INFO_FILENAME_PB"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEFAULT_SERVING_SIGNATURE_DEF_KEY"
|
||||
mtype: "<type \'str\'>"
|
||||
|
@ -6,8 +6,12 @@ tf_class {
|
||||
name: "namespace_whitelist"
|
||||
mtype: "<type \'member_descriptor\'>"
|
||||
}
|
||||
member {
|
||||
name: "save_debug_info"
|
||||
mtype: "<type \'member_descriptor\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'namespace_whitelist\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
|
||||
}
|
||||
}
|
||||
|
@ -28,6 +28,14 @@ tf_module {
|
||||
name: "CLASSIFY_OUTPUT_SCORES"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEBUG_DIRECTORY"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEBUG_INFO_FILENAME_PB"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "DEFAULT_SERVING_SIGNATURE_DEF_KEY"
|
||||
mtype: "<type \'str\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user