Make Network compatible with eager mode. Currently it only allows to instantiate a Network in eager mode using the regular Keras API, and call it on eager tensors.

PiperOrigin-RevId: 172942569
This commit is contained in:
Francois Chollet 2017-10-20 15:29:41 -07:00 committed by TensorFlower Gardener
parent 37fd951790
commit 66b1f43839
5 changed files with 228 additions and 110 deletions
tensorflow/python

View File

@ -776,7 +776,7 @@ class Network(tf_base_layers.Network, Layer):
if cache_key in self._output_mask_cache:
return self._output_mask_cache[cache_key]
else:
_, output_masks, _ = self._run_internal_graph(inputs, masks)
_, output_masks = self._run_internal_graph(inputs, masks)
return output_masks
def get_config(self):

View File

@ -192,10 +192,12 @@ class KerasIntegrationTest(test.TestCase):
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
self.assertEqual(len(model.losses), 2)
self.assertEqual(len(model.updates), 2)
history = model.fit(x_train, y_train, epochs=10, batch_size=16,
validation_data=(x_test, y_test),
verbose=2)
self.assertGreater(history.history['val_acc'][-1], 0.85)
self.assertGreater(history.history['val_acc'][-1], 0.84)
def test_vector_classification_shared_model(self):
# Test that functional models that feature internal updates

View File

@ -420,6 +420,8 @@ class Sequential(Model):
# Used by Layer base class.
self._dtype = None
self._activity_regularizer = None
self._per_input_losses = {}
self._per_input_updates = {}
# The following properties are not actually used by Keras;
# they exist for compatibility with TF's variable scoping mechanism.

View File

