parent
62f3c16a9a
commit
438ff85035
@ -547,12 +547,11 @@ class Layer(module.Module):
|
|||||||
inputs = nest.map_structure(_convert_non_tensor, inputs)
|
inputs = nest.map_structure(_convert_non_tensor, inputs)
|
||||||
input_list = nest.flatten(inputs)
|
input_list = nest.flatten(inputs)
|
||||||
|
|
||||||
# We will attempt to build a TF graph if we are not in a `tf.function` and
|
# We will attempt to build a TF graph if & only if all inputs are symbolic.
|
||||||
# all inputs are symbolic. This is always the case in graph mode. It can
|
# This is always the case in graph mode. It can also be the case in eager
|
||||||
# also be the case in eager mode when all inputs can be traced back to
|
# mode when all inputs can be traced back to `keras.Input()` (when building
|
||||||
# `keras.Input()` (when building models using the functional API).
|
# models using the functional API).
|
||||||
build_graph = (not base_layer_utils.is_in_tf_function() and
|
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
|
||||||
tf_utils.are_all_symbolic_tensors(input_list))
|
|
||||||
|
|
||||||
if build_graph:
|
if build_graph:
|
||||||
# Only create Keras history if at least one tensor originates from a
|
# Only create Keras history if at least one tensor originates from a
|
||||||
@ -664,6 +663,13 @@ class Layer(module.Module):
|
|||||||
inputs, outputs, args, kwargs)
|
inputs, outputs, args, kwargs)
|
||||||
self._handle_activity_regularization(inputs, outputs)
|
self._handle_activity_regularization(inputs, outputs)
|
||||||
self._set_mask_metadata(inputs, outputs, previous_mask)
|
self._set_mask_metadata(inputs, outputs, previous_mask)
|
||||||
|
if hasattr(self, '_set_inputs') and not self.inputs:
|
||||||
|
# Subclassed network: explicitly set metadata normally set by
|
||||||
|
# a call to self._set_inputs().
|
||||||
|
# TODO(b/120997007): This should be done in Eager as well, but
|
||||||
|
# causes garbage collection issues because of the placeholders
|
||||||
|
# created on the default Keras graph.
|
||||||
|
self._set_inputs(inputs, outputs)
|
||||||
else:
|
else:
|
||||||
# Eager execution on data tensors.
|
# Eager execution on data tensors.
|
||||||
with backend.name_scope(self._name_scope()):
|
with backend.name_scope(self._name_scope()):
|
||||||
@ -674,15 +680,6 @@ class Layer(module.Module):
|
|||||||
self._handle_activity_regularization(inputs, outputs)
|
self._handle_activity_regularization(inputs, outputs)
|
||||||
self._set_mask_metadata(inputs, outputs, previous_mask)
|
self._set_mask_metadata(inputs, outputs, previous_mask)
|
||||||
|
|
||||||
if build_graph or base_layer_utils.is_in_tf_function():
|
|
||||||
# When symbolic inputs are passed, track them for saving purposes.
|
|
||||||
if hasattr(self, '_set_inputs') and not self.inputs:
|
|
||||||
# Subclassed network: explicitly set metadata normally set by
|
|
||||||
# a call to `_set_inputs` in `compile`.
|
|
||||||
# TODO(b/120997007): Some of this functionality should be done
|
|
||||||
# in Eager as well, but placeholders are created which causes
|
|
||||||
# garbage collection and other issues.
|
|
||||||
self._set_inputs(inputs, outputs)
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -104,7 +104,7 @@ def trace_model_call(model, input_signature=None):
|
|||||||
# When given a single input, Keras models will call the model on the tensor
|
# When given a single input, Keras models will call the model on the tensor
|
||||||
# rather than a list consisting of the single tensor.
|
# rather than a list consisting of the single tensor.
|
||||||
inputs = args[0] if len(input_signature) == 1 else list(args)
|
inputs = args[0] if len(input_signature) == 1 else list(args)
|
||||||
outputs_list = nest.flatten(model(inputs))
|
outputs_list = nest.flatten(model(inputs=inputs))
|
||||||
try:
|
try:
|
||||||
output_names = model.output_names
|
output_names = model.output_names
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
@ -171,7 +171,7 @@ class TraceModelCallTest(keras_parameterized.TestCase):
|
|||||||
fn = saving_utils.trace_model_call(
|
fn = saving_utils.trace_model_call(
|
||||||
model, [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)])
|
model, [tensor_spec.TensorSpec(shape=[None, 5], dtype=dtypes.float32)])
|
||||||
signature_outputs = fn(inputs)
|
signature_outputs = fn(inputs)
|
||||||
expected_outputs = {'output_1': model(inputs)}
|
expected_outputs = {model.output_names[0]: model(inputs)}
|
||||||
self._assert_all_close(expected_outputs, signature_outputs)
|
self._assert_all_close(expected_outputs, signature_outputs)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
Loading…
Reference in New Issue
Block a user