diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 13d1d1f6f8e..0b7f4c0bcc7 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -328,9 +328,9 @@ class FromSavedModelTest(TestModels): self.assertIn('This converter can only convert a single ConcreteFunction', str(error.exception)) + @test_util.run_v2_only def testKerasSequentialModel(self): """Test a simple sequential tf.Keras model.""" - self.skipTest('b/134660903') input_data = constant_op.constant(1., shape=[1, 1]) x = np.array([[1.], [2.]]) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 24377614031..3d03a45335e 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -164,7 +164,7 @@ def _compatible_shapes(flat_relaxed, flat_to_check): for relaxed, to_check in zip(flat_relaxed, flat_to_check)) -def _common_shape(x, y): +def common_shape(x, y): """Find a `TensorShape` that is compatible with both `x` and `y`.""" if x is None != y is None: raise RuntimeError( @@ -1577,7 +1577,7 @@ class Function(object): "relaxed_arg_shapes len: %d vs. %d" % (len(arg_shapes), len(relaxed_arg_shapes))) relaxed_arg_shapes = [ - _common_shape(x, y) for (x, y) in zip( + common_shape(x, y) for (x, y) in zip( arg_shapes, relaxed_arg_shapes)] self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = ( relaxed_arg_shapes) @@ -1679,8 +1679,9 @@ def register(func, *args, **kwargs): def validate_signature(signature): if any(not isinstance(arg, tensor_spec.TensorSpec) for arg in nest.flatten(signature, expand_composites=True)): - raise TypeError("Invalid input_signature %s; input_signature must be " - "a possibly nested sequence of TensorSpec objects.") + raise TypeError("Invalid input_signature {}; input_signature must be " + "a possibly nested sequence of TensorSpec objects." + .format(signature)) def defun(func=None, diff --git a/tensorflow/python/keras/saving/saved_model.py b/tensorflow/python/keras/saving/saved_model.py index 8d9fc094dad..ea387eedbbb 100644 --- a/tensorflow/python/keras/saving/saved_model.py +++ b/tensorflow/python/keras/saving/saved_model.py @@ -20,12 +20,12 @@ from __future__ import print_function import functools import json import os +import weakref import six from tensorflow.python.client import session from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec @@ -38,10 +38,10 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.saving import model_from_json from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.utils import mode_keys +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder @@ -713,10 +713,20 @@ def serialize_all_attributes(layer, serialization_cache): if _should_skip_serialization(layer): return serialized_attr + function_dict = {} + if save_model_default_signature: + # For compatibility with the tf.Lite Converter, the default save signature + # should be traced without nested calls to other wrapped functions. + # TODO(kathywu): Investigate why having nested calls results in a stateful + # function. Perhaps something to do with losses, which are traced in nested + # calls but not in the flat call. + function_dict['_default_save_signature'] = _default_save_signature(layer) + else: + function_dict['_default_save_signature'] = None + object_dict = _wrap_layer_objects(layer, serialization_cache) try: - function_dict = _wrap_layer_functions(layer, serialization_cache, - save_model_default_signature) + function_dict.update(_wrap_layer_functions(layer, serialization_cache)) except (ValueError, TypeError) as e: logging.warning('Skipping full serialization of object {}, because an ' 'error occurred while tracing layer functions. Error ' @@ -799,44 +809,53 @@ def _wrap_layer_objects(layer, serialization_cache): wrapped_loss_functions)) -def _wrap_layer_functions(layer, serialization_cache, - save_model_default_signature=False): +def _wrap_layer_functions(layer, serialization_cache): """Returns dict of wrapped layer call function and losses in tf.functions. Args: layer: Keras Layer object. serialization_cache: Dictionary shared between all objects during serialization. - save_model_default_signature: Whether to save traced model call function. Returns: A dictionary containing all keras tf.functions to serialize. See LayerAttributes and ModelAttributes for the list of all attributes. """ + # Since Sequential models may be modified in place using model.add() or + # model.pop(), don't use saved functions. + if (isinstance(layer, RevivedLayer) and + not isinstance(layer, RevivedSequential)): + return {fn_name: getattr(layer.keras_api, fn_name, None) + for fn_name in LayerAttributes.all_functions} + # Reset the losses of the layer and its children. The call function in each # child layer is replaced with tf.functions. - original_attrs = _replace_child_layer_functions(layer, serialization_cache) - original_layer_losses = layer._losses[:] # pylint: disable=protected-access - with trackable.no_automatic_dependency_tracking_scope(layer): - layer._losses = [] # pylint: disable=protected-access - # Note that eager losses do not need to be saved since these functions - # create symbolic losses. + original_fns = _replace_child_layer_functions(layer, serialization_cache) + original_losses = _reset_layer_losses(layer) # Wrap all the layer call and activity regularizer functions. - call_fn_with_losses = _wrap_call_and_conditional_losses(layer) - fns = {'call_and_return_conditional_losses': call_fn_with_losses, - '__call__': _extract_outputs_from_fn(layer, call_fn_with_losses)} - if save_model_default_signature: - fns['_default_save_signature'] = saving_utils.trace_model_call(layer) - else: - fns['_default_save_signature'] = None + # Use LayerCallCollection to ensure that all layer call functions (__call__, + # call with losses) are traced with the same inputs. + call_collection = LayerCallCollection(layer) + call_fn_with_losses = call_collection.add_function( + _wrap_call_and_conditional_losses(layer), + '{}_layer_call_and_return_conditional_losses'.format(layer.name)) + call_fn = call_collection.add_function( + _extract_outputs_from_fn(layer, call_fn_with_losses), + '{}_layer_call_fn'.format(layer.name)) + + fns = {'call_and_return_conditional_losses': call_fn_with_losses, + '__call__': call_fn} if layer.activity_regularizer is not None: fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer) fns['call_and_return_all_conditional_losses'] = ( - _append_activity_regularizer_loss( - layer, call_fn_with_losses, fns['activity_regularizer_fn'])) + call_collection.add_function( + _append_activity_regularizer_loss(call_fn_with_losses, + fns['activity_regularizer_fn']), + '{}_layer_call_and_return_all_conditional_losses'.format(layer.name) + )) else: fns['activity_regularizer_fn'] = None fns['call_and_return_all_conditional_losses'] = call_fn_with_losses @@ -849,14 +868,21 @@ def _wrap_layer_functions(layer, serialization_cache, if fn is not None and fn.input_signature is not None: fn.get_concrete_function() - # Restore overwritten functions/losses - with trackable.no_automatic_dependency_tracking_scope(layer): - layer._losses = original_layer_losses # pylint: disable=protected-access - _restore_child_layer_functions(original_attrs) + # Restore overwritten functions and losses + _restore_child_layer_functions(original_fns) + _restore_layer_losses(original_losses) return fns +def _default_save_signature(layer): + original_losses = _reset_layer_losses(layer) + fn = saving_utils.trace_model_call(layer) + fn.get_concrete_function() + _restore_layer_losses(original_losses) + return fn + + def _list_all_layers(obj): if isinstance(obj, training_lib.Model): return obj.layers @@ -888,11 +914,9 @@ def _replace_child_layer_functions(layer, serialization_cache): Child layer 2: ... } """ - original_attrs = {} + # pylint: disable=protected-access + original_fns = {} for child_layer in _list_all_layers(layer): - # Save symbolic layer losses, which will be restored to maintain the same - # state. - original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access if child_layer not in serialization_cache[_KERAS_CACHE_KEY]: layer_fns = (serialize_all_attributes(child_layer, serialization_cache) .functions) @@ -906,27 +930,46 @@ def _replace_child_layer_functions(layer, serialization_cache): # wrapped. In this case, no replacement is necessary so move on to the # next child. continue - - original_attrs[child_layer]['call'] = child_layer.call - original_attrs[child_layer]['activity_regularizer'] = ( - child_layer.activity_regularizer) + original_fns[child_layer] = { + 'call': child_layer.call, + 'activity_regularizer': child_layer.activity_regularizer + } with trackable.no_automatic_dependency_tracking_scope(child_layer): child_layer.activity_regularizer = layer_fns.get( 'activity_regularizer_fn') child_layer.call = _use_wrapped_call( child_layer, layer_fns['call_and_return_conditional_losses']) - child_layer._losses = [] # pylint: disable=protected-access - return original_attrs + return original_fns + # pylint: enable=protected-access -def _restore_child_layer_functions(original_attrs): +def _restore_child_layer_functions(original_fns): """Restores attributes replaced with `_replace_child_layer_functions`.""" - for child_layer, attrs in original_attrs.items(): + for child_layer, fns in original_fns.items(): with trackable.no_automatic_dependency_tracking_scope(child_layer): - child_layer._losses = attrs['losses'] # pylint: disable=protected-access - if 'call' in attrs: - child_layer.call = attrs['call'] - child_layer.activity_regularizer = attrs['activity_regularizer'] + child_layer.call = fns['call'] + child_layer.activity_regularizer = fns['activity_regularizer'] + + +# pylint: disable=protected-access +def _reset_layer_losses(parent_layer): + """Resets losses of layer and its sublayers, and returns original losses.""" + losses_dict = {} + for layer in _list_all_layers(parent_layer) + [parent_layer]: + losses_dict[layer] = {'losses': layer._losses[:], + 'eager_losses': layer._eager_losses[:]} + with trackable.no_automatic_dependency_tracking_scope(layer): + layer._losses = [] + layer._eager_losses = [] + return losses_dict + + +def _restore_layer_losses(losses_dict): + for layer in losses_dict: + with trackable.no_automatic_dependency_tracking_scope(layer): + layer._losses = losses_dict[layer]['losses'] + layer._eager_losses = losses_dict[layer]['eager_losses'] +# pylint: enable=protected-access def _use_wrapped_call(layer, call_fn): @@ -947,8 +990,10 @@ def _use_wrapped_call(layer, call_fn): training = kwargs.pop('training', None) if training is None: training = K.learning_phase() - training = math_ops.cast(training, dtypes.bool) - outputs, losses = call_fn(inputs, training=training) + outputs, losses = tf_utils.smart_cond( + training, + lambda: call_fn(inputs, training=True), + lambda: call_fn(inputs, training=False)) else: outputs, losses = call_fn(inputs) layer.add_loss(losses, inputs) @@ -956,6 +1001,128 @@ def _use_wrapped_call(layer, call_fn): return wrapped_call +class LayerCallCollection(object): + """Groups wrapped layer call functions. + + This is used to ensure that all layer call functions are traced with the same + inputs- + - call + - call_and_return_conditional_losses + - call_and_return_all_conditional_losses + """ + + def __init__(self, layer): + self._layer = layer + self._expects_training_arg = layer._expects_training_arg # pylint: disable=protected-access + self._input_signature = self._generate_input_signature(layer) + self._functions = weakref.WeakValueDictionary() + # Bool indicating whether this object is currently tracing the layer call + # functions. + self.tracing = False + + def _generate_input_signature(self, layer): + """Inspects layer object and returns the inferred input signature. + + Args: + layer: Layer object. + + Returns: + List of possibly nested TensorSpecs of the layer call function inputs. + The list does not contain the `training` argument. + """ + if (isinstance(layer.call, def_function.Function) and + layer.call.input_signature is not None): + return layer.call.input_signature + else: + if isinstance(layer, training_lib.Model): + return saving_utils.model_input_signature(layer) + elif layer.input_spec is not None: + + def to_tensor_spec_or_none(x): + spec = input_spec.to_tensor_spec(x, layer.dtype) + # If the shape is too general (e.g. multiple dimensions are allowed), + # return None so that separate functions can be generated for each + # inferred input signature. + # TODO(b/134962016): currently partial signatures are not supported. + if spec.shape == tensor_shape.TensorShape(None): + return None + return spec + input_signature = [nest.map_structure( + to_tensor_spec_or_none, layer.input_spec)] + + return input_signature + else: + return None + + def add_trace(self, *args, **kwargs): + """Traces all functions with the same args and kwargs. + + Args: + *args: Positional args passed to the original function. + **kwargs: Keyword args passed to the original function. + """ + kwargs = kwargs.copy() + self.tracing = True + for fn in self._functions.values(): + # TODO(kathywu): Replace arguments with broader shapes defined in the + # input signature. + if self._expects_training_arg: + kwargs['training'] = False + fn.original_get_concrete_function(*args, **kwargs) + kwargs['training'] = True + fn.original_get_concrete_function(*args, **kwargs) + else: + fn.original_get_concrete_function(*args, **kwargs) + self.tracing = False + + @property + def fn_input_signature(self): + """Returns input signature for the wrapped layer call function.""" + if self._expects_training_arg: + # The training arg is left as a python boolean, so the call functions + # will not have an input signature (input signatures may only describe + # tensor arguments). + return None + if None in nest.flatten(self._input_signature): + # TODO(b/134962016): If input signature cannot be partially defined. + return None + return self._input_signature + + def add_function(self, python_function, name): + """Adds a layer call function to the collection.""" + self._functions[name] = fn = LayerCall( + self, python_function, name, + input_signature=self.fn_input_signature) + + if (None not in nest.flatten(self._input_signature) and + self._expects_training_arg): + # Manually add traces for layers that expect a training argument and have + # a fully defined input signature. + self.add_trace(*self._input_signature) + return fn + + +class LayerCall(def_function.Function): + """Function that triggers traces of other functions in the same collection.""" + + def __init__(self, call_collection, *args, **kwargs): + super(LayerCall, self).__init__(*args, **kwargs) + self.call_collection = call_collection + + def __call__(self, *args, **kwargs): + if not self.call_collection.tracing: + self.call_collection.add_trace(*args, **kwargs) + return super(LayerCall, self).__call__(*args, **kwargs) + + def get_concrete_function(self, *args, **kwargs): + if not self.call_collection.tracing: + self.call_collection.add_trace(*args, **kwargs) + return super(LayerCall, self).get_concrete_function(*args, **kwargs) + + def original_get_concrete_function(self, *args, **kwargs): + return super(LayerCall, self).get_concrete_function(*args, **kwargs) + + def _wrap_call_and_conditional_losses(layer): """Wraps call function that returns a tuple of (outputs, losses). @@ -966,51 +1133,19 @@ def _wrap_call_and_conditional_losses(layer): layer: a Keras layer object Returns: - call function that returns outputs and conditional losses -- excludes + python call function that returns outputs and conditional losses -- excludes activity regularizer """ - if isinstance(layer, RevivedLayer): - return layer.keras_api.call_and_return_conditional_losses - - if (isinstance(layer.call, def_function.Function) and - layer.call.input_signature is not None): - input_signature = layer.call.input_signature - else: - if (isinstance(layer, training_lib.Model) and - saving_utils.model_input_signature(layer) is not None): - input_signature = saving_utils.model_input_signature(layer) - elif layer.input_spec is not None: - input_signature = [nest.map_structure( - lambda x: input_spec.to_tensor_spec(x, layer.dtype), - layer.input_spec)] - # If input spec is too general, then don't define an input signature. - for spec in nest.flatten(input_signature): - if spec.shape == tensor_shape.TensorShape(None): - input_signature = None - break - else: - input_signature = None - - if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access - input_signature.append( - tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool)) - # Create function that generates both outputs and losses layer_call = layer.call if layer._expects_training_arg: # pylint: disable=protected-access - def call_and_return_conditional_losses(inputs, training): - _set_symbolic_learning_phase(training) + def call_and_return_conditional_losses(inputs, training=False): return layer_call(inputs, training=training), layer.get_losses_for(inputs) else: def call_and_return_conditional_losses(inputs): K.set_learning_phase(0) return layer_call(inputs), layer.get_losses_for(inputs) - return def_function.Function( - call_and_return_conditional_losses, - '{}_layer_call_and_return_conditional_losses'.format(layer.name), - input_signature=input_signature, - # TODO(kathywu): Investigate autograph error. - autograph=False) + return call_and_return_conditional_losses def _extract_outputs_from_fn(layer, call_and_return_conditional_losses): @@ -1018,50 +1153,22 @@ def _extract_outputs_from_fn(layer, call_and_return_conditional_losses): if isinstance(layer, RevivedLayer): return layer.keras_api.__call__ # pylint: disable=protected-access if layer._expects_training_arg: # pylint: disable=protected-access - def call(inputs, training): - return call_and_return_conditional_losses(inputs, training)[0] + def call(inputs, training=False): + return call_and_return_conditional_losses(inputs, training=training)[0] else: def call(inputs): return call_and_return_conditional_losses(inputs)[0] - return def_function.Function( - call, '{}_layer_call_fn'.format(layer.name), - input_signature=call_and_return_conditional_losses.input_signature, - # TODO(kathywu): Investigate autograph error. - autograph=False) - - -def _set_symbolic_learning_phase(value): - """Set learning phase to a tensor value (for internal use only). - - This is used when wrapping call functions as tf.functions that have training - as a tensor input. Thus, when `learning_phase()` is called, the training - tensor is returned. This function is called when saving a model to SavedModel. - - Args: - value: A Tensor object. - - Raises: - ValueError: If the input value is not a graph tensor - """ - graph = K.get_graph() - if not isinstance(value, ops.Tensor): - raise ValueError('Symbolic learning phase must be a graph tensor.') - K._GRAPH_LEARNING_PHASES[graph] = value # pylint: disable=protected-access + return call def _append_activity_regularizer_loss( - layer, call_fn_with_losses, activity_regularizer_fn): + call_fn_with_losses, activity_regularizer_fn): """Appends activity regularizer loss to losses returned by the wrapped fn.""" - def fn(*args): - outputs, losses = call_fn_with_losses(*args) + def fn(*args, **kwargs): + outputs, losses = call_fn_with_losses(*args, **kwargs) losses.append(activity_regularizer_fn(outputs)) return outputs, losses - return def_function.Function( - fn, - '{}_layer_call_and_return_all_conditional_losses'.format(layer.name), - input_signature=call_fn_with_losses.input_signature, - # TODO(kathywu): Investigate autograph error. - autograph=False) + return fn def _wrap_unconditional_loss(loss_fn, index): @@ -1135,9 +1242,11 @@ class KerasObjectLoader(load.Loader): # pylint: disable=protected-access for node in self._nodes: if isinstance(node, RevivedModel): - input_signature = ( - node.keras_api.call_and_return_conditional_losses.input_signature[0] - ) + call_fn = node.keras_api.call_and_return_conditional_losses + if call_fn.input_signature is None: + inputs = infer_inputs_from_restored_call_function(call_fn) + else: + inputs = call_fn.input_signature[0] if isinstance(node, RevivedSequential): with trackable.no_automatic_dependency_tracking_scope(node): node._layers = [] @@ -1147,7 +1256,7 @@ class KerasObjectLoader(load.Loader): if not node.inputs: # Since this revived object is technically a subclassed model (even if # the original model is functional/sequential), inputs should be set. - node._set_inputs(input_signature) + node._set_inputs(inputs) if isinstance(node, RevivedLayer): losses = node._serialized_attributes.get('regularization_losses', []) for loss in losses: @@ -1276,6 +1385,26 @@ def recursively_deserialize_keras_object(config, module_objects=None): raise ValueError('Unable to decode config: {}'.format(config)) +def infer_inputs_from_restored_call_function(fn): + """Returns TensorSpec of inputs from a restored call function. + + Args: + fn: Restored layer call function. It is assumed that the inputs are entirely + in the first argument. + + Returns: + TensorSpec of call function inputs. + """ + def common_spec(x, y): + return tensor_spec.TensorSpec(defun.common_shape(x.shape, y.shape), + x.dtype, x.name) + spec = fn.concrete_functions[0].structured_input_signature[0][0] + for concrete in fn.concrete_functions[1:]: + spec2 = concrete.structured_input_signature[0][0] + spec = nest.map_structure(common_spec, spec, spec2) + return spec + + class RevivedNetwork(RevivedLayer): """Keras network of layers loaded from a SavedModel.""" diff --git a/tensorflow/python/keras/saving/saved_model_test.py b/tensorflow/python/keras/saving/saved_model_test.py index 919ae45972d..732bf820868 100644 --- a/tensorflow/python/keras/saving/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model_test.py @@ -705,14 +705,8 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): expected_layers = len(model.layers) self.assertEqual(expected_layers, len(loaded.keras_api.layers)) input_arr = array_ops.ones((4, 3)) - training_bool = constant_op.constant(False) - - if model._expects_training_arg: - call_args = [input_arr, training_bool] - else: - call_args = [input_arr] self.assertAllClose(self.evaluate(model(input_arr)), - self.evaluate(loaded(*call_args))) + self.evaluate(loaded(input_arr))) @keras_parameterized.run_with_all_model_types def test_compiled_model(self): @@ -765,6 +759,20 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape) self.assertEqual('float16', loaded.input_spec['b'].dtype) + def test_multi_input_model(self): + input_1 = keras.layers.Input(shape=(3,)) + input_2 = keras.layers.Input(shape=(5,)) + model = keras.Model([input_1, input_2], [input_1, input_2]) + saved_model_dir = self._save_model_dir() + + model.save(saved_model_dir, save_format='tf') + loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir) + input_arr_1 = np.random.random((1, 3)).astype('float32') + input_arr_2 = np.random.random((1, 5)).astype('float32') + + outputs = loaded([input_arr_1, input_arr_2]) + self.assertAllEqual(input_arr_1, outputs[0]) + self.assertAllEqual(input_arr_2, outputs[1]) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py index 866f596fca3..718c2ad5340 100644 --- a/tensorflow/python/keras/saving/saving_utils.py +++ b/tensorflow/python/keras/saving/saving_utils.py @@ -66,7 +66,7 @@ def model_input_signature(model): Returns: A list containing either a single TensorSpec or an object with nested - TensorSpecs. + TensorSpecs. This list does not contain the `training` argument. """ try: inputs = model.inputs diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 94618989e4f..4804e4d0e0d 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -177,11 +177,11 @@ class RestoredFunction(def_function.Function): # TODO(mdan): We may enable autograph once exceptions are supported. super(RestoredFunction, self).__init__( python_function, name, autograph=False) - self._concrete_functions = concrete_functions + self.concrete_functions = concrete_functions self._function_spec = function_spec def _list_all_concrete_functions_for_serialization(self): - return self._concrete_functions + return self.concrete_functions def _defun_with_scope(self, scope): func = super(RestoredFunction, self)._defun_with_scope(scope) diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 23ff7093a4b..984c9ab2cde 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -179,7 +179,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): # Calling get_concrete_function wraps in a second call operation; we want to # inspect the original function body for the control output; digging into # graph.as_graph_def() and its FunctionDefLibrary is another option. - imported_concrete, = imported.f._concrete_functions + imported_concrete, = imported.f.concrete_functions imported_graph = imported_concrete.graph self.assertIn( imported_graph.get_operation_by_name("should_be_control_output"),