From 66b1f43839ccbfe7e44df004fb92d505ab6ed942 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Fri, 20 Oct 2017 15:29:41 -0700 Subject: [PATCH] Make Network compatible with eager mode. Currently it only allows to instantiate a Network in eager mode using the regular Keras API, and call it on eager tensors. PiperOrigin-RevId: 172942569 --- .../keras/_impl/keras/engine/topology.py | 2 +- .../keras/_impl/keras/integration_test.py | 4 +- tensorflow/python/keras/_impl/keras/models.py | 2 + tensorflow/python/layers/base.py | 198 +++++++++++------- tensorflow/python/layers/base_test.py | 132 +++++++++--- 5 files changed, 228 insertions(+), 110 deletions(-) diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py index d9454ee8d18..c0be023b360 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology.py @@ -776,7 +776,7 @@ class Network(tf_base_layers.Network, Layer): if cache_key in self._output_mask_cache: return self._output_mask_cache[cache_key] else: - _, output_masks, _ = self._run_internal_graph(inputs, masks) + _, output_masks = self._run_internal_graph(inputs, masks) return output_masks def get_config(self): diff --git a/tensorflow/python/keras/_impl/keras/integration_test.py b/tensorflow/python/keras/_impl/keras/integration_test.py index d7d20e5698a..71100368480 100644 --- a/tensorflow/python/keras/_impl/keras/integration_test.py +++ b/tensorflow/python/keras/_impl/keras/integration_test.py @@ -192,10 +192,12 @@ class KerasIntegrationTest(test.TestCase): model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) + self.assertEqual(len(model.losses), 2) + self.assertEqual(len(model.updates), 2) history = model.fit(x_train, y_train, epochs=10, batch_size=16, validation_data=(x_test, y_test), verbose=2) - self.assertGreater(history.history['val_acc'][-1], 0.85) + self.assertGreater(history.history['val_acc'][-1], 0.84) def test_vector_classification_shared_model(self): # Test that functional models that feature internal updates diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py index 6e55c429e95..06941e4bac0 100644 --- a/tensorflow/python/keras/_impl/keras/models.py +++ b/tensorflow/python/keras/_impl/keras/models.py @@ -420,6 +420,8 @@ class Sequential(Model): # Used by Layer base class. self._dtype = None self._activity_regularizer = None + self._per_input_losses = {} + self._per_input_updates = {} # The following properties are not actually used by Keras; # they exist for compatibility with TF's variable scoping mechanism. diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 99a30657ef7..91e18b2ba59 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -508,6 +508,7 @@ class Layer(object): input_list = nest.flatten(inputs) in_graph_mode = context.in_graph_mode() + in_deferred_mode = isinstance(input_list[0], _DeferredTensor) # Ensure the Layer, if being reused, is working with inputs from # the same graph as where it was created. if in_graph_mode: @@ -515,6 +516,7 @@ class Layer(object): ops._get_graph_from_inputs(input_list, graph=self.graph) # pylint: disable=protected-access except ValueError as e: raise ValueError('Input graph and Layer graph are not the same: %s' % e) + if in_graph_mode or in_deferred_mode: user_kwargs = copy.copy(kwargs) # Handle Keras mask propagation from previous layer to current layer. @@ -553,6 +555,7 @@ class Layer(object): raise ValueError('activity_regularizer currently unsupported in ' 'Eager mode. Found an activity_regularizer in ' '%s(%s).' % (self.__class__.__name__, self)) + if not in_graph_mode and not in_deferred_mode: # TODO(agarwal): support _keras_history in Eager mode. for x in input_list: if hasattr(x, '_keras_history'): @@ -581,13 +584,26 @@ class Layer(object): if call_has_scope_arg: kwargs['scope'] = scope # Check input assumptions set after layer building, e.g. input shape. - if in_graph_mode: + if in_graph_mode or in_deferred_mode: self._assert_input_compatibility(inputs) - outputs = self.call(inputs, *args, **kwargs) - if outputs is None: - raise ValueError('A layer\'s `call` method should return a Tensor ' - 'or a list of Tensors, not None.') + if not in_deferred_mode: + outputs = self.call(inputs, *args, **kwargs) + if outputs is None: + raise ValueError('A layer\'s `call` method should return a Tensor ' + 'or a list of Tensors, not None.') + else: + # Deferred mode behavior: use `_compute_output_shape` to + # infer the number of outputs of the layer and their shapes. + output_shapes = self._compute_output_shape(input_shapes) + output_shapes = nest.flatten(output_shapes) + outputs = [ + # TODO(fchollet): name the deferred tensors? + _DeferredTensor(shape=shape, dtype=self._dtype) + for shape in output_shapes + ] + if len(outputs) == 1: + outputs = outputs[0] if in_graph_mode: # Apply activity regularization. @@ -600,16 +616,18 @@ class Layer(object): activity_regularization = self._activity_regularizer(output) self.add_loss(activity_regularization) - # Handle mask computation and propagation to the next layer. - if hasattr(self, 'compute_mask'): - output_mask = self.compute_mask(inputs, previous_mask) - if isinstance(outputs, list): - if output_mask is None: - output_mask = [None for _ in range(len(outputs))] - for x, m in zip(outputs, output_mask): - x._keras_mask = m # pylint: disable=protected-access - else: - outputs._keras_mask = output_mask # pylint: disable=protected-access + if not in_deferred_mode: + # TODO(fchollet): consider how masking will work with deferred mode. + # Handle mask computation and propagation to the next layer. + if hasattr(self, 'compute_mask'): + output_mask = self.compute_mask(inputs, previous_mask) + if isinstance(outputs, list): + if output_mask is None: + output_mask = [None for _ in range(len(outputs))] + for x, m in zip(outputs, output_mask): + x._keras_mask = m # pylint: disable=protected-access + else: + outputs._keras_mask = output_mask # pylint: disable=protected-access if in_graph_mode: # If all input tensors have history metadata, @@ -631,14 +649,16 @@ class Layer(object): else: outputs = output_ls_copy + # Update global default collections. + _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) + + if in_deferred_mode or in_graph_mode: + if _have_all_keras_metadata(inputs): # Add an inbound node to the layer, so it can keep track of this call. # This updates the layer history of the output tensor(s). self._add_inbound_node( input_tensors=inputs, output_tensors=outputs, arguments=user_kwargs) - # Update global default collections. - _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) - self.built = True return outputs @@ -692,7 +712,6 @@ class Layer(object): arguments: dictionary of keyword arguments that were passed to the `call` method of the layer at the call that created the node. """ - assert context.in_graph_mode() input_tensors = nest.flatten(input_tensors) output_tensors = nest.flatten(output_tensors) @@ -1251,6 +1270,34 @@ class Node(object): } +class _DeferredTensor(object): + """Tensor-like object used to build graphs of layers in Eager mode. + + When calling a layer on a DeferredTensor, the layer will not perform any + computation and will simply perfom shape inference to return new + DeferredTensors with appropriate shape information. Thus DeferredTensor + behaves like a graph-mode Tensor when manipulated by layers. + """ + + def __init__(self, shape, dtype, name=None): + self.shape = tensor_shape.TensorShape(shape) + self.dtype = dtypes.as_dtype(dtype) + self.name = name + + def get_shape(self): + return self.shape + + def __str__(self): + return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name, + self.get_shape(), + self.dtype.name) + + def __repr__(self): + return "<_DeferredTensor '%s' shape=%s dtype=%s>" % (self.name, + self.get_shape(), + self.dtype.name) + + class InputLayer(Layer): """Layer to be used as an entry point into a Network (a graph of layers). @@ -1283,8 +1330,6 @@ class InputLayer(Layer): input_tensor=None, sparse=False, name=None): - if context.in_eager_mode(): - raise RuntimeError('InputLayer not supported in Eager mode.') super(InputLayer, self).__init__(dtype=dtype, name=name) self.built = True self.sparse = sparse @@ -1299,16 +1344,24 @@ class InputLayer(Layer): else: batch_input_shape = None - if sparse: - input_tensor = array_ops.sparse_placeholder( + if context.in_eager_mode(): + # In eager mode, create a temporary placeholder to call the layer on. + input_tensor = _DeferredTensor( shape=batch_input_shape, dtype=dtype, name=self.name) else: - input_tensor = array_ops.placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) + # In graph mode, create a graph placeholder to call the layer on. + if sparse: + input_tensor = array_ops.sparse_placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + input_tensor = array_ops.placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) # For compatibility with Keras API. self.is_placeholder = True @@ -1375,8 +1428,6 @@ def Input( # pylint: disable=invalid-name Raises: RuntimeError: If called in Eager mode. """ - if context.in_eager_mode(): - raise RuntimeError('Input not supported in Eager mode.') input_layer = InputLayer( input_shape=shape, batch_size=batch_size, @@ -1440,9 +1491,10 @@ class Network(Layer): """ def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called - # TODO(agarwal): Make Network work in Eager mode. if context.in_eager_mode(): - raise RuntimeError('Network not supported in Eager mode.') + # TODO(fchollet): check that all inputs and outputs are DeferredTensors. + pass + # Set layer name and scope if isinstance(name, vs.VariableScope): base_name = name.name @@ -1919,16 +1971,17 @@ class Network(Layer): masks = [None for _ in range(len(inputs))] else: masks = nest.flatten(mask) - # Try to retrieve cached outputs if the layer has already been called - # on these exact inputs. - cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks) - if cache_key in self._output_tensor_cache: - # Cache hit. - return self._output_tensor_cache[cache_key] - else: - # Cache miss: actually apply the network graph to the new inputs. - output_tensors, _, _ = self._run_internal_graph(inputs, masks) - return output_tensors + + if context.in_graph_mode(): + # Try to retrieve cached outputs if the layer has already been called + # on these exact inputs. + cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks) + if cache_key in self._output_tensor_cache: + # Cache hit. + return self._output_tensor_cache[cache_key] + # Actually apply the network graph to the new inputs. + outputs, _ = self._run_internal_graph(inputs, masks) + return outputs def _compute_output_shape(self, input_shape): if isinstance(input_shape, list): @@ -2091,6 +2144,7 @@ class Network(Layer): if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_mask + output_tensors = nest.flatten( layer.call(computed_tensor, **kwargs)) if hasattr(layer, 'compute_mask'): @@ -2121,18 +2175,19 @@ class Network(Layer): ] layer.add_loss(regularization_losses, computed_tensors) - # Update model updates and losses: - # Keep track of updates that depend on the inputs - # (e.g. BN updates). - self.add_update(layer.get_updates_for(computed_tensors), inputs) - # Keep track of unconditional updates (e.g. a counter). - self.add_update(layer.get_updates_for(None), None) - # Keep track of losses that depend on the inputs - # (e.g. activity regularizers). - self.add_loss(layer.get_losses_for(computed_tensors), inputs) - # Keep track of unconditional losses - # (e.g. weight regularizers). - self.add_loss(layer.get_losses_for(None), None) + if context.in_graph_mode(): + # Update model updates and losses: + # Keep track of updates that depend on the inputs + # (e.g. BN updates). + self.add_update(layer.get_updates_for(computed_tensors), inputs) + # Keep track of unconditional updates (e.g. a counter). + self.add_update(layer.get_updates_for(None), None) + # Keep track of losses that depend on the inputs + # (e.g. activity regularizers). + self.add_loss(layer.get_losses_for(computed_tensors), inputs) + # Keep track of unconditional losses + # (e.g. weight regularizers). + self.add_loss(layer.get_losses_for(None), None) # Update tensor_map. for x, y, mask in zip(reference_output_tensors, output_tensors, @@ -2149,31 +2204,26 @@ class Network(Layer): output_tensors.append(tensor) output_masks.append(mask) - # Update cache; - # keys are based on ids on input tensors and inputs masks. - cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks) - if len(output_tensors) == 1: output_tensors = output_tensors[0] - self._output_tensor_cache[cache_key] = output_tensors - else: - self._output_tensor_cache[cache_key] = output_tensors - - if len(output_masks) == 1: - output_masks = output_masks[0] - self._output_mask_cache[cache_key] = output_masks - else: - self._output_mask_cache[cache_key] = output_masks - - if output_shapes is not None: - input_shapes = [_static_shape(x) for x in inputs] - cache_key = _object_list_uid(input_shapes) - if len(output_shapes) == 1: + if output_shapes is not None: output_shapes = output_shapes[0] + if output_masks is not None: + output_masks = output_masks[0] + + if context.in_graph_mode(): + # Update cache; + # keys are based on ids on input tensors and inputs masks. + cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks) + self._output_tensor_cache[cache_key] = output_tensors + if output_masks is not None: + self._output_mask_cache[cache_key] = output_masks + if output_shapes is not None: + input_shapes = [_static_shape(x) for x in inputs] + cache_key = _object_list_uid(input_shapes) self._output_shape_cache[cache_key] = output_shapes - else: - self._output_shape_cache[cache_key] = output_shapes - return output_tensors, output_masks, output_shapes + + return output_tensors, output_masks def _is_tensor_or_tensor_list(v): diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 813a2fe755d..71eff2f9657 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import copy +import numpy as np + from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -41,13 +43,13 @@ class BaseLayerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testLayerProperties(self): layer = base_layers.Layer(name='my_layer') - self.assertListEqual(layer.variables, []) - self.assertListEqual(layer.trainable_variables, []) - self.assertListEqual(layer.non_trainable_variables, []) + self.assertEqual(layer.variables, []) + self.assertEqual(layer.trainable_variables, []) + self.assertEqual(layer.non_trainable_variables, []) if context.in_graph_mode(): # updates, losses only suppported in GRAPH mode - self.assertListEqual(layer.updates, []) - self.assertListEqual(layer.losses, []) + self.assertEqual(layer.updates, []) + self.assertEqual(layer.losses, []) self.assertEqual(layer.built, False) layer = base_layers.Layer(name='my_layer', trainable=False) self.assertEqual(layer.trainable, False) @@ -60,11 +62,11 @@ class BaseLayerTest(test.TestCase): variable = layer.add_variable( 'my_var', [2, 2], initializer=init_ops.zeros_initializer()) self.assertEqual(variable.name, 'my_layer/my_var:0') - self.assertListEqual(layer.variables, [variable]) - self.assertListEqual(layer.trainable_variables, [variable]) - self.assertListEqual(layer.non_trainable_variables, []) + self.assertEqual(layer.variables, [variable]) + self.assertEqual(layer.trainable_variables, [variable]) + self.assertEqual(layer.non_trainable_variables, []) if context.in_graph_mode(): - self.assertListEqual( + self.assertEqual( layer.variables, ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) @@ -74,9 +76,9 @@ class BaseLayerTest(test.TestCase): 'non_trainable_var', [2, 2], initializer=init_ops.zeros_initializer(), trainable=False) - self.assertListEqual(layer.variables, [variable, variable_2]) - self.assertListEqual(layer.trainable_variables, [variable]) - self.assertListEqual(layer.non_trainable_variables, [variable_2]) + self.assertEqual(layer.variables, [variable, variable_2]) + self.assertEqual(layer.trainable_variables, [variable]) + self.assertEqual(layer.non_trainable_variables, [variable_2]) if context.in_graph_mode(): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) @@ -105,8 +107,8 @@ class BaseLayerTest(test.TestCase): inputs = random_ops.random_uniform((5,), seed=1) layer.apply(inputs) layer.apply(inputs) - self.assertListEqual([v.name for v in layer.variables], - ['my_layer/my_var:0']) + self.assertEqual([v.name for v in layer.variables], + ['my_layer/my_var:0']) # Creating a layer with no scope leads to lazy construction of # the scope at apply() time. It uses scope "/base_name" @@ -120,7 +122,7 @@ class BaseLayerTest(test.TestCase): # The variables were created outside of the Layer, and # reuse=True, so the Layer does not own them and they are not # stored in its collection. - self.assertListEqual(lazy_layer.variables, []) + self.assertEqual(lazy_layer.variables, []) self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer') # Creating a layer with no scope leads to lazy construction of @@ -135,7 +137,7 @@ class BaseLayerTest(test.TestCase): # The variables were created outside of the Layer, and # reuse=True, so the Layer does not own them and they are not # stored in its collection. - self.assertListEqual(lazy_layer.variables, []) + self.assertEqual(lazy_layer.variables, []) self.assertEqual(lazy_layer._scope.name, 'new_scope') # Checking for graph equality is only done in GRAPH mode. @@ -183,14 +185,14 @@ class BaseLayerTest(test.TestCase): outputs = layer.apply(inputs) self.assertEqual(layer.built, True) self.assertEqual(outputs.op.name, 'my_layer/add') - self.assertListEqual([v.name - for v in layer.variables], ['my_layer/my_var:0']) + self.assertEqual([v.name + for v in layer.variables], ['my_layer/my_var:0']) with self.assertRaisesRegexp(ValueError, 'my_layer/this_will_break_on_second_call'): layer.apply(inputs) # The list of variables hasn't changed. - self.assertListEqual([v.name - for v in layer.variables], ['my_layer/my_var:0']) + self.assertEqual([v.name + for v in layer.variables], ['my_layer/my_var:0']) @test_util.run_in_graph_and_eager_modes() def testDeepCopy(self): @@ -435,8 +437,8 @@ class BaseLayerTest(test.TestCase): dense_layer.add_update(0, inputs=a) dense_layer.add_update(1, inputs=None) - self.assertListEqual(dense_layer.get_updates_for(a), [0]) - self.assertListEqual(dense_layer.get_updates_for(None), [1]) + self.assertEqual(dense_layer.get_updates_for(a), [0]) + self.assertEqual(dense_layer.get_updates_for(None), [1]) def test_get_losses_for(self): a = base_layers.Input(shape=(2,)) @@ -444,8 +446,8 @@ class BaseLayerTest(test.TestCase): dense_layer.add_loss(0, inputs=a) dense_layer.add_loss(1, inputs=None) - self.assertListEqual(dense_layer.get_losses_for(a), [0]) - self.assertListEqual(dense_layer.get_losses_for(None), [1]) + self.assertEqual(dense_layer.get_losses_for(a), [0]) + self.assertEqual(dense_layer.get_losses_for(None), [1]) def testTopologicalAttributes(self): # test layer attributes / methods related to cross-layer connectivity. @@ -612,7 +614,7 @@ class NetworkTest(test.TestCase): a = base_layers.Input(shape=(32,), name='input_a') b = base_layers.Input(shape=(32,), name='input_b') - self.assertListEqual(a.get_shape().as_list(), [None, 32]) + self.assertEqual(a.get_shape().as_list(), [None, 32]) a_layer, a_node_index, a_tensor_index = a._keras_history b_layer, _, _ = b._keras_history self.assertEqual(len(a_layer._inbound_nodes), 1) @@ -620,11 +622,11 @@ class NetworkTest(test.TestCase): node = a_layer._inbound_nodes[a_node_index] self.assertEqual(node.outbound_layer, a_layer) - self.assertListEqual(node.inbound_layers, []) - self.assertListEqual(node.input_tensors, [a]) - self.assertListEqual(node.input_shapes, [(None, 32)]) - self.assertListEqual(node.output_tensors, [a]) - self.assertListEqual(node.output_shapes, [(None, 32)]) + self.assertEqual(node.inbound_layers, []) + self.assertEqual(node.input_tensors, [a]) + self.assertEqual(node.input_shapes, [(None, 32)]) + self.assertEqual(node.output_tensors, [a]) + self.assertEqual(node.output_shapes, [(None, 32)]) dense = core_layers.Dense(16, name='dense_1') dense(a) @@ -632,12 +634,12 @@ class NetworkTest(test.TestCase): self.assertEqual(len(dense._inbound_nodes), 2) self.assertEqual(len(dense._outbound_nodes), 0) - self.assertListEqual(dense._inbound_nodes[0].inbound_layers, [a_layer]) + self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer]) self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) - self.assertListEqual(dense._inbound_nodes[1].inbound_layers, [b_layer]) + self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer]) self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) - self.assertListEqual(dense._inbound_nodes[0].input_tensors, [a]) - self.assertListEqual(dense._inbound_nodes[1].input_tensors, [b]) + self.assertEqual(dense._inbound_nodes[0].input_tensors, [a]) + self.assertEqual(dense._inbound_nodes[1].input_tensors, [b]) # Test config config_0 = dense._inbound_nodes[0].get_config() @@ -889,5 +891,67 @@ class NetworkTest(test.TestCase): self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) +class DeferredModeTest(test.TestCase): + + def testDeferredTensorAttributes(self): + x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x') + self.assertEqual(str(x), + 'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)') + self.assertEqual(repr(x), + '<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>') + + @test_util.run_in_graph_and_eager_modes() + def testSimpleNetworkBuilding(self): + inputs = base_layers.Input(shape=(32,)) + if context.in_eager_mode(): + self.assertIsInstance(inputs, base_layers._DeferredTensor) + self.assertEqual(inputs.dtype.name, 'float32') + self.assertEqual(inputs.shape.as_list(), [None, 32]) + + x = core_layers.Dense(2)(inputs) + if context.in_eager_mode(): + self.assertIsInstance(x, base_layers._DeferredTensor) + self.assertEqual(x.dtype.name, 'float32') + self.assertEqual(x.shape.as_list(), [None, 2]) + + outputs = core_layers.Dense(4)(x) + network = base_layers.Network(inputs, outputs) + self.assertIsInstance(network, base_layers.Network) + + if context.in_eager_mode(): + # It should be possible to call such a network on EagerTensors. + inputs = constant_op.constant( + np.random.random((10, 32)).astype('float32')) + outputs = network(inputs) + self.assertEqual(outputs.shape.as_list(), [10, 4]) + + @test_util.run_in_graph_and_eager_modes() + def testMultiIONetworkbuilding(self): + input_a = base_layers.Input(shape=(32,)) + input_b = base_layers.Input(shape=(16,)) + a = core_layers.Dense(16)(input_a) + + class AddLayer(base_layers.Layer): + + def call(self, inputs): + return inputs[0] + inputs[1] + + def _compute_output_shape(self, input_shape): + return input_shape[0] + + c = AddLayer()([a, input_b]) # pylint: disable=not-callable + c = core_layers.Dense(2)(c) + + network = base_layers.Network([input_a, input_b], [a, c]) + if context.in_eager_mode(): + a_val = constant_op.constant( + np.random.random((10, 32)).astype('float32')) + b_val = constant_op.constant( + np.random.random((10, 16)).astype('float32')) + outputs = network([a_val, b_val]) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].shape.as_list(), [10, 16]) + self.assertEqual(outputs[1].shape.as_list(), [10, 2]) + if __name__ == '__main__': test.main()