parent
62f3c16a9a
commit
438ff85035
tensorflow/python/keras
@ -547,12 +547,11 @@ class Layer(module.Module):
|
||||
inputs = nest.map_structure(_convert_non_tensor, inputs)
|
||||
input_list = nest.flatten(inputs)
|
||||
|
||||
# We will attempt to build a TF graph if we are not in a `tf.function` and
|
||||
# all inputs are symbolic. This is always the case in graph mode. It can
|
||||
# also be the case in eager mode when all inputs can be traced back to
|
||||
# `keras.Input()` (when building models using the functional API).
|
||||
build_graph = (not base_layer_utils.is_in_tf_function() and
|
||||
tf_utils.are_all_symbolic_tensors(input_list))
|
||||
# We will attempt to build a TF graph if & only if all inputs are symbolic.
|
||||
# This is always the case in graph mode. It can also be the case in eager
|
||||
# mode when all inputs can be traced back to `keras.Input()` (when building
|
||||
# models using the functional API).
|
||||
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
|
||||
|
||||
if build_graph:
|
||||
# Only create Keras history if at least one tensor originates from a
|
||||
@ -664,6 +663,13 @@ class Layer(module.Module):
|
||||
inputs, outputs, args, kwargs)
|
||||
self._handle_activity_regularization(inputs, outputs)
|
||||
self._set_mask_metadata(inputs, outputs, previous_mask)
|
||||
if hasattr(self, '_set_inputs') and not self.inputs:
|
||||
# Subclassed network: explicitly set metadata normally set by
|
||||
# a call to self._set_inputs().
|
||||
# TODO(b/120997007): This should be done in Eager as well, but
|
||||
# causes garbage collection issues because of the placeholders
|
||||
# created on the default Keras graph.
|
||||
self._set_inputs(inputs, outputs)
|
||||
else:
|
||||
# Eager execution on data tensors.
|
||||
with backend.name_scope(self._name_scope()):
|
||||
@ -674,15 +680,6 @@ class Layer(module.Module):
|
||||
self._handle_activity_regularization(inputs, outputs)
|
||||
self._set_mask_metadata(inputs, outputs, previous_mask)
|
||||
|
||||
if build_graph or base_layer_utils.is_in_tf_function():
|
||||
# When symbolic inputs are passed, track them for saving purposes.
|
||||
if hasattr(self, '_set_inputs') and not self.inputs:
|
||||
# Subclassed network: explicitly set metadata normally set by
|
||||
# a call to `_set_inputs` in `compile`.
|
||||
# TODO(b/120997007): Some of this functionality should be done
|
||||
# in Eager as well, but placeholders are created which causes
|
||||
# garbage collection and other issues.
|
||||
self._set_inputs(inputs, outputs)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
|
@ -104,7 +104,7 @@ def trace_model_call(model, input_signature=None):
|
||||
# When given a single input, Keras models will call the model on the tensor
|
||||
# rather than a list consisting of the single tensor.
|
||||
inputs = args[0] if len(input_signature) == 1 else list(args)
|
||||
outputs_list = nest.flatten(model(inputs))
|
||||
outputs_list = nest.flatten(model(inputs=inputs))
|
||||
try:
|
||||
output_names = model.output_names
|
||||
except AttributeError:
|
||||
|
@ -171,7 +171,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
||||
fn = saving_utils.trace_model_call(
|
||||
model, [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)])
|
||||
signature_outputs = fn(inputs)
|
||||
expected_outputs = {'output_1': model(inputs)}
|
||||
expected_outputs = {model.output_names[0]: model(inputs)}
|
||||
self._assert_all_close(expected_outputs, signature_outputs)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
|
Loading…
Reference in New Issue
Block a user