Add SaveOptions.save_debug_info and associated plumbing for SavedModel V2.

(second try: original rolled back due to path sep mismatch in a unit test for Windows)

The main behavior change is that I chose to mangle the trace key to be:
  op_name "@" func_name

(global ops use func_name = '')
Originally it was just a simple concat: func_name op_name

I don't have a strong opinion on the specific mangling to be used, but I do believe that how it was done is collision prone and best to correct now. Specifically, I chose this form because:
a) func_name does not seem to have strong validation on its naming (and in practice can be quite varied in the presence of lambdas, etc)
b) op_name does have strong validation on its syntax and specifically excludes '@' (matches regex: [A-Za-z0-9.][A-Za-z0-9_.\\-/]*)

Given these points, what I propose should be collision free. I'm not sure I should be making this decision, though. Please advise.

The test coverage was pretty sparse for all of this and I tried to buff it up around my modifications.

PiperOrigin-RevId: 267674315
This commit is contained in:
A. Unique TensorFlower 2019-09-06 14:33:35 -07:00 committed by TensorFlower Gardener
parent 09279ed3cd
commit c974f04d09
17 changed files with 331 additions and 58 deletions

View File

@ -1139,9 +1139,10 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
for (int i = 0, e = original_nodes.size(); i != e; ++i) {
auto node_name = original_nodes[i];
auto func_name = (i < original_funcs.size()) ? original_funcs[i] : "";
// Use the catenation of function and node names as the lookup key. This
// is to match the utility of generating the GraphDebugInfo.
node_call_sites.push_back(node_name_to_call_site(func_name + node_name));
// Use the catenation of function and node names as the lookup key.
// This matches the way that the key is formed on the python side.
std::string key = node_name + "@" + func_name;
node_call_sites.push_back(node_name_to_call_site(key));
}
return mlir::FusedLoc::get(node_call_sites, context_);
}

View File

@ -1,6 +1,7 @@
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphDebugInfoProtos";
option java_multiple_files = true;
@ -37,5 +38,14 @@ message GraphDebugInfo {
repeated string files = 1;
// This maps a node name to a stack trace in the source code.
// The map key is a mangling of the containing function and op name with
// syntax:
// op.name '@' func_name
// For ops in the top-level graph, the func_name is the empty string.
// Note that op names are restricted to a small number of characters which
// exclude '@', making it impossible to collide keys of this form. Function
// names accept a much wider set of characters.
// It would be preferable to avoid mangling and use a tuple key of (op.name,
// func_name), but this is not supported with protocol buffers.
map<string, StackTrace> traces = 2;
}

View File

@ -1077,7 +1077,7 @@ class FromSessionTest(TestModels, parameterized.TestCase):
# Check the add node in the inlined function is included.
func = sess.graph.as_graph_def().library.function[0].signature.name
self.assertIn((func + 'add'), converter._debug_info.traces)
self.assertIn(('add@' + func), converter._debug_info.traces)
class FromFrozenGraphFile(test_util.TensorFlowTestCase):

View File

