Automated rollback of commit 48c906cfdc

PiperOrigin-RevId: 249327590
This commit is contained in:
Thomas O'Malley 2019-05-21 14:36:50 -07:00 committed by TensorFlower Gardener
parent 62f3c16a9a
commit 438ff85035
3 changed files with 14 additions and 17 deletions
tensorflow/python/keras

View File

@ -547,12 +547,11 @@ class Layer(module.Module):
inputs = nest.map_structure(_convert_non_tensor, inputs)
input_list = nest.flatten(inputs)
# We will attempt to build a TF graph if we are not in a `tf.function` and
# all inputs are symbolic. This is always the case in graph mode. It can
# also be the case in eager mode when all inputs can be traced back to
# `keras.Input()` (when building models using the functional API).
build_graph = (not base_layer_utils.is_in_tf_function() and
tf_utils.are_all_symbolic_tensors(input_list))
# We will attempt to build a TF graph if & only if all inputs are symbolic.
# This is always the case in graph mode. It can also be the case in eager
# mode when all inputs can be traced back to `keras.Input()` (when building
# models using the functional API).
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
if build_graph:
# Only create Keras history if at least one tensor originates from a
@ -664,6 +663,13 @@ class Layer(module.Module):
inputs, outputs, args, kwargs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, previous_mask)
if hasattr(self, '_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by
# a call to self._set_inputs().
# TODO(b/120997007): This should be done in Eager as well, but
# causes garbage collection issues because of the placeholders
# created on the default Keras graph.
self._set_inputs(inputs, outputs)
else:
# Eager execution on data tensors.
with backend.name_scope(self._name_scope()):
@ -674,15 +680,6 @@ class Layer(module.Module):
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, previous_mask)
if build_graph or base_layer_utils.is_in_tf_function():
# When symbolic inputs are passed, track them for saving purposes.
if hasattr(self, '_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by
# a call to `_set_inputs` in `compile`.
# TODO(b/120997007): Some of this functionality should be done
# in Eager as well, but placeholders are created which causes
# garbage collection and other issues.
self._set_inputs(inputs, outputs)
return outputs
@property

View File

@ -104,7 +104,7 @@ def trace_model_call(model, input_signature=None):
# When given a single input, Keras models will call the model on the tensor
# rather than a list consisting of the single tensor.
inputs = args[0] if len(input_signature) == 1 else list(args)
outputs_list = nest.flatten(model(inputs))
outputs_list = nest.flatten(model(inputs=inputs))
try:
output_names = model.output_names
except AttributeError:

View File

@ -171,7 +171,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
fn = saving_utils.trace_model_call(
model, [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)])
signature_outputs = fn(inputs)
expected_outputs = {'output_1': model(inputs)}
expected_outputs = {model.output_names[0]: model(inputs)}
self._assert_all_close(expected_outputs, signature_outputs)
@test_util.run_in_graph_and_eager_modes