Generalizes the first argument in keras layers further. Now, functional models get constructed if *any* tensor in the arguments or keyword arguments has a keras history, rather than if *all* of the elements in the first argument to the layer do.

PiperOrigin-RevId: 313718130
Change-Id: I77f65f49decf45f6a2b53ab0519d6d2ac38232d3
This commit is contained in:
Tomer Kaftan 2020-05-28 21:44:18 -07:00 committed by TensorFlower Gardener
parent 618ff4c618
commit a1ae008076
5 changed files with 149 additions and 31 deletions

View File

@ -830,14 +830,13 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
in_call = call_context.in_call
input_list = nest.flatten(inputs)
# We will attempt to build a TF graph if & only if all inputs are symbolic.
# This is always the case in graph mode. It can also be the case in eager
# mode when all inputs can be traced back to `keras.Input()` (when building
# models using the functional API).
# TODO(kaftan): make this not special case inputs. Instead
# build a functional api model if *any* *arg or **kwarg is symbolic,
# even if part of the data structure in that arg is not symbolic.
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
# We will attempt to trace in a graph if & only if inputs are symbolic.
# This is always the case when tracing a function. It can also be the case
# when running eagerly if any input can be traced back to `keras.Input()`
# (when building models using the functional API).
build_graph = tf_utils.are_all_symbolic_tensors(input_list) or (
any(map(tf_utils.is_symbolic_tensor, nest.flatten(
[input_list, args, kwargs]))) and context.executing_eagerly())
# Accept NumPy and scalar inputs by converting to Tensors.
if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
@ -890,11 +889,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
'training', training_value, args, kwargs)
training_arg_passed_by_framework = True
# Only create Keras history if at least one tensor originates from a
# `keras.Input`. Otherwise this Layer may be being used outside the Keras
# framework.
# TODO(kaftan): make this not special case inputs
if build_graph and base_layer_utils.needs_keras_history(inputs):
# Turn inputs into TF op layers if necessary.
# This process is fragile and prone to bad interactions with inputs
# when calling nested layers with tf.functions floating around,
# and with nonsymbolic tensors.
# So, we limit it to the
# case where *all* inputs in the first arg are symbolic.
if (tf_utils.are_all_symbolic_tensors(input_list)
and base_layer_utils.needs_keras_history(inputs)):
base_layer_utils.create_keras_history(inputs)
with call_context.enter(self, inputs, build_graph, training_value):
@ -968,8 +970,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
raise ValueError('A layer\'s `call` method should return a '
'Tensor or a list of Tensors, not None '
'(layer: ' + self.name + ').')
# TODO(kaftan): This should be 'any' and check all args
if base_layer_utils.have_all_keras_metadata(inputs):
# We configure connectivity metadata if all inputs in the first
# arg have keras history, or if we're actively building the
# functional api outside of any outer keras model.
if base_layer_utils.have_all_keras_metadata(inputs) or (
context.executing_eagerly() and
base_layer_utils.have_any_keras_metadata(inputs, args, kwargs)):
if training_arg_passed_by_framework:
args, kwargs = self._set_call_arg_value(
'training', None, args, kwargs, pop_kwarg_if_none=True)

View File

@ -165,6 +165,10 @@ def have_all_keras_metadata(tensors):
return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
def have_any_keras_metadata(*tensors):
return any(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
def generate_placeholders_from_shape(shape):
return array_ops.placeholder(shape=shape, dtype=backend.floatx())
@ -214,7 +218,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
for tensor in tensor_list:
if getattr(tensor, '_keras_history', None) is not None:
continue
try:
op = tensor.op # The Op that created this Tensor.
except AttributeError:
continue
if op not in processed_ops:
if op.type.startswith('Sparse'):
lambda_example = """
@ -392,7 +399,10 @@ def mark_checked(tensors):
"""
def _mark_checked(tensor):
try:
tensor._keras_history_checked = True # pylint: disable=protected-access
except AttributeError:
pass
nest.map_structure(_mark_checked, tensors)

View File

@ -32,6 +32,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_layer as input_layer_module
from tensorflow.python.keras.engine import node as node_module
from tensorflow.python.keras.engine import training as training_lib
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.saving.saved_model import network_serialization
@ -1111,10 +1112,16 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None):
kwargs = {}
elif len(input_data) == 4:
kwargs = input_data[3]
try:
kwargs = _deserialize_keras_tensors(kwargs, created_layers)
except IndexError:
# Happens if keras tensors in kwargs are still unprocessed
add_unprocessed_node(layer, node_data)
return
else:
raise ValueError('Improperly formatted model config.')
if inbound_layer_name != node_module._CONSTANT_VALUE:
inbound_layer = created_layers[inbound_layer_name]
inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
@ -1124,6 +1131,9 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None):
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
input_tensors.append(
nest.flatten(inbound_node.outputs)[inbound_tensor_index])
else:
# We received a constant w/ no Keras history attached
input_tensors.append(inbound_tensor_index)
input_tensors = nest.pack_sequence_as(node_data, input_tensors)
# Call layer on its inputs, thus creating the node
# and building the layer if needed.

View File

@ -964,6 +964,43 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
class MaybeAdd(layers.Layer):
def call(self, x1, x2=None):
if x2 is not None:
return x1 + x2
return x1
input2 = input_layer_lib.Input(10)
outputs = MaybeAdd()(3., x2=input2)
model = training_lib.Model([input2], outputs)
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=7 * np.ones((10, 10)),
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=7 * np.ones((10, 10)),
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_composite_call_kwarg_derived_from_keras_layer(self):
@ -1005,6 +1042,56 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations(mode='eager'))
def test_call_some_not_all_nested_in_first_arg_derived_from_keras_layer(self):
# This functionality is unsupported in v1 graphs
class AddAll(layers.Layer):
def call(self, x1_x2, x3):
x1, x2 = x1_x2
out = x1 + x2
if x3 is not None:
for t in x3.values():
out += t
return out
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
input3 = input_layer_lib.Input(10)
outputs = AddAll()(
[input1, 4 * array_ops.ones((1, 10))],
x3={
'a': input2,
'b': input3,
'c': 5 * array_ops.ones((1, 10))
})
model = training_lib.Model([input1, input2, input3], outputs)
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
y=15 * np.ones((10, 10)),
batch_size=2)
# Check that all inputs were correctly added.
self.assertEqual(history.history['loss'][0], 0.0)
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'AddAll': AddAll})
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
y=15 * np.ones((10, 10)),
batch_size=2)
# Check that all inputs were correctly added.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_call_nested_arg_derived_from_keras_layer(self):

View File

@ -32,6 +32,8 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
_CONSTANT_VALUE = '_CONSTANT_VALUE'
class Node(object):
"""A `Node` describes the connectivity between two layers.
@ -181,11 +183,14 @@ class Node(object):
# `kwargs` is added to each Tensor in the first arg. This should be
# changed in a future version of the serialization format.
def serialize_first_arg_tensor(t):
if is_keras_tensor(t):
kh = t._keras_history
node_index = kh.node_index
node_key = make_node_key(kh.layer.name, node_index)
new_node_index = node_conversion_map.get(node_key, 0)
data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
else:
data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
return tf_utils.ListWrapper(data)
data = nest.map_structure(serialize_first_arg_tensor, inputs)