@ -38,12 +38,31 @@ _INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
_BAD_FILE_SUBSTRINGS = [
os.path.join("tensorflow", "python"),
os.path.join("tensorflow", "contrib"),
os.path.join("tensorflow_estimator", "python"),
os.path.join("tensorflow_estimator", "contrib"),
"<embedded",
# Remove the last three path components from this module's file (i.e.
# python/framework/error_interpolation.py) so that we have an absolute path
# prefix to the root of the installation.
_FRAMEWORK_COMMON_PREFIX = os.path.dirname(
os.path.dirname(os.path.dirname(__file__)))
# Sub-directories under the common prefix that are considered part of the
# framework.
_FRAMEWORK_PATH_PREFIXES = [
os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep,
os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep,
]
# Patterns of filename patterns that should be considered internal to
# the TensorFlow framework.
_FRAMEWORK_FILENAME_PATTERNS = [
re.compile(r"<embedded"),
]
# Patterns of filename patterns that should be considered external to
# TensorFlow regardless of framework prefix match.
_EXTERNAL_FILENAME_PATTERNS = [
# Explicitly treat test frames as not part of the framework.
re.compile(r"_test\.py$"),
]
@ -178,13 +197,39 @@ def _compute_colocation_summary_from_op(op, prefix=""):
# pylint: enable=protected-access
def _is_framework_filename(filename):
"""Returns whether a filename should be considered a part of the framework.
A file is part of the framework if it does not match a pattern in
_EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in
_FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix.
Args:
filename: A filename string.
Returns:
Whether the filename should be considered to be internal to the
TensorFlow framework for the purposes of reporting errors.
"""
for pattern in _EXTERNAL_FILENAME_PATTERNS:
if pattern.search(filename):
return False
for pattern in _FRAMEWORK_FILENAME_PATTERNS:
if pattern.search(filename):
return True
for prefix in _FRAMEWORK_PATH_PREFIXES:
if filename.startswith(prefix):
return True
return False
def _find_index_of_defining_frame_for_op(op):
"""Return index in op.traceback with first 'useful' frame.
This method reads through the stack stored in op.traceback looking for the
innermost frame which (hopefully) belongs to the caller. It accomplishes this
by rejecting frames whose filename appears to come from TensorFlow (see
error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings).
by rejecting frames deemed to be part of the TensorFlow framework (by
pattern matching the filename).
Args:
op: the Operation object for which we would like to find the defining
@ -201,8 +246,9 @@ def _find_index_of_defining_frame_for_op(op):
filenames = [frame.filename for frame in tf_traceback]
# We process the filenames from the innermost frame to outermost.
for idx, filename in enumerate(reversed(filenames)):
contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS]
if not any(contains_bad_substrings):
is_framework = _is_framework_filename(filename)
if not is_framework:
# Consider this to be the defining frame.
return size - idx - 1
return 0
@ -237,11 +283,13 @@ def _compute_useful_frames(op, num):
return op.traceback[outermost_included:innermost_excluded]
def create_graph_debug_info_def(operations):
def create_graph_debug_info_def(func_named_operations):
"""Construct and returns a `GraphDebugInfo` protocol buffer.
Args:
operations: An iterable of op.Operation objects having _traceback members.
func_named_operations: An iterable of (func_name, op.Operation) tuples
where the Operation instances have a _traceback members. The func_name
should be the empty string for operations in the top-level Graph.
Returns:
GraphDebugInfo protocol buffer.
@ -256,9 +304,9 @@ def create_graph_debug_info_def(operations):
# collects the unique file names.
all_file_names = set()
node_to_trace = {}
for func, op in operations:
for func_name, op in func_named_operations:
# Gets the stack trace of the operation and then the file location.
node_name = func + op.name
node_name = op.name + "@" + func_name
node_to_trace[node_name] = _compute_useful_frames(op, 10)
for frame in node_to_trace[node_name]:
all_file_names.add(frame.filename)

View File

@ -48,7 +48,7 @@ def _make_frame_with_filename(op, idx, filename):
def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
num_inner_tf_frames):
"""Replace op._traceback with a new traceback using special filenames."""
tf_filename = "%d" + error_interpolation._BAD_FILE_SUBSTRINGS[0]
tf_filename = error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "%d.py"
user_filename = os.path.join("%d", "my_favorite_file.py")
num_requested_frames = num_user_frames + num_inner_tf_frames
@ -122,10 +122,70 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
self.assertIn("No node-device colocations", summary)
# Note that the create_graph_debug_info_def needs to run on graph mode ops,
# so it is excluded from eager tests. Even when used in eager mode, it is
# via FunctionGraphs, and directly verifying in graph mode is the narrowest
# way to unit test the functionality.
@test_util.run_deprecated_v1
class CreateGraphDebugInfoDefTest(test.TestCase):
def setUp(self):
super(CreateGraphDebugInfoDefTest, self).setUp()
ops.reset_default_graph()
def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index):
self.assertIn(key, graph_debug_info.traces)
stack_trace = graph_debug_info.traces[key]
found_flc = None
for flc in stack_trace.file_line_cols:
if flc.file_index == file_index:
found_flc = flc
break
self.assertIsNotNone(found_flc,
"Could not find a stack trace entry for file")
return found_flc
def testStackTraceExtraction(self):
# Since the create_graph_debug_info_def() function does not actually
# do anything special with functions except name mangling, just verify
# it with a loose op and manually provided function name.
# The following ops *must* be on consecutive lines (it will be verified
# in the resulting trace).
# pyformat: disable
global_op = constant_op.constant(0, name="Global").op
op1 = constant_op.constant(1, name="One").op
op2 = constant_op.constant(2, name="Two").op
# pyformat: enable
export_ops = [("", global_op), ("func1", op1), ("func2", op2)]
graph_debug_info = error_interpolation.create_graph_debug_info_def(
export_ops)
this_file_index = -1
for file_index, file_name in enumerate(graph_debug_info.files):
if "{}error_interpolation_test.py".format(os.sep) in file_name:
this_file_index = file_index
self.assertGreaterEqual(
this_file_index, 0,
"Could not find this file in trace:" + repr(graph_debug_info))
# Verify the traces exist for each op.
global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@",
this_file_index)
op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1",
this_file_index)
op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2",
this_file_index)
global_line = global_flc.line
self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line")
self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line")
@test_util.run_deprecated_v1
class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
def setUp(self):
super(InterpolateFilenamesAndLineNumbersTest, self).setUp()
ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
constant_op.constant(1, name="One")
@ -133,17 +193,6 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
three = constant_op.constant(3, name="Three")
self.graph = three.graph
# Change the list of bad file substrings so that constant_op.py is chosen
# as the defining stack frame for constant_op.constant ops.
self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS
error_interpolation._BAD_FILE_SUBSTRINGS = [
"%sops.py" % os.sep,
"%sutil" % os.sep,
]
def tearDown(self):
error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings
def testFindIndexOfDefiningFrameForOp(self):
local_op = constant_op.constant(42).op
user_filename = "hope.py"
@ -187,58 +236,50 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
two_tags_no_seps = "{{node One}}{{node Three}}"
interpolated_string = error_interpolation.interpolate(
two_tags_no_seps, self.graph)
self.assertRegexpMatches(interpolated_string,
"constant_op.py:[0-9]+.*constant_op.py:[0-9]+")
self.assertRegexpMatches(
interpolated_string, r"error_interpolation_test\.py:[0-9]+."
r"*error_interpolation_test\.py:[0-9]+")
def testTwoTagsWithSeps(self):
two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph)
expected_regex = (
r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegexpMatches(interpolated_string, expected_regex)
def testNewLine(self):
newline = "\n\n{{node One}}"
interpolated_string = error_interpolation.interpolate(newline, self.graph)
self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
self.assertRegexpMatches(interpolated_string,
r"error_interpolation_test\.py:[0-9]+.*")
@test_util.run_deprecated_v1
class InputNodesTest(test.TestCase):
def setUp(self):
super(InputNodesTest, self).setUp()
# Add nodes to the graph for retrieval by name later.
one = constant_op.constant(1, name="One")
two = constant_op.constant(2, name="Two")
three = math_ops.add(one, two, name="Three")
self.graph = three.graph
# Change the list of bad file substrings so that constant_op.py is chosen
# as the defining stack frame for constant_op.constant ops.
self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS
error_interpolation._BAD_FILE_SUBSTRINGS = [
"%sops.py" % os.sep,
"%sutil" % os.sep,
]
def tearDown(self):
error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings
def testNoInputs(self):
two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;"
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph)
expected_regex = (
r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegexpMatches(interpolated_string, expected_regex)
def testBasicInputs(self):
tag = ";;;{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(tag, self.graph)
expected_regex = re.compile(
r"^;;;.*op_def_library.py:[0-9]+\) ;;;.*Input.*constant_op.py:[0-9]+\)",
re.DOTALL)
r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL)
self.assertRegexpMatches(interpolated_string, expected_regex)
@ -249,6 +290,7 @@ class InterpolateDeviceSummaryTest(test.TestCase):
return "/cpu:*"
def setUp(self):
super(InterpolateDeviceSummaryTest, self).setUp()
ops.reset_default_graph()
self.zero = constant_op.constant([0.0], name="zero")
with ops.device("/cpu"):
@ -290,6 +332,7 @@ class InterpolateDeviceSummaryTest(test.TestCase):
class InterpolateColocationSummaryTest(test.TestCase):
def setUp(self):
super(InterpolateColocationSummaryTest, self).setUp()
ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
node_one = constant_op.constant(1, name="One")
@ -337,5 +380,23 @@ class InterpolateColocationSummaryTest(test.TestCase):
self.assertNotIn("Two", result)
class IsFrameworkFilenameTest(test.TestCase):
def testAllowsUnitTests(self):
self.assertFalse(
error_interpolation._is_framework_filename(
error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "foobar_test.py"))
def testFrameworkPythonFile(self):
self.assertTrue(
error_interpolation._is_framework_filename(
error_interpolation.__file__))
def testEmbedded(self):
self.assertTrue(
error_interpolation._is_framework_filename(
"<embedded stdlib>/context_lib.py"))
if __name__ == "__main__":
test.main()