@ -508,6 +508,7 @@ class Layer(object):
input_list = nest.flatten(inputs)
in_graph_mode = context.in_graph_mode()
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
# Ensure the Layer, if being reused, is working with inputs from
# the same graph as where it was created.
if in_graph_mode:
@ -515,6 +516,7 @@ class Layer(object):
ops._get_graph_from_inputs(input_list, graph=self.graph) # pylint: disable=protected-access
except ValueError as e:
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
if in_graph_mode or in_deferred_mode:
user_kwargs = copy.copy(kwargs)
# Handle Keras mask propagation from previous layer to current layer.
@ -553,6 +555,7 @@ class Layer(object):
raise ValueError('activity_regularizer currently unsupported in '
'Eager mode. Found an activity_regularizer in '
'%s(%s).' % (self.__class__.__name__, self))
if not in_graph_mode and not in_deferred_mode:
# TODO(agarwal): support _keras_history in Eager mode.
for x in input_list:
if hasattr(x, '_keras_history'):
@ -581,13 +584,26 @@ class Layer(object):
if call_has_scope_arg:
kwargs['scope'] = scope
# Check input assumptions set after layer building, e.g. input shape.
if in_graph_mode:
if in_graph_mode or in_deferred_mode:
self._assert_input_compatibility(inputs)
outputs = self.call(inputs, *args, **kwargs)
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None.')
if not in_deferred_mode:
outputs = self.call(inputs, *args, **kwargs)
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None.')
else:
# Deferred mode behavior: use `_compute_output_shape` to
# infer the number of outputs of the layer and their shapes.
output_shapes = self._compute_output_shape(input_shapes)
output_shapes = nest.flatten(output_shapes)
outputs = [
# TODO(fchollet): name the deferred tensors?
_DeferredTensor(shape=shape, dtype=self._dtype)
for shape in output_shapes
]
if len(outputs) == 1:
outputs = outputs[0]
if in_graph_mode:
# Apply activity regularization.
@ -600,16 +616,18 @@ class Layer(object):
activity_regularization = self._activity_regularizer(output)
self.add_loss(activity_regularization)
# Handle mask computation and propagation to the next layer.
if hasattr(self, 'compute_mask'):
output_mask = self.compute_mask(inputs, previous_mask)
if isinstance(outputs, list):
if output_mask is None:
output_mask = [None for _ in range(len(outputs))]
for x, m in zip(outputs, output_mask):
x._keras_mask = m # pylint: disable=protected-access
else:
outputs._keras_mask = output_mask # pylint: disable=protected-access
if not in_deferred_mode:
# TODO(fchollet): consider how masking will work with deferred mode.
# Handle mask computation and propagation to the next layer.
if hasattr(self, 'compute_mask'):
output_mask = self.compute_mask(inputs, previous_mask)
if isinstance(outputs, list):
if output_mask is None:
output_mask = [None for _ in range(len(outputs))]
for x, m in zip(outputs, output_mask):
x._keras_mask = m # pylint: disable=protected-access
else:
outputs._keras_mask = output_mask # pylint: disable=protected-access
if in_graph_mode:
# If all input tensors have history metadata,
@ -631,14 +649,16 @@ class Layer(object):
else:
outputs = output_ls_copy
# Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
if in_deferred_mode or in_graph_mode:
if _have_all_keras_metadata(inputs):
# Add an inbound node to the layer, so it can keep track of this call.
# This updates the layer history of the output tensor(s).
self._add_inbound_node(
input_tensors=inputs, output_tensors=outputs, arguments=user_kwargs)
# Update global default collections.
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
self.built = True
return outputs
@ -692,7 +712,6 @@ class Layer(object):
arguments: dictionary of keyword arguments that were passed to the
`call` method of the layer at the call that created the node.
"""
assert context.in_graph_mode()
input_tensors = nest.flatten(input_tensors)
output_tensors = nest.flatten(output_tensors)
@ -1251,6 +1270,34 @@ class Node(object):
}
class _DeferredTensor(object):
"""Tensor-like object used to build graphs of layers in Eager mode.
When calling a layer on a DeferredTensor, the layer will not perform any
computation and will simply perfom shape inference to return new
DeferredTensors with appropriate shape information. Thus DeferredTensor
behaves like a graph-mode Tensor when manipulated by layers.
"""
def __init__(self, shape, dtype, name=None):
self.shape = tensor_shape.TensorShape(shape)
self.dtype = dtypes.as_dtype(dtype)
self.name = name
def get_shape(self):
return self.shape
def __str__(self):
return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name,
self.get_shape(),
self.dtype.name)
def __repr__(self):
return "<_DeferredTensor '%s' shape=%s dtype=%s>" % (self.name,
self.get_shape(),
self.dtype.name)
class InputLayer(Layer):
"""Layer to be used as an entry point into a Network (a graph of layers).
@ -1283,8 +1330,6 @@ class InputLayer(Layer):
input_tensor=None,
sparse=False,
name=None):
if context.in_eager_mode():
raise RuntimeError('InputLayer not supported in Eager mode.')
super(InputLayer, self).__init__(dtype=dtype, name=name)
self.built = True
self.sparse = sparse
@ -1299,16 +1344,24 @@ class InputLayer(Layer):
else:
batch_input_shape = None
if sparse:
input_tensor = array_ops.sparse_placeholder(
if context.in_eager_mode():
# In eager mode, create a temporary placeholder to call the layer on.
input_tensor = _DeferredTensor(
shape=batch_input_shape,
dtype=dtype,
name=self.name)
else:
input_tensor = array_ops.placeholder(
shape=batch_input_shape,
dtype=dtype,
name=self.name)
# In graph mode, create a graph placeholder to call the layer on.
if sparse:
input_tensor = array_ops.sparse_placeholder(
shape=batch_input_shape,
dtype=dtype,
name=self.name)
else:
input_tensor = array_ops.placeholder(
shape=batch_input_shape,
dtype=dtype,
name=self.name)
# For compatibility with Keras API.
self.is_placeholder = True
@ -1375,8 +1428,6 @@ def Input( # pylint: disable=invalid-name
Raises:
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
raise RuntimeError('Input not supported in Eager mode.')
input_layer = InputLayer(
input_shape=shape,
batch_size=batch_size,
@ -1440,9 +1491,10 @@ class Network(Layer):
"""
def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called
# TODO(agarwal): Make Network work in Eager mode.
if context.in_eager_mode():
raise RuntimeError('Network not supported in Eager mode.')
# TODO(fchollet): check that all inputs and outputs are DeferredTensors.
pass
# Set layer name and scope
if isinstance(name, vs.VariableScope):
base_name = name.name
@ -1919,16 +1971,17 @@ class Network(Layer):
masks = [None for _ in range(len(inputs))]
else:
masks = nest.flatten(mask)
# Try to retrieve cached outputs if the layer has already been called
# on these exact inputs.
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
if cache_key in self._output_tensor_cache:
# Cache hit.
return self._output_tensor_cache[cache_key]
else:
# Cache miss: actually apply the network graph to the new inputs.
output_tensors, _, _ = self._run_internal_graph(inputs, masks)
return output_tensors
if context.in_graph_mode():
# Try to retrieve cached outputs if the layer has already been called
# on these exact inputs.
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
if cache_key in self._output_tensor_cache:
# Cache hit.
return self._output_tensor_cache[cache_key]
# Actually apply the network graph to the new inputs.
outputs, _ = self._run_internal_graph(inputs, masks)
return outputs
def _compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
@ -2091,6 +2144,7 @@ class Network(Layer):
if 'mask' in estimator_util.fn_args(layer.call):
if 'mask' not in kwargs:
kwargs['mask'] = computed_mask
output_tensors = nest.flatten(
layer.call(computed_tensor, **kwargs))
if hasattr(layer, 'compute_mask'):
@ -2121,18 +2175,19 @@ class Network(Layer):
]
layer.add_loss(regularization_losses, computed_tensors)
# Update model updates and losses:
# Keep track of updates that depend on the inputs
# (e.g. BN updates).
self.add_update(layer.get_updates_for(computed_tensors), inputs)
# Keep track of unconditional updates (e.g. a counter).
self.add_update(layer.get_updates_for(None), None)
# Keep track of losses that depend on the inputs
# (e.g. activity regularizers).
self.add_loss(layer.get_losses_for(computed_tensors), inputs)
# Keep track of unconditional losses
# (e.g. weight regularizers).
self.add_loss(layer.get_losses_for(None), None)
if context.in_graph_mode():
# Update model updates and losses:
# Keep track of updates that depend on the inputs
# (e.g. BN updates).
self.add_update(layer.get_updates_for(computed_tensors), inputs)
# Keep track of unconditional updates (e.g. a counter).
self.add_update(layer.get_updates_for(None), None)
# Keep track of losses that depend on the inputs
# (e.g. activity regularizers).
self.add_loss(layer.get_losses_for(computed_tensors), inputs)
# Keep track of unconditional losses
# (e.g. weight regularizers).
self.add_loss(layer.get_losses_for(None), None)
# Update tensor_map.
for x, y, mask in zip(reference_output_tensors, output_tensors,
@ -2149,31 +2204,26 @@ class Network(Layer):
output_tensors.append(tensor)
output_masks.append(mask)
# Update cache;
# keys are based on ids on input tensors and inputs masks.
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
if len(output_tensors) == 1:
output_tensors = output_tensors[0]
self._output_tensor_cache[cache_key] = output_tensors
else:
self._output_tensor_cache[cache_key] = output_tensors
if len(output_masks) == 1:
output_masks = output_masks[0]
self._output_mask_cache[cache_key] = output_masks
else:
self._output_mask_cache[cache_key] = output_masks
if output_shapes is not None:
input_shapes = [_static_shape(x) for x in inputs]
cache_key = _object_list_uid(input_shapes)
if len(output_shapes) == 1:
if output_shapes is not None:
output_shapes = output_shapes[0]
if output_masks is not None:
output_masks = output_masks[0]
if context.in_graph_mode():
# Update cache;
# keys are based on ids on input tensors and inputs masks.
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
self._output_tensor_cache[cache_key] = output_tensors
if output_masks is not None:
self._output_mask_cache[cache_key] = output_masks
if output_shapes is not None:
input_shapes = [_static_shape(x) for x in inputs]
cache_key = _object_list_uid(input_shapes)
self._output_shape_cache[cache_key] = output_shapes
else:
self._output_shape_cache[cache_key] = output_shapes
return output_tensors, output_masks, output_shapes
return output_tensors, output_masks
def _is_tensor_or_tensor_list(v):

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import copy
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -41,13 +43,13 @@ class BaseLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testLayerProperties(self):
layer = base_layers.Layer(name='my_layer')
self.assertListEqual(layer.variables, [])
self.assertListEqual(layer.trainable_variables, [])
self.assertListEqual(layer.non_trainable_variables, [])
self.assertEqual(layer.variables, [])
self.assertEqual(layer.trainable_variables, [])
self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
# updates, losses only suppported in GRAPH mode
self.assertListEqual(layer.updates, [])
self.assertListEqual(layer.losses, [])
self.assertEqual(layer.updates, [])
self.assertEqual(layer.losses, [])
self.assertEqual(layer.built, False)
layer = base_layers.Layer(name='my_layer', trainable=False)
self.assertEqual(layer.trainable, False)
@ -60,11 +62,11 @@ class BaseLayerTest(test.TestCase):
variable = layer.add_variable(
'my_var', [2, 2], initializer=init_ops.zeros_initializer())
self.assertEqual(variable.name, 'my_layer/my_var:0')
self.assertListEqual(layer.variables, [variable])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [])
self.assertEqual(layer.variables, [variable])
self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
self.assertListEqual(
self.assertEqual(
layer.variables,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
@ -74,9 +76,9 @@ class BaseLayerTest(test.TestCase):
'non_trainable_var', [2, 2],
initializer=init_ops.zeros_initializer(),
trainable=False)
self.assertListEqual(layer.variables, [variable, variable_2])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [variable_2])
self.assertEqual(layer.variables, [variable, variable_2])
self.assertEqual(layer.trainable_variables, [variable])
self.assertEqual(layer.non_trainable_variables, [variable_2])
if context.in_graph_mode():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
@ -105,8 +107,8 @@ class BaseLayerTest(test.TestCase):
inputs = random_ops.random_uniform((5,), seed=1)
layer.apply(inputs)
layer.apply(inputs)
self.assertListEqual([v.name for v in layer.variables],
['my_layer/my_var:0'])
self.assertEqual([v.name for v in layer.variables],
['my_layer/my_var:0'])
# Creating a layer with no scope leads to lazy construction of
# the scope at apply() time. It uses scope "<current scope>/base_name"
@ -120,7 +122,7 @@ class BaseLayerTest(test.TestCase):
# The variables were created outside of the Layer, and
# reuse=True, so the Layer does not own them and they are not
# stored in its collection.
self.assertListEqual(lazy_layer.variables, [])
self.assertEqual(lazy_layer.variables, [])
self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer')
# Creating a layer with no scope leads to lazy construction of
@ -135,7 +137,7 @@ class BaseLayerTest(test.TestCase):
# The variables were created outside of the Layer, and
# reuse=True, so the Layer does not own them and they are not
# stored in its collection.
self.assertListEqual(lazy_layer.variables, [])
self.assertEqual(lazy_layer.variables, [])
self.assertEqual(lazy_layer._scope.name, 'new_scope')
# Checking for graph equality is only done in GRAPH mode.
@ -183,14 +185,14 @@ class BaseLayerTest(test.TestCase):
outputs = layer.apply(inputs)
self.assertEqual(layer.built, True)
self.assertEqual(outputs.op.name, 'my_layer/add')
self.assertListEqual([v.name
for v in layer.variables], ['my_layer/my_var:0'])
self.assertEqual([v.name
for v in layer.variables], ['my_layer/my_var:0'])
with self.assertRaisesRegexp(ValueError,
'my_layer/this_will_break_on_second_call'):
layer.apply(inputs)
# The list of variables hasn't changed.
self.assertListEqual([v.name
for v in layer.variables], ['my_layer/my_var:0'])
self.assertEqual([v.name
for v in layer.variables], ['my_layer/my_var:0'])
@test_util.run_in_graph_and_eager_modes()
def testDeepCopy(self):
@ -435,8 +437,8 @@ class BaseLayerTest(test.TestCase):
dense_layer.add_update(0, inputs=a)
dense_layer.add_update(1, inputs=None)
self.assertListEqual(dense_layer.get_updates_for(a), [0])
self.assertListEqual(dense_layer.get_updates_for(None), [1])
self.assertEqual(dense_layer.get_updates_for(a), [0])
self.assertEqual(dense_layer.get_updates_for(None), [1])
def test_get_losses_for(self):
a = base_layers.Input(shape=(2,))
@ -444,8 +446,8 @@ class BaseLayerTest(test.TestCase):
dense_layer.add_loss(0, inputs=a)
dense_layer.add_loss(1, inputs=None)
self.assertListEqual(dense_layer.get_losses_for(a), [0])
self.assertListEqual(dense_layer.get_losses_for(None), [1])
self.assertEqual(dense_layer.get_losses_for(a), [0])
self.assertEqual(dense_layer.get_losses_for(None), [1])
def testTopologicalAttributes(self):
# test layer attributes / methods related to cross-layer connectivity.
@ -612,7 +614,7 @@ class NetworkTest(test.TestCase):
a = base_layers.Input(shape=(32,), name='input_a')
b = base_layers.Input(shape=(32,), name='input_b')
self.assertListEqual(a.get_shape().as_list(), [None, 32])
self.assertEqual(a.get_shape().as_list(), [None, 32])
a_layer, a_node_index, a_tensor_index = a._keras_history
b_layer, _, _ = b._keras_history
self.assertEqual(len(a_layer._inbound_nodes), 1)
@ -620,11 +622,11 @@ class NetworkTest(test.TestCase):
node = a_layer._inbound_nodes[a_node_index]
self.assertEqual(node.outbound_layer, a_layer)
self.assertListEqual(node.inbound_layers, [])
self.assertListEqual(node.input_tensors, [a])
self.assertListEqual(node.input_shapes, [(None, 32)])
self.assertListEqual(node.output_tensors, [a])
self.assertListEqual(node.output_shapes, [(None, 32)])
self.assertEqual(node.inbound_layers, [])
self.assertEqual(node.input_tensors, [a])
self.assertEqual(node.input_shapes, [(None, 32)])
self.assertEqual(node.output_tensors, [a])
self.assertEqual(node.output_shapes, [(None, 32)])
dense = core_layers.Dense(16, name='dense_1')
dense(a)
@ -632,12 +634,12 @@ class NetworkTest(test.TestCase):
self.assertEqual(len(dense._inbound_nodes), 2)
self.assertEqual(len(dense._outbound_nodes), 0)
self.assertListEqual(dense._inbound_nodes[0].inbound_layers, [a_layer])
self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer])
self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense)
self.assertListEqual(dense._inbound_nodes[1].inbound_layers, [b_layer])
self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer])
self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense)
self.assertListEqual(dense._inbound_nodes[0].input_tensors, [a])
self.assertListEqual(dense._inbound_nodes[1].input_tensors, [b])
self.assertEqual(dense._inbound_nodes[0].input_tensors, [a])
self.assertEqual(dense._inbound_nodes[1].input_tensors, [b])
# Test config
config_0 = dense._inbound_nodes[0].get_config()
@ -889,5 +891,67 @@ class NetworkTest(test.TestCase):
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
class DeferredModeTest(test.TestCase):
def testDeferredTensorAttributes(self):
x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x')
self.assertEqual(str(x),
'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)')
self.assertEqual(repr(x),
'<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>')
@test_util.run_in_graph_and_eager_modes()
def testSimpleNetworkBuilding(self):
inputs = base_layers.Input(shape=(32,))
if context.in_eager_mode():
self.assertIsInstance(inputs, base_layers._DeferredTensor)
self.assertEqual(inputs.dtype.name, 'float32')
self.assertEqual(inputs.shape.as_list(), [None, 32])
x = core_layers.Dense(2)(inputs)
if context.in_eager_mode():
self.assertIsInstance(x, base_layers._DeferredTensor)
self.assertEqual(x.dtype.name, 'float32')
self.assertEqual(x.shape.as_list(), [None, 2])
outputs = core_layers.Dense(4)(x)
network = base_layers.Network(inputs, outputs)
self.assertIsInstance(network, base_layers.Network)
if context.in_eager_mode():
# It should be possible to call such a network on EagerTensors.
inputs = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
outputs = network(inputs)
self.assertEqual(outputs.shape.as_list(), [10, 4])
@test_util.run_in_graph_and_eager_modes()
def testMultiIONetworkbuilding(self):
input_a = base_layers.Input(shape=(32,))
input_b = base_layers.Input(shape=(16,))
a = core_layers.Dense(16)(input_a)
class AddLayer(base_layers.Layer):
def call(self, inputs):
return inputs[0] + inputs[1]
def _compute_output_shape(self, input_shape):
return input_shape[0]
c = AddLayer()([a, input_b]) # pylint: disable=not-callable
c = core_layers.Dense(2)(c)
network = base_layers.Network([input_a, input_b], [a, c])
if context.in_eager_mode():
a_val = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
b_val = constant_op.constant(
np.random.random((10, 16)).astype('float32'))
outputs = network([a_val, b_val])
self.assertEqual(len(outputs), 2)
self.assertEqual(outputs[0].shape.as_list(), [10, 16])
self.assertEqual(outputs[1].shape.as_list(), [10, 2])
if __name__ == '__main__':
test.main()