Make Network compatible with eager mode. Currently it only allows to instantiate a Network in eager mode using the regular Keras API, and call it on eager tensors.
PiperOrigin-RevId: 172942569
This commit is contained in:
parent
37fd951790
commit
66b1f43839
tensorflow/python
@ -776,7 +776,7 @@ class Network(tf_base_layers.Network, Layer):
|
||||
if cache_key in self._output_mask_cache:
|
||||
return self._output_mask_cache[cache_key]
|
||||
else:
|
||||
_, output_masks, _ = self._run_internal_graph(inputs, masks)
|
||||
_, output_masks = self._run_internal_graph(inputs, masks)
|
||||
return output_masks
|
||||
|
||||
def get_config(self):
|
||||
|
@ -192,10 +192,12 @@ class KerasIntegrationTest(test.TestCase):
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer='rmsprop',
|
||||
metrics=['accuracy'])
|
||||
self.assertEqual(len(model.losses), 2)
|
||||
self.assertEqual(len(model.updates), 2)
|
||||
history = model.fit(x_train, y_train, epochs=10, batch_size=16,
|
||||
validation_data=(x_test, y_test),
|
||||
verbose=2)
|
||||
self.assertGreater(history.history['val_acc'][-1], 0.85)
|
||||
self.assertGreater(history.history['val_acc'][-1], 0.84)
|
||||
|
||||
def test_vector_classification_shared_model(self):
|
||||
# Test that functional models that feature internal updates
|
||||
|
@ -420,6 +420,8 @@ class Sequential(Model):
|
||||
# Used by Layer base class.
|
||||
self._dtype = None
|
||||
self._activity_regularizer = None
|
||||
self._per_input_losses = {}
|
||||
self._per_input_updates = {}
|
||||
|
||||
# The following properties are not actually used by Keras;
|
||||
# they exist for compatibility with TF's variable scoping mechanism.
|
||||
|
@ -508,6 +508,7 @@ class Layer(object):
|
||||
input_list = nest.flatten(inputs)
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
in_deferred_mode = isinstance(input_list[0], _DeferredTensor)
|
||||
# Ensure the Layer, if being reused, is working with inputs from
|
||||
# the same graph as where it was created.
|
||||
if in_graph_mode:
|
||||
@ -515,6 +516,7 @@ class Layer(object):
|
||||
ops._get_graph_from_inputs(input_list, graph=self.graph) # pylint: disable=protected-access
|
||||
except ValueError as e:
|
||||
raise ValueError('Input graph and Layer graph are not the same: %s' % e)
|
||||
if in_graph_mode or in_deferred_mode:
|
||||
user_kwargs = copy.copy(kwargs)
|
||||
|
||||
# Handle Keras mask propagation from previous layer to current layer.
|
||||
@ -553,6 +555,7 @@ class Layer(object):
|
||||
raise ValueError('activity_regularizer currently unsupported in '
|
||||
'Eager mode. Found an activity_regularizer in '
|
||||
'%s(%s).' % (self.__class__.__name__, self))
|
||||
if not in_graph_mode and not in_deferred_mode:
|
||||
# TODO(agarwal): support _keras_history in Eager mode.
|
||||
for x in input_list:
|
||||
if hasattr(x, '_keras_history'):
|
||||
@ -581,13 +584,26 @@ class Layer(object):
|
||||
if call_has_scope_arg:
|
||||
kwargs['scope'] = scope
|
||||
# Check input assumptions set after layer building, e.g. input shape.
|
||||
if in_graph_mode:
|
||||
if in_graph_mode or in_deferred_mode:
|
||||
self._assert_input_compatibility(inputs)
|
||||
outputs = self.call(inputs, *args, **kwargs)
|
||||
|
||||
if outputs is None:
|
||||
raise ValueError('A layer\'s `call` method should return a Tensor '
|
||||
'or a list of Tensors, not None.')
|
||||
if not in_deferred_mode:
|
||||
outputs = self.call(inputs, *args, **kwargs)
|
||||
if outputs is None:
|
||||
raise ValueError('A layer\'s `call` method should return a Tensor '
|
||||
'or a list of Tensors, not None.')
|
||||
else:
|
||||
# Deferred mode behavior: use `_compute_output_shape` to
|
||||
# infer the number of outputs of the layer and their shapes.
|
||||
output_shapes = self._compute_output_shape(input_shapes)
|
||||
output_shapes = nest.flatten(output_shapes)
|
||||
outputs = [
|
||||
# TODO(fchollet): name the deferred tensors?
|
||||
_DeferredTensor(shape=shape, dtype=self._dtype)
|
||||
for shape in output_shapes
|
||||
]
|
||||
if len(outputs) == 1:
|
||||
outputs = outputs[0]
|
||||
|
||||
if in_graph_mode:
|
||||
# Apply activity regularization.
|
||||
@ -600,16 +616,18 @@ class Layer(object):
|
||||
activity_regularization = self._activity_regularizer(output)
|
||||
self.add_loss(activity_regularization)
|
||||
|
||||
# Handle mask computation and propagation to the next layer.
|
||||
if hasattr(self, 'compute_mask'):
|
||||
output_mask = self.compute_mask(inputs, previous_mask)
|
||||
if isinstance(outputs, list):
|
||||
if output_mask is None:
|
||||
output_mask = [None for _ in range(len(outputs))]
|
||||
for x, m in zip(outputs, output_mask):
|
||||
x._keras_mask = m # pylint: disable=protected-access
|
||||
else:
|
||||
outputs._keras_mask = output_mask # pylint: disable=protected-access
|
||||
if not in_deferred_mode:
|
||||
# TODO(fchollet): consider how masking will work with deferred mode.
|
||||
# Handle mask computation and propagation to the next layer.
|
||||
if hasattr(self, 'compute_mask'):
|
||||
output_mask = self.compute_mask(inputs, previous_mask)
|
||||
if isinstance(outputs, list):
|
||||
if output_mask is None:
|
||||
output_mask = [None for _ in range(len(outputs))]
|
||||
for x, m in zip(outputs, output_mask):
|
||||
x._keras_mask = m # pylint: disable=protected-access
|
||||
else:
|
||||
outputs._keras_mask = output_mask # pylint: disable=protected-access
|
||||
|
||||
if in_graph_mode:
|
||||
# If all input tensors have history metadata,
|
||||
@ -631,14 +649,16 @@ class Layer(object):
|
||||
else:
|
||||
outputs = output_ls_copy
|
||||
|
||||
# Update global default collections.
|
||||
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
|
||||
|
||||
if in_deferred_mode or in_graph_mode:
|
||||
if _have_all_keras_metadata(inputs):
|
||||
# Add an inbound node to the layer, so it can keep track of this call.
|
||||
# This updates the layer history of the output tensor(s).
|
||||
self._add_inbound_node(
|
||||
input_tensors=inputs, output_tensors=outputs, arguments=user_kwargs)
|
||||
|
||||
# Update global default collections.
|
||||
_add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
|
||||
|
||||
self.built = True
|
||||
return outputs
|
||||
|
||||
@ -692,7 +712,6 @@ class Layer(object):
|
||||
arguments: dictionary of keyword arguments that were passed to the
|
||||
`call` method of the layer at the call that created the node.
|
||||
"""
|
||||
assert context.in_graph_mode()
|
||||
input_tensors = nest.flatten(input_tensors)
|
||||
output_tensors = nest.flatten(output_tensors)
|
||||
|
||||
@ -1251,6 +1270,34 @@ class Node(object):
|
||||
}
|
||||
|
||||
|
||||
class _DeferredTensor(object):
|
||||
"""Tensor-like object used to build graphs of layers in Eager mode.
|
||||
|
||||
When calling a layer on a DeferredTensor, the layer will not perform any
|
||||
computation and will simply perfom shape inference to return new
|
||||
DeferredTensors with appropriate shape information. Thus DeferredTensor
|
||||
behaves like a graph-mode Tensor when manipulated by layers.
|
||||
"""
|
||||
|
||||
def __init__(self, shape, dtype, name=None):
|
||||
self.shape = tensor_shape.TensorShape(shape)
|
||||
self.dtype = dtypes.as_dtype(dtype)
|
||||
self.name = name
|
||||
|
||||
def get_shape(self):
|
||||
return self.shape
|
||||
|
||||
def __str__(self):
|
||||
return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name,
|
||||
self.get_shape(),
|
||||
self.dtype.name)
|
||||
|
||||
def __repr__(self):
|
||||
return "<_DeferredTensor '%s' shape=%s dtype=%s>" % (self.name,
|
||||
self.get_shape(),
|
||||
self.dtype.name)
|
||||
|
||||
|
||||
class InputLayer(Layer):
|
||||
"""Layer to be used as an entry point into a Network (a graph of layers).
|
||||
|
||||
@ -1283,8 +1330,6 @@ class InputLayer(Layer):
|
||||
input_tensor=None,
|
||||
sparse=False,
|
||||
name=None):
|
||||
if context.in_eager_mode():
|
||||
raise RuntimeError('InputLayer not supported in Eager mode.')
|
||||
super(InputLayer, self).__init__(dtype=dtype, name=name)
|
||||
self.built = True
|
||||
self.sparse = sparse
|
||||
@ -1299,16 +1344,24 @@ class InputLayer(Layer):
|
||||
else:
|
||||
batch_input_shape = None
|
||||
|
||||
if sparse:
|
||||
input_tensor = array_ops.sparse_placeholder(
|
||||
if context.in_eager_mode():
|
||||
# In eager mode, create a temporary placeholder to call the layer on.
|
||||
input_tensor = _DeferredTensor(
|
||||
shape=batch_input_shape,
|
||||
dtype=dtype,
|
||||
name=self.name)
|
||||
else:
|
||||
input_tensor = array_ops.placeholder(
|
||||
shape=batch_input_shape,
|
||||
dtype=dtype,
|
||||
name=self.name)
|
||||
# In graph mode, create a graph placeholder to call the layer on.
|
||||
if sparse:
|
||||
input_tensor = array_ops.sparse_placeholder(
|
||||
shape=batch_input_shape,
|
||||
dtype=dtype,
|
||||
name=self.name)
|
||||
else:
|
||||
input_tensor = array_ops.placeholder(
|
||||
shape=batch_input_shape,
|
||||
dtype=dtype,
|
||||
name=self.name)
|
||||
|
||||
# For compatibility with Keras API.
|
||||
self.is_placeholder = True
|
||||
@ -1375,8 +1428,6 @@ def Input( # pylint: disable=invalid-name
|
||||
Raises:
|
||||
RuntimeError: If called in Eager mode.
|
||||
"""
|
||||
if context.in_eager_mode():
|
||||
raise RuntimeError('Input not supported in Eager mode.')
|
||||
input_layer = InputLayer(
|
||||
input_shape=shape,
|
||||
batch_size=batch_size,
|
||||
@ -1440,9 +1491,10 @@ class Network(Layer):
|
||||
"""
|
||||
|
||||
def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called
|
||||
# TODO(agarwal): Make Network work in Eager mode.
|
||||
if context.in_eager_mode():
|
||||
raise RuntimeError('Network not supported in Eager mode.')
|
||||
# TODO(fchollet): check that all inputs and outputs are DeferredTensors.
|
||||
pass
|
||||
|
||||
# Set layer name and scope
|
||||
if isinstance(name, vs.VariableScope):
|
||||
base_name = name.name
|
||||
@ -1919,16 +1971,17 @@ class Network(Layer):
|
||||
masks = [None for _ in range(len(inputs))]
|
||||
else:
|
||||
masks = nest.flatten(mask)
|
||||
# Try to retrieve cached outputs if the layer has already been called
|
||||
# on these exact inputs.
|
||||
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
|
||||
if cache_key in self._output_tensor_cache:
|
||||
# Cache hit.
|
||||
return self._output_tensor_cache[cache_key]
|
||||
else:
|
||||
# Cache miss: actually apply the network graph to the new inputs.
|
||||
output_tensors, _, _ = self._run_internal_graph(inputs, masks)
|
||||
return output_tensors
|
||||
|
||||
if context.in_graph_mode():
|
||||
# Try to retrieve cached outputs if the layer has already been called
|
||||
# on these exact inputs.
|
||||
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
|
||||
if cache_key in self._output_tensor_cache:
|
||||
# Cache hit.
|
||||
return self._output_tensor_cache[cache_key]
|
||||
# Actually apply the network graph to the new inputs.
|
||||
outputs, _ = self._run_internal_graph(inputs, masks)
|
||||
return outputs
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
if isinstance(input_shape, list):
|
||||
@ -2091,6 +2144,7 @@ class Network(Layer):
|
||||
if 'mask' in estimator_util.fn_args(layer.call):
|
||||
if 'mask' not in kwargs:
|
||||
kwargs['mask'] = computed_mask
|
||||
|
||||
output_tensors = nest.flatten(
|
||||
layer.call(computed_tensor, **kwargs))
|
||||
if hasattr(layer, 'compute_mask'):
|
||||
@ -2121,18 +2175,19 @@ class Network(Layer):
|
||||
]
|
||||
layer.add_loss(regularization_losses, computed_tensors)
|
||||
|
||||
# Update model updates and losses:
|
||||
# Keep track of updates that depend on the inputs
|
||||
# (e.g. BN updates).
|
||||
self.add_update(layer.get_updates_for(computed_tensors), inputs)
|
||||
# Keep track of unconditional updates (e.g. a counter).
|
||||
self.add_update(layer.get_updates_for(None), None)
|
||||
# Keep track of losses that depend on the inputs
|
||||
# (e.g. activity regularizers).
|
||||
self.add_loss(layer.get_losses_for(computed_tensors), inputs)
|
||||
# Keep track of unconditional losses
|
||||
# (e.g. weight regularizers).
|
||||
self.add_loss(layer.get_losses_for(None), None)
|
||||
if context.in_graph_mode():
|
||||
# Update model updates and losses:
|
||||
# Keep track of updates that depend on the inputs
|
||||
# (e.g. BN updates).
|
||||
self.add_update(layer.get_updates_for(computed_tensors), inputs)
|
||||
# Keep track of unconditional updates (e.g. a counter).
|
||||
self.add_update(layer.get_updates_for(None), None)
|
||||
# Keep track of losses that depend on the inputs
|
||||
# (e.g. activity regularizers).
|
||||
self.add_loss(layer.get_losses_for(computed_tensors), inputs)
|
||||
# Keep track of unconditional losses
|
||||
# (e.g. weight regularizers).
|
||||
self.add_loss(layer.get_losses_for(None), None)
|
||||
|
||||
# Update tensor_map.
|
||||
for x, y, mask in zip(reference_output_tensors, output_tensors,
|
||||
@ -2149,31 +2204,26 @@ class Network(Layer):
|
||||
output_tensors.append(tensor)
|
||||
output_masks.append(mask)
|
||||
|
||||
# Update cache;
|
||||
# keys are based on ids on input tensors and inputs masks.
|
||||
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
|
||||
|
||||
if len(output_tensors) == 1:
|
||||
output_tensors = output_tensors[0]
|
||||
self._output_tensor_cache[cache_key] = output_tensors
|
||||
else:
|
||||
self._output_tensor_cache[cache_key] = output_tensors
|
||||
|
||||
if len(output_masks) == 1:
|
||||
output_masks = output_masks[0]
|
||||
self._output_mask_cache[cache_key] = output_masks
|
||||
else:
|
||||
self._output_mask_cache[cache_key] = output_masks
|
||||
|
||||
if output_shapes is not None:
|
||||
input_shapes = [_static_shape(x) for x in inputs]
|
||||
cache_key = _object_list_uid(input_shapes)
|
||||
if len(output_shapes) == 1:
|
||||
if output_shapes is not None:
|
||||
output_shapes = output_shapes[0]
|
||||
if output_masks is not None:
|
||||
output_masks = output_masks[0]
|
||||
|
||||
if context.in_graph_mode():
|
||||
# Update cache;
|
||||
# keys are based on ids on input tensors and inputs masks.
|
||||
cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks)
|
||||
self._output_tensor_cache[cache_key] = output_tensors
|
||||
if output_masks is not None:
|
||||
self._output_mask_cache[cache_key] = output_masks
|
||||
if output_shapes is not None:
|
||||
input_shapes = [_static_shape(x) for x in inputs]
|
||||
cache_key = _object_list_uid(input_shapes)
|
||||
self._output_shape_cache[cache_key] = output_shapes
|
||||
else:
|
||||
self._output_shape_cache[cache_key] = output_shapes
|
||||
return output_tensors, output_masks, output_shapes
|
||||
|
||||
return output_tensors, output_masks
|
||||
|
||||
|
||||
def _is_tensor_or_tensor_list(v):
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -41,13 +43,13 @@ class BaseLayerTest(test.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLayerProperties(self):
|
||||
layer = base_layers.Layer(name='my_layer')
|
||||
self.assertListEqual(layer.variables, [])
|
||||
self.assertListEqual(layer.trainable_variables, [])
|
||||
self.assertListEqual(layer.non_trainable_variables, [])
|
||||
self.assertEqual(layer.variables, [])
|
||||
self.assertEqual(layer.trainable_variables, [])
|
||||
self.assertEqual(layer.non_trainable_variables, [])
|
||||
if context.in_graph_mode():
|
||||
# updates, losses only suppported in GRAPH mode
|
||||
self.assertListEqual(layer.updates, [])
|
||||
self.assertListEqual(layer.losses, [])
|
||||
self.assertEqual(layer.updates, [])
|
||||
self.assertEqual(layer.losses, [])
|
||||
self.assertEqual(layer.built, False)
|
||||
layer = base_layers.Layer(name='my_layer', trainable=False)
|
||||
self.assertEqual(layer.trainable, False)
|
||||
@ -60,11 +62,11 @@ class BaseLayerTest(test.TestCase):
|
||||
variable = layer.add_variable(
|
||||
'my_var', [2, 2], initializer=init_ops.zeros_initializer())
|
||||
self.assertEqual(variable.name, 'my_layer/my_var:0')
|
||||
self.assertListEqual(layer.variables, [variable])
|
||||
self.assertListEqual(layer.trainable_variables, [variable])
|
||||
self.assertListEqual(layer.non_trainable_variables, [])
|
||||
self.assertEqual(layer.variables, [variable])
|
||||
self.assertEqual(layer.trainable_variables, [variable])
|
||||
self.assertEqual(layer.non_trainable_variables, [])
|
||||
if context.in_graph_mode():
|
||||
self.assertListEqual(
|
||||
self.assertEqual(
|
||||
layer.variables,
|
||||
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
|
||||
|
||||
@ -74,9 +76,9 @@ class BaseLayerTest(test.TestCase):
|
||||
'non_trainable_var', [2, 2],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False)
|
||||
self.assertListEqual(layer.variables, [variable, variable_2])
|
||||
self.assertListEqual(layer.trainable_variables, [variable])
|
||||
self.assertListEqual(layer.non_trainable_variables, [variable_2])
|
||||
self.assertEqual(layer.variables, [variable, variable_2])
|
||||
self.assertEqual(layer.trainable_variables, [variable])
|
||||
self.assertEqual(layer.non_trainable_variables, [variable_2])
|
||||
if context.in_graph_mode():
|
||||
self.assertEqual(
|
||||
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
|
||||
@ -105,8 +107,8 @@ class BaseLayerTest(test.TestCase):
|
||||
inputs = random_ops.random_uniform((5,), seed=1)
|
||||
layer.apply(inputs)
|
||||
layer.apply(inputs)
|
||||
self.assertListEqual([v.name for v in layer.variables],
|
||||
['my_layer/my_var:0'])
|
||||
self.assertEqual([v.name for v in layer.variables],
|
||||
['my_layer/my_var:0'])
|
||||
|
||||
# Creating a layer with no scope leads to lazy construction of
|
||||
# the scope at apply() time. It uses scope "<current scope>/base_name"
|
||||
@ -120,7 +122,7 @@ class BaseLayerTest(test.TestCase):
|
||||
# The variables were created outside of the Layer, and
|
||||
# reuse=True, so the Layer does not own them and they are not
|
||||
# stored in its collection.
|
||||
self.assertListEqual(lazy_layer.variables, [])
|
||||
self.assertEqual(lazy_layer.variables, [])
|
||||
self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer')
|
||||
|
||||
# Creating a layer with no scope leads to lazy construction of
|
||||
@ -135,7 +137,7 @@ class BaseLayerTest(test.TestCase):
|
||||
# The variables were created outside of the Layer, and
|
||||
# reuse=True, so the Layer does not own them and they are not
|
||||
# stored in its collection.
|
||||
self.assertListEqual(lazy_layer.variables, [])
|
||||
self.assertEqual(lazy_layer.variables, [])
|
||||
self.assertEqual(lazy_layer._scope.name, 'new_scope')
|
||||
|
||||
# Checking for graph equality is only done in GRAPH mode.
|
||||
@ -183,14 +185,14 @@ class BaseLayerTest(test.TestCase):
|
||||
outputs = layer.apply(inputs)
|
||||
self.assertEqual(layer.built, True)
|
||||
self.assertEqual(outputs.op.name, 'my_layer/add')
|
||||
self.assertListEqual([v.name
|
||||
for v in layer.variables], ['my_layer/my_var:0'])
|
||||
self.assertEqual([v.name
|
||||
for v in layer.variables], ['my_layer/my_var:0'])
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'my_layer/this_will_break_on_second_call'):
|
||||
layer.apply(inputs)
|
||||
# The list of variables hasn't changed.
|
||||
self.assertListEqual([v.name
|
||||
for v in layer.variables], ['my_layer/my_var:0'])
|
||||
self.assertEqual([v.name
|
||||
for v in layer.variables], ['my_layer/my_var:0'])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDeepCopy(self):
|
||||
@ -435,8 +437,8 @@ class BaseLayerTest(test.TestCase):
|
||||
dense_layer.add_update(0, inputs=a)
|
||||
dense_layer.add_update(1, inputs=None)
|
||||
|
||||
self.assertListEqual(dense_layer.get_updates_for(a), [0])
|
||||
self.assertListEqual(dense_layer.get_updates_for(None), [1])
|
||||
self.assertEqual(dense_layer.get_updates_for(a), [0])
|
||||
self.assertEqual(dense_layer.get_updates_for(None), [1])
|
||||
|
||||
def test_get_losses_for(self):
|
||||
a = base_layers.Input(shape=(2,))
|
||||
@ -444,8 +446,8 @@ class BaseLayerTest(test.TestCase):
|
||||
dense_layer.add_loss(0, inputs=a)
|
||||
dense_layer.add_loss(1, inputs=None)
|
||||
|
||||
self.assertListEqual(dense_layer.get_losses_for(a), [0])
|
||||
self.assertListEqual(dense_layer.get_losses_for(None), [1])
|
||||
self.assertEqual(dense_layer.get_losses_for(a), [0])
|
||||
self.assertEqual(dense_layer.get_losses_for(None), [1])
|
||||
|
||||
def testTopologicalAttributes(self):
|
||||
# test layer attributes / methods related to cross-layer connectivity.
|
||||
@ -612,7 +614,7 @@ class NetworkTest(test.TestCase):
|
||||
a = base_layers.Input(shape=(32,), name='input_a')
|
||||
b = base_layers.Input(shape=(32,), name='input_b')
|
||||
|
||||
self.assertListEqual(a.get_shape().as_list(), [None, 32])
|
||||
self.assertEqual(a.get_shape().as_list(), [None, 32])
|
||||
a_layer, a_node_index, a_tensor_index = a._keras_history
|
||||
b_layer, _, _ = b._keras_history
|
||||
self.assertEqual(len(a_layer._inbound_nodes), 1)
|
||||
@ -620,11 +622,11 @@ class NetworkTest(test.TestCase):
|
||||
node = a_layer._inbound_nodes[a_node_index]
|
||||
self.assertEqual(node.outbound_layer, a_layer)
|
||||
|
||||
self.assertListEqual(node.inbound_layers, [])
|
||||
self.assertListEqual(node.input_tensors, [a])
|
||||
self.assertListEqual(node.input_shapes, [(None, 32)])
|
||||
self.assertListEqual(node.output_tensors, [a])
|
||||
self.assertListEqual(node.output_shapes, [(None, 32)])
|
||||
self.assertEqual(node.inbound_layers, [])
|
||||
self.assertEqual(node.input_tensors, [a])
|
||||
self.assertEqual(node.input_shapes, [(None, 32)])
|
||||
self.assertEqual(node.output_tensors, [a])
|
||||
self.assertEqual(node.output_shapes, [(None, 32)])
|
||||
|
||||
dense = core_layers.Dense(16, name='dense_1')
|
||||
dense(a)
|
||||
@ -632,12 +634,12 @@ class NetworkTest(test.TestCase):
|
||||
|
||||
self.assertEqual(len(dense._inbound_nodes), 2)
|
||||
self.assertEqual(len(dense._outbound_nodes), 0)
|
||||
self.assertListEqual(dense._inbound_nodes[0].inbound_layers, [a_layer])
|
||||
self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer])
|
||||
self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense)
|
||||
self.assertListEqual(dense._inbound_nodes[1].inbound_layers, [b_layer])
|
||||
self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer])
|
||||
self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense)
|
||||
self.assertListEqual(dense._inbound_nodes[0].input_tensors, [a])
|
||||
self.assertListEqual(dense._inbound_nodes[1].input_tensors, [b])
|
||||
self.assertEqual(dense._inbound_nodes[0].input_tensors, [a])
|
||||
self.assertEqual(dense._inbound_nodes[1].input_tensors, [b])
|
||||
|
||||
# Test config
|
||||
config_0 = dense._inbound_nodes[0].get_config()
|
||||
@ -889,5 +891,67 @@ class NetworkTest(test.TestCase):
|
||||
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
|
||||
|
||||
|
||||
class DeferredModeTest(test.TestCase):
|
||||
|
||||
def testDeferredTensorAttributes(self):
|
||||
x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x')
|
||||
self.assertEqual(str(x),
|
||||
'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)')
|
||||
self.assertEqual(repr(x),
|
||||
'<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>')
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSimpleNetworkBuilding(self):
|
||||
inputs = base_layers.Input(shape=(32,))
|
||||
if context.in_eager_mode():
|
||||
self.assertIsInstance(inputs, base_layers._DeferredTensor)
|
||||
self.assertEqual(inputs.dtype.name, 'float32')
|
||||
self.assertEqual(inputs.shape.as_list(), [None, 32])
|
||||
|
||||
x = core_layers.Dense(2)(inputs)
|
||||
if context.in_eager_mode():
|
||||
self.assertIsInstance(x, base_layers._DeferredTensor)
|
||||
self.assertEqual(x.dtype.name, 'float32')
|
||||
self.assertEqual(x.shape.as_list(), [None, 2])
|
||||
|
||||
outputs = core_layers.Dense(4)(x)
|
||||
network = base_layers.Network(inputs, outputs)
|
||||
self.assertIsInstance(network, base_layers.Network)
|
||||
|
||||
if context.in_eager_mode():
|
||||
# It should be possible to call such a network on EagerTensors.
|
||||
inputs = constant_op.constant(
|
||||
np.random.random((10, 32)).astype('float32'))
|
||||
outputs = network(inputs)
|
||||
self.assertEqual(outputs.shape.as_list(), [10, 4])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMultiIONetworkbuilding(self):
|
||||
input_a = base_layers.Input(shape=(32,))
|
||||
input_b = base_layers.Input(shape=(16,))
|
||||
a = core_layers.Dense(16)(input_a)
|
||||
|
||||
class AddLayer(base_layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs[0] + inputs[1]
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
return input_shape[0]
|
||||
|
||||
c = AddLayer()([a, input_b]) # pylint: disable=not-callable
|
||||
c = core_layers.Dense(2)(c)
|
||||
|
||||
network = base_layers.Network([input_a, input_b], [a, c])
|
||||
if context.in_eager_mode():
|
||||
a_val = constant_op.constant(
|
||||
np.random.random((10, 32)).astype('float32'))
|
||||
b_val = constant_op.constant(
|
||||
np.random.random((10, 16)).astype('float32'))
|
||||
outputs = network([a_val, b_val])
|
||||
self.assertEqual(len(outputs), 2)
|
||||
self.assertEqual(outputs[0].shape.as_list(), [10, 16])
|
||||
self.assertEqual(outputs[1].shape.as_list(), [10, 2])
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user