Collect node debug information for frozen graphs

This CL added the debug information support for the nodes in the frozen graphs
which are GraphDefs and will be sent to the new tf-tflite converter. A GraphDef
only serializes the node name from the original Graph object, but the whole
stack track defining the node will miss. So to collect the stack trace (debug
information) for the nodes in the GraphDef, a few changes made in this CL:

- For TFLiteConverter (v1), an experimental function, which create Graph Debug
  info from the original graph object, is passed to the converter constructor
  in addition to the GraphDef, so we can retrive the stack trace for the nodes
  from the GraphDef. (TFLiteConverterV2 isn't an issue because function object
  has passed to the constructor.)

- Propagate the original node name in the Grappler function inlining pass, so
  the original node name is stored in the GraphDef when a node is inlined. And
  we can use the stored name to look up the stack trace in the original graph.

- When a node name is looked up in the original graph, We need to consider the
  function library as well. For function libraries created by `@tf.function`
  and `@defun`, we use the sub-graphs in the original graph. However, function
  created by `@Defun` only has FunctionDef for the sub-graphs, so it isn't
  supported by this CL.

PiperOrigin-RevId: 253932770
This commit is contained in:
Feng Liu 2019-06-18 22:24:02 -07:00 committed by TensorFlower Gardener
parent 3c1287bffa
commit 2c171cdb26
15 changed files with 378 additions and 91 deletions

View File

@ -1494,6 +1494,27 @@ Status ValidateNoInline(const FunctionBody* fbody) {
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
// Propagate the debug info of `nodes` in function `func` to the `target` node.
// If the debug info of any node is missing, its node name and function name
// is used.
void PropagateDebugInfoToNode(const string& func,
const std::vector<const Node*>& nodes,
NodeDef* target) {
if (nodes.empty() || target->has_experimental_debug_info()) {
return;
}
for (const Node* node : nodes) {
const auto& node_def = node->def();
if (node_def.has_experimental_debug_info()) {
target->mutable_experimental_debug_info()->MergeFrom(
node_def.experimental_debug_info());
} else {
target->mutable_experimental_debug_info()->add_original_node_names(
node_def.name());
target->mutable_experimental_debug_info()->add_original_func_names(func);
}
}
}
} // namespace } // namespace
string InlineFunctionBodyOptions::DebugString() const { string InlineFunctionBodyOptions::DebugString() const {
@ -1719,6 +1740,7 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
if (options.initialize_empty_device && ndef.device().empty()) { if (options.initialize_empty_device && ndef.device().empty()) {
ndef.set_device(caller->def().device()); ndef.set_device(caller->def().device());
} }
PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef);
// Add the function node name as a prefix: // Add the function node name as a prefix:
// 1) to node name to avoid collisions // 1) to node name to avoid collisions

View File

@ -70,6 +70,15 @@ message NodeDef {
// be {A, B}. This information can be used to map errors originating at the // be {A, B}. This information can be used to map errors originating at the
// current node to some top level source code. // current node to some top level source code.
repeated string original_node_names = 1; repeated string original_node_names = 1;
// This is intended to store the list of names of the functions from the
// original graph that this node was derived. For example if this node, say
// C, was result of a fusion of node A in function FA and node B in function
// FB, then `original_funcs` would be {FA, FB}. If the node is in the top
// level graph, the `original_func` is empty. This information, with the
// `original_node_names` can be used to map errors originating at the
// current ndoe to some top level source code.
repeated string original_func_names = 2;
}; };
// This stores debug information associated with the node. // This stores debug information associated with the node.

View File

