From 707ac111cfed90f35c37417d8c79ab7cbcba152a Mon Sep 17 00:00:00 2001 From: James Qin Date: Tue, 19 Jun 2018 04:15:27 -0700 Subject: [PATCH] Update a few documentation for layer-input-casting feature. PiperOrigin-RevId: 201152785 --- tensorflow/python/keras/engine/base_layer.py | 38 ++++++++++------- .../python/keras/engine/topology_test.py | 42 +++++++++++-------- tensorflow/python/layers/base_test.py | 16 +++---- 3 files changed, 57 insertions(+), 39 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 751cc5a8d56..b05bc96e28f 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -89,11 +89,19 @@ class Layer(checkpointable.CheckpointableBase): once. Should actually perform the logic of applying the layer to the input tensors (which should be passed in as the first argument). - By default, layers will cast all their inputs and arguments to the layer's - dtype, if set. This is useful for creating a model with multiple dtypes, as - the user does not need to explicitly cast tensors. If a `Layer` descendant - wants only a subset of inputs/arguments to be casted, or none of them, - `_cast_inputs_and_args()` should be overridden. + A note on a layer's `dtype` property: + A layer's dtype can be specified via the constructor `dtype` argument, and + defaults to the dtype of the first input when the layer is called. The dtype + cannot be changed once set. + + All floating point tensor inputs and arguments are casted to the layer's + dtype, before the body of the layer computation happens. For models with + layers of different dtypes, this helps getting rid of the explicit casts + between layers. + + The casting behavior can be customized in subclasses by overridding + `_cast_inputs_and_args()` function, which is useful if certain or all inputs + should not be casted. Arguments: trainable: Boolean, whether the layer's variables should be trainable. @@ -675,10 +683,9 @@ class Layer(checkpointable.CheckpointableBase): kwargs['mask'] = previous_mask input_shapes = None - # We only cast inputs if self.dtype was previous set, which occurs when - # a dtype was passed to the constructor, or when this layer has previously - # been called. We cast floating point inputs to self.dtype to ensure the - # layer runs with the correct dtype. + # Inputs are only casted if a dtype is pased in the constructor, or if a + # layer's __call__() has been previously invoked. At present, only floating + # point tensor inputs are affected. # TODO(b/77478433): Perhaps we should only cast inputs if a dtype was passed # to the constructor, not when the layer has previously been called. inputs_should_be_cast = (self.dtype is not None) @@ -810,10 +817,13 @@ class Layer(checkpointable.CheckpointableBase): def _cast_inputs_and_args(self, inputs, *args, **kwargs): """Casts the inputs, args, and kwargs of a layer to the layer's dtype. - This is intended to be potentially overridden by layer subclasses. By - default, inputs, args, and kwargs are automatically casted to the layer's - dtype. Overriding this method allows only some of the inputs, args, and - kwargs (or none of them) to be casted. + This is intended to be potentially overridden by subclasses. By default, + inputs, args, and kwargs are automatically casted to the layer's dtype. + Overriding this method allows only some of the parameters to be treated + differently. + + Currently, this only casts floating point tensors to floating point dtypes, + but more types may be casted in the future. Does not modify inputs, args, or kwargs. @@ -823,7 +833,7 @@ class Layer(checkpointable.CheckpointableBase): **kwargs: The kwargs to self.__call__. Returns: - The tuple (new_inputs, new_args, new_kwargs), where tensors in inputs, + A tuple (new_inputs, new_args, new_kwargs), where tensors in inputs, args, and kwargs have been casted to self.dtype. """ new_inputs = nest.map_structure(self._cast_fn, inputs) diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 7fbe6b80adf..d28c30cb7df 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -1057,24 +1057,30 @@ class TopologyConstructionTest(test.TestCase): def compute_output_shape(self, input_shapes): return input_shapes[0] - x = keras.layers.Input((32,), dtype='float64') - layer1 = SingleInputLayer() - layer2 = SingleInputLayer(dtype='float32') - layer3 = MultiInputLayer(dtype='float16') - i1 = layer1(x) - i2 = layer2(i1) - y = layer3((i1, i2)) - network = keras.engine.Network(x, y) - x2 = array_ops.ones((32,), dtype='float16') - y2 = network(x2) - self.assertEqual(layer1.dtype, dtypes.float64) - self.assertEqual(layer1.a.dtype, dtypes.float64) - self.assertEqual(layer2.dtype, dtypes.float32) - self.assertEqual(layer2.a.dtype, dtypes.float32) - self.assertEqual(layer3.dtype, dtypes.float16) - self.assertEqual(layer3.a.dtype, dtypes.float16) - self.assertEqual(layer3.b.dtype, dtypes.float16) - self.assertEqual(y2.dtype, dtypes.float16) + default_layer = SingleInputLayer() + fp32_layer = SingleInputLayer(dtype='float32') + fp16_layer = MultiInputLayer(dtype='float16') + + input_t = keras.layers.Input((32,), dtype='float64') + o1 = default_layer(input_t) + o2 = fp32_layer(o1) + # fp16_layer has inputs of different dtypes. + output_t = fp16_layer((o1, o2)) + network = keras.engine.Network(input_t, output_t) + + x = array_ops.ones((32,), dtype='float16') + y = network(x) + self.assertEqual(default_layer.dtype, dtypes.float64) + self.assertEqual(default_layer.a.dtype, dtypes.float64) + + self.assertEqual(fp32_layer.dtype, dtypes.float32) + self.assertEqual(fp32_layer.a.dtype, dtypes.float32) + + self.assertEqual(fp16_layer.dtype, dtypes.float16) + self.assertEqual(fp16_layer.a.dtype, dtypes.float16) + self.assertEqual(fp16_layer.b.dtype, dtypes.float16) + + self.assertEqual(y.dtype, dtypes.float16) class DeferredModeTest(test.TestCase): diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 15448c6be8d..ad44328aabf 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -593,7 +593,8 @@ class BaseLayerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testOnlyCastInputsWhenDtypeSpecified(self): - class MyLayerBase(keras_base_layer.Layer): + + class MyKerasLayer(keras_base_layer.Layer): def call(self, inputs): self.x = inputs[0] @@ -603,13 +604,13 @@ class BaseLayerTest(test.TestCase): # Inherit from both the Keras Layer and base_layers.Layer to ensure we # still get the base_layers.Layer behavior when directly inheriting from # the Keras Layer. - class MyLayer(MyLayerBase, base_layers.Layer): + class MyTFLayer(MyKerasLayer, base_layers.Layer): pass # Test inputs are casted. input1 = array_ops.constant(1.0, dtype=dtypes.float64) input2 = array_ops.constant(1.0, dtype=dtypes.float32) - layer = MyLayer(dtype=dtypes.float16) + layer = MyTFLayer(dtype=dtypes.float16) output1, output2 = layer([input1, input2]) self.assertEqual(output1.dtype, dtypes.float16) self.assertEqual(output2.dtype, dtypes.float16) @@ -617,14 +618,15 @@ class BaseLayerTest(test.TestCase): # Test inputs are not casted. input1 = array_ops.constant(1.0, dtype=dtypes.float64) input2 = array_ops.constant(1.0, dtype=dtypes.float32) - layer = MyLayer() + layer = MyTFLayer() output1, output2 = layer([input1, input2]) self.assertEqual(output1.dtype, dtypes.float64) self.assertEqual(output2.dtype, dtypes.float32) @test_util.run_in_graph_and_eager_modes() def testVariablesDefaultToFloat32(self): - class MyLayerBase(keras_base_layer.Layer): + + class MyKerasLayer(keras_base_layer.Layer): def build(self, input_shape): self.x = self.add_weight('x', ()) @@ -635,14 +637,14 @@ class BaseLayerTest(test.TestCase): # Inherit from both the Keras Layer and base_layers.Layer to ensure we # still get the base_layers.Layer behavior when directly inheriting from # the Keras Layer. - class MyLayer(MyLayerBase, base_layers.Layer): + class MyTFLayer(MyKerasLayer, base_layers.Layer): pass try: # The behavior of Keras Layers is to default to floatx. Ensure that this # behavior is overridden to instead default to float32. backend.set_floatx('float16') - layer = MyLayer() + layer = MyTFLayer() layer.build(()) self.assertEqual(layer.dtype, None) self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32)