Generalizes the first argument in keras layers further. Now, functional models get constructed if *any* tensor in the arguments or keyword arguments has a keras history, rather than if *all* of the elements in the first argument to the layer do.
PiperOrigin-RevId: 313718130 Change-Id: I77f65f49decf45f6a2b53ab0519d6d2ac38232d3
This commit is contained in:
parent
618ff4c618
commit
a1ae008076
|
@ -830,14 +830,13 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
in_call = call_context.in_call
|
||||
input_list = nest.flatten(inputs)
|
||||
|
||||
# We will attempt to build a TF graph if & only if all inputs are symbolic.
|
||||
# This is always the case in graph mode. It can also be the case in eager
|
||||
# mode when all inputs can be traced back to `keras.Input()` (when building
|
||||
# models using the functional API).
|
||||
# TODO(kaftan): make this not special case inputs. Instead
|
||||
# build a functional api model if *any* *arg or **kwarg is symbolic,
|
||||
# even if part of the data structure in that arg is not symbolic.
|
||||
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
|
||||
# We will attempt to trace in a graph if & only if inputs are symbolic.
|
||||
# This is always the case when tracing a function. It can also be the case
|
||||
# when running eagerly if any input can be traced back to `keras.Input()`
|
||||
# (when building models using the functional API).
|
||||
build_graph = tf_utils.are_all_symbolic_tensors(input_list) or (
|
||||
any(map(tf_utils.is_symbolic_tensor, nest.flatten(
|
||||
[input_list, args, kwargs]))) and context.executing_eagerly())
|
||||
|
||||
# Accept NumPy and scalar inputs by converting to Tensors.
|
||||
if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
|
||||
|
@ -890,11 +889,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
'training', training_value, args, kwargs)
|
||||
training_arg_passed_by_framework = True
|
||||
|
||||
# Only create Keras history if at least one tensor originates from a
|
||||
# `keras.Input`. Otherwise this Layer may be being used outside the Keras
|
||||
# framework.
|
||||
# TODO(kaftan): make this not special case inputs
|
||||
if build_graph and base_layer_utils.needs_keras_history(inputs):
|
||||
# Turn inputs into TF op layers if necessary.
|
||||
# This process is fragile and prone to bad interactions with inputs
|
||||
# when calling nested layers with tf.functions floating around,
|
||||
# and with nonsymbolic tensors.
|
||||
# So, we limit it to the
|
||||
# case where *all* inputs in the first arg are symbolic.
|
||||
if (tf_utils.are_all_symbolic_tensors(input_list)
|
||||
and base_layer_utils.needs_keras_history(inputs)):
|
||||
base_layer_utils.create_keras_history(inputs)
|
||||
|
||||
with call_context.enter(self, inputs, build_graph, training_value):
|
||||
|
@ -968,8 +970,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
raise ValueError('A layer\'s `call` method should return a '
|
||||
'Tensor or a list of Tensors, not None '
|
||||
'(layer: ' + self.name + ').')
|
||||
# TODO(kaftan): This should be 'any' and check all args
|
||||
if base_layer_utils.have_all_keras_metadata(inputs):
|
||||
# We configure connectivity metadata if all inputs in the first
|
||||
# arg have keras history, or if we're actively building the
|
||||
# functional api outside of any outer keras model.
|
||||
if base_layer_utils.have_all_keras_metadata(inputs) or (
|
||||
context.executing_eagerly() and
|
||||
base_layer_utils.have_any_keras_metadata(inputs, args, kwargs)):
|
||||
if training_arg_passed_by_framework:
|
||||
args, kwargs = self._set_call_arg_value(
|
||||
'training', None, args, kwargs, pop_kwarg_if_none=True)
|
||||
|
|
|
@ -165,6 +165,10 @@ def have_all_keras_metadata(tensors):
|
|||
return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
|
||||
|
||||
|
||||
def have_any_keras_metadata(*tensors):
|
||||
return any(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
|
||||
|
||||
|
||||
def generate_placeholders_from_shape(shape):
|
||||
return array_ops.placeholder(shape=shape, dtype=backend.floatx())
|
||||
|
||||
|
@ -214,7 +218,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
|||
for tensor in tensor_list:
|
||||
if getattr(tensor, '_keras_history', None) is not None:
|
||||
continue
|
||||
op = tensor.op # The Op that created this Tensor.
|
||||
try:
|
||||
op = tensor.op # The Op that created this Tensor.
|
||||
except AttributeError:
|
||||
continue
|
||||
if op not in processed_ops:
|
||||
if op.type.startswith('Sparse'):
|
||||
lambda_example = """
|
||||
|
@ -392,7 +399,10 @@ def mark_checked(tensors):
|
|||
"""
|
||||
|
||||
def _mark_checked(tensor):
|
||||
tensor._keras_history_checked = True # pylint: disable=protected-access
|
||||
try:
|
||||
tensor._keras_history_checked = True # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
nest.map_structure(_mark_checked, tensors)
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from tensorflow.python.keras import backend
|
|||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_layer as input_layer_module
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
from tensorflow.python.keras.engine import training as training_lib
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.saving.saved_model import network_serialization
|
||||
|
@ -1111,19 +1112,28 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None):
|
|||
kwargs = {}
|
||||
elif len(input_data) == 4:
|
||||
kwargs = input_data[3]
|
||||
kwargs = _deserialize_keras_tensors(kwargs, created_layers)
|
||||
try:
|
||||
kwargs = _deserialize_keras_tensors(kwargs, created_layers)
|
||||
except IndexError:
|
||||
# Happens if keras tensors in kwargs are still unprocessed
|
||||
add_unprocessed_node(layer, node_data)
|
||||
return
|
||||
else:
|
||||
raise ValueError('Improperly formatted model config.')
|
||||
|
||||
inbound_layer = created_layers[inbound_layer_name]
|
||||
inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
|
||||
if inbound_layer_name != node_module._CONSTANT_VALUE:
|
||||
inbound_layer = created_layers[inbound_layer_name]
|
||||
inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
|
||||
|
||||
if inbound_node_index is None:
|
||||
add_unprocessed_node(layer, node_data)
|
||||
return
|
||||
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
|
||||
input_tensors.append(
|
||||
nest.flatten(inbound_node.outputs)[inbound_tensor_index])
|
||||
if inbound_node_index is None:
|
||||
add_unprocessed_node(layer, node_data)
|
||||
return
|
||||
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
|
||||
input_tensors.append(
|
||||
nest.flatten(inbound_node.outputs)[inbound_tensor_index])
|
||||
else:
|
||||
# We received a constant w/ no Keras history attached
|
||||
input_tensors.append(inbound_tensor_index)
|
||||
input_tensors = nest.pack_sequence_as(node_data, input_tensors)
|
||||
# Call layer on its inputs, thus creating the node
|
||||
# and building the layer if needed.
|
||||
|
|
|
@ -964,6 +964,43 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
|
||||
|
||||
class MaybeAdd(layers.Layer):
|
||||
|
||||
def call(self, x1, x2=None):
|
||||
if x2 is not None:
|
||||
return x1 + x2
|
||||
return x1
|
||||
|
||||
input2 = input_layer_lib.Input(10)
|
||||
outputs = MaybeAdd()(3., x2=input2)
|
||||
model = training_lib.Model([input2], outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=7 * np.ones((10, 10)),
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=7 * np.ones((10, 10)),
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_composite_call_kwarg_derived_from_keras_layer(self):
|
||||
|
||||
|
@ -1005,6 +1042,56 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations(mode='eager'))
|
||||
def test_call_some_not_all_nested_in_first_arg_derived_from_keras_layer(self):
|
||||
# This functionality is unsupported in v1 graphs
|
||||
|
||||
class AddAll(layers.Layer):
|
||||
|
||||
def call(self, x1_x2, x3):
|
||||
x1, x2 = x1_x2
|
||||
out = x1 + x2
|
||||
if x3 is not None:
|
||||
for t in x3.values():
|
||||
out += t
|
||||
return out
|
||||
|
||||
input1 = input_layer_lib.Input(10)
|
||||
input2 = input_layer_lib.Input(10)
|
||||
input3 = input_layer_lib.Input(10)
|
||||
|
||||
outputs = AddAll()(
|
||||
[input1, 4 * array_ops.ones((1, 10))],
|
||||
x3={
|
||||
'a': input2,
|
||||
'b': input3,
|
||||
'c': 5 * array_ops.ones((1, 10))
|
||||
})
|
||||
model = training_lib.Model([input1, input2, input3], outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
|
||||
y=15 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that all inputs were correctly added.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'AddAll': AddAll})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
|
||||
y=15 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that all inputs were correctly added.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_nested_arg_derived_from_keras_layer(self):
|
||||
|
||||
|
|
|
@ -32,6 +32,8 @@ from tensorflow.python.platform import tf_logging as logging
|
|||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
_CONSTANT_VALUE = '_CONSTANT_VALUE'
|
||||
|
||||
|
||||
class Node(object):
|
||||
"""A `Node` describes the connectivity between two layers.
|
||||
|
@ -181,11 +183,14 @@ class Node(object):
|
|||
# `kwargs` is added to each Tensor in the first arg. This should be
|
||||
# changed in a future version of the serialization format.
|
||||
def serialize_first_arg_tensor(t):
|
||||
kh = t._keras_history
|
||||
node_index = kh.node_index
|
||||
node_key = make_node_key(kh.layer.name, node_index)
|
||||
new_node_index = node_conversion_map.get(node_key, 0)
|
||||
data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
|
||||
if is_keras_tensor(t):
|
||||
kh = t._keras_history
|
||||
node_index = kh.node_index
|
||||
node_key = make_node_key(kh.layer.name, node_index)
|
||||
new_node_index = node_conversion_map.get(node_key, 0)
|
||||
data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
|
||||
else:
|
||||
data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
|
||||
return tf_utils.ListWrapper(data)
|
||||
|
||||
data = nest.map_structure(serialize_first_arg_tensor, inputs)
|
||||
|
|
Loading…
Reference in New Issue