View File

@ -741,11 +741,11 @@ class ScopedMetaGraphTest(test.TestCase):
biases1 = resource_variable_ops.ResourceVariable(
[0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
operations = []
func_named_operations = []
for op in graph1.get_operations():
operations.append(("", op))
func_named_operations.append(("", op))
debug_info_def = error_interpolation.create_graph_debug_info_def(
operations=operations)
func_named_operations)
# The unique file names in all the stack traces should be larger or equal
# than 1.

View File

@ -323,9 +323,12 @@ tf_py_test(
additional_deps = [
":loader",
":save",
":save_options",
":signature_constants",
":tag_constants",
"@absl_py//absl/testing:parameterized",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:error_interpolation",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
],

View File

@ -86,6 +86,25 @@ tf_export(
"saved_model.constants.SAVED_MODEL_FILENAME_PBTXT"
]).export_constant(__name__, "SAVED_MODEL_FILENAME_PBTXT")
# Subdirectory where debugging related files are written.
DEBUG_DIRECTORY = "debug"
tf_export(
"saved_model.DEBUG_DIRECTORY",
v1=[
"saved_model.DEBUG_DIRECTORY",
"saved_model.constants.DEBUG_DIRECTORY",
]).export_constant(__name__, "DEBUG_DIRECTORY")
# File name for GraphDebugInfo protocol buffer which corresponds to the
# SavedModel.
DEBUG_INFO_FILENAME_PB = "saved_model_debug_info.pb"
tf_export(
"saved_model.DEBUG_INFO_FILENAME_PB",
v1=[
"saved_model.DEBUG_INFO_FILENAME_PB",
"saved_model.constants.DEBUG_INFO_FILENAME_PB"
]).export_constant(__name__, "DEBUG_INFO_FILENAME_PB")
# File name for json format of SavedModel.
# Not exported while keras_saved_model is in contrib.
SAVED_MODEL_FILENAME_JSON = "saved_model.json"

View File

@ -31,6 +31,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@ -546,7 +547,8 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
namespace_whitelist: List of strings containing whitelisted op namespaces.
Returns:
An _AssetInfo, which contains information to help creating the SavedModel.
A tuple of (_AssetInfo, Graph) containing the captured assets and
exported Graph generated from tracing the saveable_view.
"""
# List objects from the eager context to make sure Optimizers give us the
# right Graph-dependent variables.
@ -700,6 +702,28 @@ def _write_object_proto(obj, proto, asset_file_def_index):
proto.user_object.CopyFrom(registered_type_proto)
def _export_debug_info(exported_graph):
"""Exports debug information from a graph.
Args:
exported_graph: A Graph that has been created by tracing a saveable view.
Returns:
Corresponding GraphDebugInfo with traces for ops in all functions of the
exported_graph.
"""
exported_operations = []
for fn_name in exported_graph._functions: # pylint: disable=protected-access
fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access
if not isinstance(fn, defun._EagerDefinedFunction): # pylint: disable=protected-access
continue
fn_graph = fn.graph
for fn_op in fn_graph.get_operations():
exported_operations.append((fn_name, fn_op))
return error_interpolation.create_graph_debug_info_def(exported_operations)
@tf_export("saved_model.save",
v1=["saved_model.save", "saved_model.experimental.save"])
def save(obj, export_dir, signatures=None, options=None):
@ -907,6 +931,16 @@ def save(obj, export_dir, signatures=None, options=None):
saveable_view, asset_info.asset_index)
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
file_io.atomic_write_string_to_file(path, saved_model.SerializeToString())
# Save debug info, if requested.
if options.save_debug_info:
graph_debug_info = _export_debug_info(exported_graph)
file_io.atomic_write_string_to_file(
os.path.join(
utils_impl.get_or_create_debug_dir(export_dir),
constants.DEBUG_INFO_FILENAME_PB),
graph_debug_info.SerializeToString())
# Clean reference cycles so repeated export()s don't make work for the garbage
# collector. Before this point we need to keep references to captured
# constants in the saved graph.

View File

@ -33,9 +33,9 @@ class SaveOptions(object):
"""
# Define object attributes in __slots__ for improved memory and performance.
__slots__ = ("namespace_whitelist",)
__slots__ = ("namespace_whitelist", "save_debug_info")
def __init__(self, namespace_whitelist=None):
def __init__(self, namespace_whitelist=None, save_debug_info=False):
"""Creates an object that stores options for SavedModel saving.
Args:
@ -43,9 +43,14 @@ class SaveOptions(object):
when saving a model. Saving an object that uses namespaced ops must
explicitly add all namespaces to the whitelist. The namespaced ops must
be registered into the framework when loading the SavedModel.
save_debug_info: Boolean indicating whether debug information is saved.
If True, then a debug/saved_model_debug_info.pb file will be written
with the contents of a GraphDebugInfo binary protocol buffer containing
stack trace information for all ops and functions that are saved.
"""
self.namespace_whitelist = _validate_namespace_whitelist(
namespace_whitelist)
self.save_debug_info = save_debug_info
def _validate_namespace_whitelist(namespace_whitelist):

View File

@ -24,6 +24,7 @@ import sys
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
@ -48,6 +49,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training.tracking import tracking
@ -415,6 +417,48 @@ class SavingOptionsTest(test.TestCase):
save._verify_ops(graph_def, [])
save._verify_ops(graph_def, ["Test"])
def test_save_debug_info_enabled(self):
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"),
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(
root,
save_dir,
root.f,
options=save_options.SaveOptions(save_debug_info=True))
debug_info_file_name = os.path.join(save_dir, "debug",
"saved_model_debug_info.pb")
self.assertTrue(os.path.exists(debug_info_file_name))
debug_info = graph_debug_info_pb2.GraphDebugInfo()
with open(debug_info_file_name, "rb") as f:
debug_info.ParseFromString(f.read())
# Verify that there is a trace for DEBUG_INFO_OP just to ensure that
# function debug info tracing is nominally functioning.
found_op = False
for key in debug_info.traces.keys():
if key.startswith("DEBUG_INFO_OP@"):
found_op = True
break
self.assertTrue(found_op, "Did not find DEBUG_INFO_OP in trace")
def test_save_debug_info_disabled(self):
root = tracking.AutoTrackable()
root.f = def_function.function(
lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"),
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(
root,
save_dir,
root.f,
options=save_options.SaveOptions(save_debug_info=False))
debug_info_file_name = os.path.join(save_dir, "debug",
"saved_model_debug_info.pb")
self.assertFalse(os.path.exists(debug_info_file_name))
class AssetTests(test.TestCase):

View File

@ -244,3 +244,19 @@ def get_assets_dir(export_dir):
return os.path.join(
compat.as_text(export_dir),
compat.as_text(constants.ASSETS_DIRECTORY))
def get_or_create_debug_dir(export_dir):
"""Returns path to the debug sub-directory, creating if it does not exist."""
debug_dir = get_debug_dir(export_dir)
if not file_io.file_exists(debug_dir):
file_io.recursive_create_dir(debug_dir)
return debug_dir
def get_debug_dir(export_dir):
"""Returns path to the debug sub-directory in the SavedModel."""
return os.path.join(
compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY))

View File

@ -6,8 +6,12 @@ tf_class {
name: "namespace_whitelist"
mtype: "<type \'member_descriptor\'>"
}
member {
name: "save_debug_info"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'namespace_whitelist\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
}

View File

@ -8,6 +8,14 @@ tf_module {
name: "ASSETS_KEY"
mtype: "<type \'str\'>"
}
member {
name: "DEBUG_DIRECTORY"
mtype: "<type \'str\'>"
}
member {
name: "DEBUG_INFO_FILENAME_PB"
mtype: "<type \'str\'>"
}
member {
name: "LEGACY_INIT_OP_KEY"
mtype: "<type \'str\'>"

View File

@ -32,6 +32,14 @@ tf_module {
name: "CLASSIFY_OUTPUT_SCORES"
mtype: "<type \'str\'>"
}
member {
name: "DEBUG_DIRECTORY"
mtype: "<type \'str\'>"
}
member {
name: "DEBUG_INFO_FILENAME_PB"
mtype: "<type \'str\'>"
}
member {
name: "DEFAULT_SERVING_SIGNATURE_DEF_KEY"
mtype: "<type \'str\'>"

View File

@ -6,8 +6,12 @@ tf_class {
name: "namespace_whitelist"
mtype: "<type \'member_descriptor\'>"
}
member {
name: "save_debug_info"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'namespace_whitelist\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\', \'namespace_whitelist\', \'save_debug_info\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
}

View File

@ -28,6 +28,14 @@ tf_module {
name: "CLASSIFY_OUTPUT_SCORES"
mtype: "<type \'str\'>"
}
member {
name: "DEBUG_DIRECTORY"
mtype: "<type \'str\'>"
}
member {
name: "DEBUG_INFO_FILENAME_PB"
mtype: "<type \'str\'>"
}
member {
name: "DEFAULT_SERVING_SIGNATURE_DEF_KEY"
mtype: "<type \'str\'>"