@ -197,7 +197,8 @@ def build_toco_convert_protos(input_tensors,
dump_graphviz_dir=None, dump_graphviz_dir=None,
dump_graphviz_video=False, dump_graphviz_video=False,
target_ops=None, target_ops=None,
allow_nonexistent_arrays=False): allow_nonexistent_arrays=False,
debug_info=None):
"""Builds protocol buffers describing a conversion of a model using TOCO. """Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@ -257,10 +258,12 @@ def build_toco_convert_protos(input_tensors,
(default set([OpsSet.TFLITE_BUILTINS])) (default set([OpsSet.TFLITE_BUILTINS]))
allow_nonexistent_arrays: Allow specifying array names that don't exist allow_nonexistent_arrays: Allow specifying array names that don't exist
or are unused in the final graph. (default False) or are unused in the final graph. (default False)
debug_info: `GraphDebugInfo` proto containing the stack traces for the
original nodes referred by the converted graph.
Returns: Returns:
model_flags, toco_flags: two protocol buffers describing the conversion model_flags, toco_flags, debug_info: three protocol buffers describing the
process. conversion process and debug information.
Raises: Raises:
ValueError: ValueError:
@ -319,7 +322,7 @@ def build_toco_convert_protos(input_tensors,
model.allow_nonexistent_arrays = allow_nonexistent_arrays model.allow_nonexistent_arrays = allow_nonexistent_arrays
return model, toco return model, toco, debug_info
def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays, def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
@ -350,7 +353,7 @@ def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
Raises: Raises:
Defined in `build_toco_convert_protos`. Defined in `build_toco_convert_protos`.
""" """
model_flags, toco_flags = build_toco_convert_protos( model_flags, toco_flags, _ = build_toco_convert_protos(
input_tensors=[], output_tensors=[], *args, **kwargs) input_tensors=[], output_tensors=[], *args, **kwargs)
for idx, (name, shape) in enumerate(input_arrays_with_shape): for idx, (name, shape) in enumerate(input_arrays_with_shape):
@ -397,7 +400,7 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
Raises: Raises:
Defined in `build_toco_convert_protos`. Defined in `build_toco_convert_protos`.
""" """
model_flags, toco_flags = build_toco_convert_protos( model_flags, toco_flags, _ = build_toco_convert_protos(
input_tensors, output_tensors, *args, **kwargs) input_tensors, output_tensors, *args, **kwargs)
data = toco_convert_protos(model_flags.SerializeToString(), data = toco_convert_protos(model_flags.SerializeToString(),
toco_flags.SerializeToString(), toco_flags.SerializeToString(),

View File

@ -173,6 +173,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
frozen_graph_def: Frozen GraphDef. frozen_graph_def: Frozen GraphDef.
in_tensors: List of input tensors for the graph. in_tensors: List of input tensors for the graph.
out_tensors: List of output tensors for the graph. out_tensors: List of output tensors for the graph.
graph: `Graph` object.
Raises: Raises:
ValueError: ValueError:
@ -203,4 +204,4 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
util.set_tensor_shapes(in_tensors, input_shapes) util.set_tensor_shapes(in_tensors, input_shapes)
frozen_graph_def = util.freeze_graph(sess, in_tensors, out_tensors) frozen_graph_def = util.freeze_graph(sess, in_tensors, out_tensors)
return frozen_graph_def, in_tensors, out_tensors return frozen_graph_def, in_tensors, out_tensors, sess.graph

View File

@ -90,13 +90,14 @@ class FreezeSavedModelTest(test_util.TensorFlowTestCase):
tag_set = set([tag_constants.SERVING]) tag_set = set([tag_constants.SERVING])
if signature_key is None: if signature_key is None:
signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model( graph_def, in_tensors, out_tensors, _ = (
saved_model_dir=saved_model_dir, convert_saved_model.freeze_saved_model(
input_arrays=input_arrays, saved_model_dir=saved_model_dir,
input_shapes=input_shapes, input_arrays=input_arrays,
output_arrays=output_arrays, input_shapes=input_shapes,
tag_set=tag_set, output_arrays=output_arrays,
signature_key=signature_key) tag_set=tag_set,
signature_key=signature_key))
return graph_def, in_tensors, out_tensors return graph_def, in_tensors, out_tensors
def testSimpleSavedModel(self): def testSimpleSavedModel(self):

View File

@ -43,7 +43,9 @@ from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.lite.python.optimize import calibrator as _calibrator from tensorflow.lite.python.optimize import calibrator as _calibrator
from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
from tensorflow.lite.python.util import freeze_graph as _freeze_graph from tensorflow.lite.python.util import freeze_graph as _freeze_graph
from tensorflow.lite.python.util import get_debug_info as _get_debug_info
from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config from tensorflow.lite.python.util import get_grappler_config as _get_grappler_config
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
@ -253,6 +255,7 @@ class TFLiteConverterV2(TFLiteConverterBase):
self._trackable_obj = trackable_obj self._trackable_obj = trackable_obj
self.allow_custom_ops = False self.allow_custom_ops = False
self.target_spec = TargetSpec() self.target_spec = TargetSpec()
self._debug_info = None
@classmethod @classmethod
def from_concrete_functions(cls, funcs): def from_concrete_functions(cls, funcs):
@ -377,12 +380,15 @@ class TFLiteConverterV2(TFLiteConverterBase):
tensor.set_shape(shape) tensor.set_shape(shape)
self._validate_representative_dataset() self._validate_representative_dataset()
self._debug_info = _get_debug_info(
_build_debug_info_func(self._funcs[0].graph), graph_def)
converter_kwargs = { converter_kwargs = {
"input_format": constants.TENSORFLOW_GRAPHDEF, "input_format": constants.TENSORFLOW_GRAPHDEF,
"allow_custom_ops": self.allow_custom_ops, "allow_custom_ops": self.allow_custom_ops,
"post_training_quantize": self._is_weight_only_quantize(), "post_training_quantize": self._is_weight_only_quantize(),
"target_ops": self.target_spec.supported_ops, "target_ops": self.target_spec.supported_ops,
"debug_info": self._debug_info
} }
# Converts model. # Converts model.
@ -507,7 +513,8 @@ class TFLiteConverter(TFLiteConverterBase):
input_tensors, input_tensors,
output_tensors, output_tensors,
input_arrays_with_shape=None, input_arrays_with_shape=None,
output_arrays=None): output_arrays=None,
experimental_debug_info_func=None):
"""Constructor for TFLiteConverter. """Constructor for TFLiteConverter.
Args: Args:
@ -523,6 +530,8 @@ class TFLiteConverter(TFLiteConverterBase):
output_arrays: List of output tensors to freeze graph with. Use only when output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `input_tensors` and graph cannot be loaded into TensorFlow and when `input_tensors` and
`output_tensors` are None. (default None) `output_tensors` are None. (default None)
experimental_debug_info_func: An experimental function to retrieve the
graph debug info for a set of nodes from the `graph_def`.
Raises: Raises:
ValueError: Invalid arguments. ValueError: Invalid arguments.
@ -545,6 +554,8 @@ class TFLiteConverter(TFLiteConverterBase):
self.dump_graphviz_dir = None self.dump_graphviz_dir = None
self.dump_graphviz_video = False self.dump_graphviz_video = False
self.target_spec = TargetSpec() self.target_spec = TargetSpec()
self._debug_info_func = experimental_debug_info_func
self._debug_info = None
# Attributes are used by models that cannot be loaded into TensorFlow. # Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors(): if not self._has_valid_tensors():
@ -569,7 +580,11 @@ class TFLiteConverter(TFLiteConverterBase):
TFLiteConverter class. TFLiteConverter class.
""" """
graph_def = _freeze_graph(sess, input_tensors, output_tensors) graph_def = _freeze_graph(sess, input_tensors, output_tensors)
return cls(graph_def, input_tensors, output_tensors) return cls(
graph_def,
input_tensors,
output_tensors,
experimental_debug_info_func=_build_debug_info_func(sess.graph))
@classmethod @classmethod
def from_frozen_graph(cls, def from_frozen_graph(cls,
@ -700,7 +715,10 @@ class TFLiteConverter(TFLiteConverterBase):
result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key) output_arrays, tag_set, signature_key)
return cls( return cls(
graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) graph_def=result[0],
input_tensors=result[1],
output_tensors=result[2],
experimental_debug_info_func=_build_debug_info_func(result[3]))
@classmethod @classmethod
def from_keras_model_file(cls, def from_keras_model_file(cls,
@ -743,8 +761,12 @@ class TFLiteConverter(TFLiteConverterBase):
frozen_func = _convert_to_constants.convert_variables_to_constants_v2( frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
concrete_func) concrete_func)
_set_tensor_shapes(frozen_func.inputs, input_shapes) _set_tensor_shapes(frozen_func.inputs, input_shapes)
return cls(frozen_func.graph.as_graph_def(), frozen_func.inputs, return cls(
frozen_func.outputs) frozen_func.graph.as_graph_def(),
frozen_func.inputs,
frozen_func.outputs,
experimental_debug_info_func=_build_debug_info_func(
frozen_func.graph))
# Handles Keras when Eager mode is disabled. # Handles Keras when Eager mode is disabled.
_keras.backend.clear_session() _keras.backend.clear_session()
@ -765,7 +787,11 @@ class TFLiteConverter(TFLiteConverterBase):
_set_tensor_shapes(input_tensors, input_shapes) _set_tensor_shapes(input_tensors, input_shapes)
graph_def = _freeze_graph(sess, input_tensors, output_tensors) graph_def = _freeze_graph(sess, input_tensors, output_tensors)
return cls(graph_def, input_tensors, output_tensors) return cls(
graph_def,
input_tensors,
output_tensors,
experimental_debug_info_func=_build_debug_info_func(sess.graph))
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name == "post_training_quantize": if name == "post_training_quantize":
@ -904,12 +930,15 @@ class TFLiteConverter(TFLiteConverterBase):
except Exception: except Exception:
optimized_graph = self._graph_def optimized_graph = self._graph_def
self._debug_info = _get_debug_info(self._debug_info_func, optimized_graph)
# Converts model. # Converts model.
if self._has_valid_tensors(): if self._has_valid_tensors():
result = _toco_convert_impl( result = _toco_convert_impl(
input_data=optimized_graph, input_data=optimized_graph,
input_tensors=self._input_tensors, input_tensors=self._input_tensors,
output_tensors=self._output_tensors, output_tensors=self._output_tensors,
debug_info=self._debug_info,
**converter_kwargs) **converter_kwargs)
else: else:
result = _toco_convert_graph_def( result = _toco_convert_graph_def(

View File

@ -49,7 +49,20 @@ from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.training_util import write_graph from tensorflow.python.training.training_util import write_graph
class FromConstructor(test_util.TensorFlowTestCase): class TestModels(test_util.TensorFlowTestCase):
def assertValidDebugInfo(self, debug_info):
"""Verify the DebugInfo is valid."""
file_names = set()
for file_path in debug_info.files:
file_names.add(os.path.basename(file_path))
# To make the test independent on how the nodes are created, we only assert
# the name of this test file.
self.assertIn('lite_test.py', file_names)
self.assertNotIn('lite_v2_test.py', file_names)
class FromConstructor(TestModels):
# Tests invalid constructors using a dummy value for the GraphDef. # Tests invalid constructors using a dummy value for the GraphDef.
def testInvalidConstructor(self): def testInvalidConstructor(self):
@ -89,7 +102,7 @@ class FromConstructor(test_util.TensorFlowTestCase):
@test_util.run_v1_only('Incompatible with 2.0.') @test_util.run_v1_only('Incompatible with 2.0.')
class FromSessionTest(test_util.TensorFlowTestCase): class FromSessionTest(TestModels):
def testFloat(self): def testFloat(self):
in_tensor = array_ops.placeholder( in_tensor = array_ops.placeholder(
@ -160,8 +173,9 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session() sess = session.Session()
# Convert model and ensure model is not None. # Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session( converter = lite.TFLiteConverter.from_session(sess,
sess, [in_tensor_1, in_tensor_2], [out_tensor]) [in_tensor_1, in_tensor_2],
[out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8 converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = { converter.quantized_input_stats = {
'inputA': (0., 1.), 'inputA': (0., 1.),
@ -205,8 +219,9 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session() sess = session.Session()
# Convert model and ensure model is not None. # Convert model and ensure model is not None.
converter = lite.TFLiteConverter.from_session( converter = lite.TFLiteConverter.from_session(sess,
sess, [in_tensor_1, in_tensor_2], [out_tensor]) [in_tensor_1, in_tensor_2],
[out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8 converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
@ -851,6 +866,33 @@ class FromSessionTest(test_util.TensorFlowTestCase):
np.array([[2, 2], [2, 2]], dtype=np.int32)) np.array([[2, 2], [2, 2]], dtype=np.int32))
interpreter.invoke() interpreter.invoke()
def testGraphDebugInfo(self):
"""Test a session has debug info captured."""
@def_function.function
def plus_placeholder(x, placeholder):
return x + placeholder
with ops.Graph().as_default():
placeholder = array_ops.placeholder(
dtype=dtypes.float32, shape=[1], name='input')
variable_node = variables.Variable(1.0, name='variable_node')
defun_node = plus_placeholder(variable_node, placeholder)
output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
# Initialize variables in the model.
sess = session.Session()
sess.run(variables.variables_initializer([variable_node]))
converter = lite.TFLiteConverter.from_session(sess, [placeholder],
[output_node])
converter.convert()
self.assertValidDebugInfo(converter._debug_info)
# Check the add node in the inlined function is included.
func = sess.graph.as_graph_def().library.function[0].signature.name
self.assertIn((func + 'add'), converter._debug_info.traces)
@test_util.run_v1_only('Incompatible with 2.0.') @test_util.run_v1_only('Incompatible with 2.0.')
class FromFrozenGraphFile(test_util.TensorFlowTestCase): class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@ -1013,6 +1055,25 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() interpreter.allocate_tensors()
def testGraphDebugInfo(self):
"""Test a frozen graph doesn't have debug info captured."""
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
_ = in_tensor + in_tensor
sess = session.Session()
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
['Placeholder'], ['add'])
converter.convert()
# GraphDebugInfo should be none for frozen graph.
self.assertTrue(not converter._debug_info)
class FromFrozenGraphObjectDetection(test_util.TensorFlowTestCase): class FromFrozenGraphObjectDetection(test_util.TensorFlowTestCase):
@ -1040,9 +1101,10 @@ class FromFrozenGraphObjectDetection(test_util.TensorFlowTestCase):
# Tests the object detection model that cannot be loaded in TensorFlow. # Tests the object detection model that cannot be loaded in TensorFlow.
self._initObjectDetectionArgs() self._initObjectDetectionArgs()
converter = lite.TFLiteConverter.from_frozen_graph( converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
self._graph_def_file, self._input_arrays, self._output_arrays, self._input_arrays,
self._input_shapes) self._output_arrays,
self._input_shapes)
converter.allow_custom_ops = True converter.allow_custom_ops = True
tflite_model = converter.convert() tflite_model = converter.convert()
self.assertTrue(tflite_model) self.assertTrue(tflite_model)
@ -1081,8 +1143,9 @@ class FromFrozenGraphObjectDetection(test_util.TensorFlowTestCase):
# Missing `input_shapes`. # Missing `input_shapes`.
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
lite.TFLiteConverter.from_frozen_graph( lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
self._graph_def_file, self._input_arrays, self._output_arrays) self._input_arrays,
self._output_arrays)
self.assertEqual('input_shapes must be defined for this model.', self.assertEqual('input_shapes must be defined for this model.',
str(error.exception)) str(error.exception))
@ -1103,7 +1166,7 @@ class FromFrozenGraphObjectDetection(test_util.TensorFlowTestCase):
@test_util.run_v1_only('Incompatible with 2.0.') @test_util.run_v1_only('Incompatible with 2.0.')
class FromSavedModelTest(test_util.TensorFlowTestCase): class FromSavedModelTest(TestModels):
def _createSavedModel(self, shape): def _createSavedModel(self, shape):
"""Create a simple SavedModel.""" """Create a simple SavedModel."""
@ -1248,6 +1311,13 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() interpreter.allocate_tensors()
def testGraphDebugInfo(self):
"""Test a SavedModel has debug info captured."""
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.convert()
self.assertValidDebugInfo(converter._debug_info)
class MyAddLayer(keras.layers.Layer): class MyAddLayer(keras.layers.Layer):
@ -1265,7 +1335,7 @@ class MyAddLayer(keras.layers.Layer):
@test_util.run_v1_only('Incompatible with 2.0.') @test_util.run_v1_only('Incompatible with 2.0.')
class FromKerasFile(test_util.TensorFlowTestCase, parameterized.TestCase): class FromKerasFile(TestModels, parameterized.TestCase):
def setUp(self): def setUp(self):
super(FromKerasFile, self).setUp() super(FromKerasFile, self).setUp()
@ -1627,9 +1697,19 @@ class FromKerasFile(test_util.TensorFlowTestCase, parameterized.TestCase):
interpreter = Interpreter(model_content=tflite_model) interpreter = Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() interpreter.allocate_tensors()
@parameterized.named_parameters(('_graph', context.graph_mode),
('_eager', context.eager_mode))
def testGraphDebugInfo(self, test_context):
"""Test a Sequential tf.keras model has debug info captured."""
with test_context():
self._getSequentialModel()
converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
converter.convert()
self.assertValidDebugInfo(converter._debug_info)
@test_util.run_v1_only('Incompatible with 2.0.') @test_util.run_v1_only('Incompatible with 2.0.')
class GrapplerTest(test_util.TensorFlowTestCase): class GrapplerTest(TestModels):
def testConstantFolding(self): def testConstantFolding(self):
# Constant folding handles the tf.broadcast_to operation which was not # Constant folding handles the tf.broadcast_to operation which was not

View File

@ -83,6 +83,16 @@ class TestModels(test_util.TensorFlowTestCase):
return BasicModel() return BasicModel()
def _assertValidDebugInfo(self, debug_info):
"""Verify the DebugInfo is valid."""
file_names = set()
for file_path in debug_info.files:
file_names.add(os.path.basename(file_path))
# To make the test independent on how the nodes are created, we only assert
# the name of this test file.
self.assertIn('lite_v2_test.py', file_names)
self.assertNotIn('lite_test.py', file_names)
class FromConcreteFunctionTest(TestModels): class FromConcreteFunctionTest(TestModels):
@ -239,6 +249,20 @@ class FromConcreteFunctionTest(TestModels):
# Ensure that the quantized weights tflite model is smaller. # Ensure that the quantized weights tflite model is smaller.
self.assertLess(len(quantized_tflite), len(float_tflite)) self.assertLess(len(quantized_tflite), len(float_tflite))
@test_util.run_v2_only
def testGraphDebugInfo(self):
"""Test a concrete function has debug info captured."""
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.f = def_function.function(lambda x: root.v1 * x)
input_data = constant_op.constant(1., shape=[1])
concrete_func = root.f.get_concrete_function(input_data)
# Convert model.
converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
converter.convert()
self._assertValidDebugInfo(converter._debug_info)
class FromSavedModelTest(TestModels): class FromSavedModelTest(TestModels):
@ -355,6 +379,22 @@ class FromSavedModelTest(TestModels):
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value, actual_value) self.assertEqual(expected_value, actual_value)
@test_util.run_v2_only
def testGraphDebugInfo(self):
"""Test a SavedModel has debug info captured."""
input_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.f = def_function.function(lambda x: 2. * x)
to_save = root.f.get_concrete_function(input_data)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(root, save_dir, to_save)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
converter.convert()
self._assertValidDebugInfo(converter._debug_info)
class FromKerasModelTest(TestModels): class FromKerasModelTest(TestModels):
@ -426,6 +466,20 @@ class FromKerasModelTest(TestModels):
for tf_result, tflite_result in zip(expected_value, actual_value): for tf_result, tflite_result in zip(expected_value, actual_value):
np.testing.assert_almost_equal(tf_result[0], tflite_result, 5) np.testing.assert_almost_equal(tf_result[0], tflite_result, 5)
@test_util.run_v2_only
def testGraphDebugInfo(self):
"""Test a tf.Keras model has debug info captured."""
# Create a simple Keras model.
x = [-1, 0, 1, 2, 3, 4]
y = [-3, -1, 1, 3, 5, 7]
model = keras.models.Sequential(
[keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=1)
converter = lite.TFLiteConverterV2.from_keras_model(model)
converter.convert()
self._assertValidDebugInfo(converter._debug_info)
class GrapplerTest(TestModels): class GrapplerTest(TestModels):

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import sys
from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import config_pb2 as _config_pb2 from tensorflow.core.protobuf import config_pb2 as _config_pb2
@ -26,7 +27,9 @@ from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
from tensorflow.lite.toco import types_pb2 as _types_pb2 from tensorflow.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation as _error_interpolation
from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.grappler import tf_optimizer from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
@ -285,3 +288,71 @@ def is_frozen_graph(sess):
if op.type.startswith("Variable") or op.type.endswith("VariableOp"): if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
return False return False
return True return True
def build_debug_info_func(original_graph):
"""Returns a method to retrieve the `GraphDebugInfo` from the original graph.
Args:
original_graph: The original `Graph` containing all the op stack traces.
Returns:
A function which retrieves the stack traces from the original graph and
converts them to a `GraphDebugInfo` for a given set of nodes.
"""
def f(original_nodes):
"""Function to create `GraphDebugInfo` for the given `original_nodes`."""
if not original_graph:
return None
# For the given nodes, gets all the op definitions in the original graph.
useful_ops = []
for func, name in original_nodes:
try:
if not func:
useful_ops.append((func, original_graph.get_operation_by_name(name)))
else:
sub_func = original_graph._get_function(func) # pylint: disable=protected-access
if isinstance(sub_func, function._EagerDefinedFunction): # pylint: disable=protected-access
useful_ops.append(
(func, sub_func.graph.get_operation_by_name(name)))
else:
sys.stderr.write(
"Use '@tf.function' or '@defun' to decorate the function.")
continue
except KeyError:
# New node created by graph optimizer. No stack trace from source code.
continue
# Convert all the op definitions to stack traces in terms of GraphDebugInfo.
return _error_interpolation.create_graph_debug_info_def(useful_ops)
return f
def get_debug_info(nodes_to_debug_info_func, converted_graph):
"""Returns the debug info for the original nodes in the `converted_graph`.
Args:
nodes_to_debug_info_func: The method to collect the op debug info for the
nodes.
converted_graph: A `GraphDef` after optimization and transfermation.
Returns:
`GraphDebugInfo` for all the original nodes in `converted_graph`.
"""
if not nodes_to_debug_info_func:
return None
# Collect all the debug info nodes from the converted_graph
original_nodes = set()
for node in converted_graph.node:
debug_nodes = node.experimental_debug_info.original_node_names
debug_funcs = node.experimental_debug_info.original_func_names
# If the `original_node_names` are empty, uses the node name directly.
if not debug_nodes:
original_nodes.add(("", node.name))
else:
for i in range(len(debug_nodes)):
original_nodes.add((debug_funcs[i], debug_nodes[i]))
# Convert the nodes to the debug info proto object.
return nodes_to_debug_info_func(original_nodes)

View File

@ -29,6 +29,7 @@ import re
import six import six
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.python.util import tf_stack from tensorflow.python.util import tf_stack
_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?" _NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?"
@ -212,7 +213,8 @@ def _get_defining_frame_from_op(op):
frame_index = _find_index_of_defining_frame_for_op(op) frame_index = _find_index_of_defining_frame_for_op(op)
return op.traceback[frame_index] return op.traceback[frame_index]
def compute_useful_frames(op, num):
def _compute_useful_frames(op, num):
"""Return a list of frames, which form a 'useful' stack. """Return a list of frames, which form a 'useful' stack.
Starting from the defining frame to the outermost one, this method computes Starting from the defining frame to the outermost one, this method computes
@ -235,6 +237,54 @@ def compute_useful_frames(op, num):
outermost_included = max(innermost_excluded - num, 0) outermost_included = max(innermost_excluded - num, 0)
return op.traceback[outermost_included:innermost_excluded] return op.traceback[outermost_included:innermost_excluded]
def create_graph_debug_info_def(operations):
"""Construct and returns a `GraphDebugInfo` protocol buffer.
Args:
operations: An iterable of op.Operation objects having _traceback members.
Returns:
GraphDebugInfo protocol buffer.
Raises:
TypeError: If the arguments are not of the correct proto buffer type.
"""
# Creates an empty GraphDebugInfoDef proto.
graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo()
# Gets the file names and line numbers for the exported node names. Also
# collects the unique file names.
all_file_names = set()
node_to_trace = {}
for func, op in operations:
# Gets the stack trace of the operation and then the file location.
node_name = func + op.name
node_to_trace[node_name] = _compute_useful_frames(op, 10)
for frame in node_to_trace[node_name]:
all_file_names.add(frame[tf_stack.TB_FILENAME])
# Sets the `files` field in the GraphDebugInfo proto
graph_debug_info_def.files.extend(all_file_names)
# Builds a mapping between file names and index of the `files` field, so we
# only store the indexes for the nodes in the GraphDebugInfo.
file_to_index = dict(
[(y, x) for x, y in enumerate(graph_debug_info_def.files)])
# Creates the FileLineCol proto for each node and sets the value in the
# GraphDebugInfo proto. We only store the file name index for each node to
# save the storage space.
for node_name, frames in node_to_trace.items():
trace_def = graph_debug_info_def.traces[node_name]
for frame in reversed(frames):
trace_def.file_line_cols.add(
file_index=file_to_index[frame[tf_stack.TB_FILENAME]],
line=frame[tf_stack.TB_LINENO])
return graph_debug_info_def
def compute_field_dict(op, strip_file_prefix=""): def compute_field_dict(op, strip_file_prefix=""):
"""Return a dictionary mapping interpolation tokens to values. """Return a dictionary mapping interpolation tokens to values.

View File

@ -30,7 +30,6 @@ from google.protobuf import text_format
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import op_def_pb2
from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
@ -44,7 +43,6 @@ from tensorflow.python.framework import versions
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import tf_stack
# Prefix to be added to unbound input names so they are easily identifiable. # Prefix to be added to unbound input names so they are easily identifiable.
@ -514,55 +512,6 @@ def strip_graph_default_valued_attrs(meta_graph_def):
meta_graph_def.meta_info_def.stripped_default_attrs = True meta_graph_def.meta_info_def.stripped_default_attrs = True
def create_graph_debug_info_def(operations):
"""Construct and returns a `GraphDebugInfo` protocol buffer.
Args:
operations: An iterable of op.Operation objects having _traceback members.
Returns:
GraphDebugInfo protocol buffer.
Raises:
TypeError: If the arguments are not of the correct proto buffer type.
"""
# Creates an empty GraphDebugInfoDef proto.
graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo()
# Gets the file names and line numbers for the exported node names. Also
# collects the unique file names.
all_file_names = set()
node_to_trace = {}
for op in operations:
# Gets the stack trace of the operation and then the file location.
node_name = op.name
node_to_trace[node_name] = error_interpolation.compute_useful_frames(op, 10)
for frame in node_to_trace[node_name]:
all_file_names.add(frame[tf_stack.TB_FILENAME])
# Sets the `files` field in the GraphDebugInfo proto
graph_debug_info_def.files.extend(all_file_names)
# Builds a mapping between file names and index of the `files` field, so we
# only store the indexes for the nodes in the GraphDebugInfo.
file_to_index = dict(
[(y, x) for x, y in enumerate(graph_debug_info_def.files)])
# Creates the FileLineCol proto for each node and sets the value in the
# GraphDebugInfo proto. We only store the file name index for each node to
# save the storage space.
for node_name, frames in node_to_trace.items():
trace_def = graph_debug_info_def.traces[node_name]
for frame in reversed(frames):
trace_def.file_line_cols.add(
file_index=file_to_index[frame[tf_stack.TB_FILENAME]],
line=frame[tf_stack.TB_LINENO],
func=frame[tf_stack.TB_FUNCNAME],
code=frame[tf_stack.TB_CODEDICT])
return graph_debug_info_def
def create_meta_graph_def(meta_info_def=None, def create_meta_graph_def(meta_info_def=None,
graph_def=None, graph_def=None,
saver_def=None, saver_def=None,
@ -1108,12 +1057,14 @@ def export_scoped_meta_graph(filename=None,
# Gets the operation from the graph by the name. Exludes variable nodes, # Gets the operation from the graph by the name. Exludes variable nodes,
# so only the nodes in the frozen models are included. # so only the nodes in the frozen models are included.
# TODO(liufengdb): fix this for functions.
ops_to_export = [] ops_to_export = []
for node in scoped_meta_graph_def.graph_def.node: for node in scoped_meta_graph_def.graph_def.node:
scoped_op_name = ops.prepend_name_scope(node.name, export_scope) scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
ops_to_export.append(graph.get_operation_by_name(scoped_op_name)) ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))
graph_debug_info = create_graph_debug_info_def(ops_to_export) graph_debug_info = error_interpolation.create_graph_debug_info_def(
ops_to_export)
graph_io.write_graph( graph_io.write_graph(
graph_debug_info, graph_debug_info,

View File

@ -28,6 +28,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import function from tensorflow.python.framework import function
from tensorflow.python.framework import meta_graph from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -740,8 +741,11 @@ class ScopedMetaGraphTest(test.TestCase):
biases1 = resource_variable_ops.ResourceVariable( biases1 = resource_variable_ops.ResourceVariable(
[0.1] * 3, name="biases") [0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
debug_info_def = meta_graph.create_graph_debug_info_def( operations = []
operations=graph1.get_operations()) for op in graph1.get_operations():
operations.append(("", op))
debug_info_def = error_interpolation.create_graph_debug_info_def(
operations=operations)
# The unique file names in all the stack traces should be larger or equal # The unique file names in all the stack traces should be larger or equal
# than 1. # than 1.

View File

@ -8,5 +8,11 @@ tf_proto {
label: LABEL_REPEATED label: LABEL_REPEATED
type: TYPE_STRING type: TYPE_STRING
} }
field {
name: "original_func_names"
number: 2
label: LABEL_REPEATED
type: TYPE_STRING
}
} }
} }

View File

@ -67,6 +67,12 @@ tf_proto {
label: LABEL_REPEATED label: LABEL_REPEATED
type: TYPE_STRING type: TYPE_STRING
} }
field {
name: "original_func_names"
number: 2
label: LABEL_REPEATED
type: TYPE_STRING
}
} }
} }
} }

View File

@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " argspec: "args=[\'self\', \'graph_def\', \'input_tensors\', \'output_tensors\', \'input_arrays_with_shape\', \'output_arrays\', \'experimental_debug_info_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "convert" name: "convert"