Update a few documentation for layer-input-casting feature.

PiperOrigin-RevId: 201152785
This commit is contained in:
James Qin 2018-06-19 04:15:27 -07:00 committed by TensorFlower Gardener
parent fc6ff59c0c
commit 707ac111cf
3 changed files with 57 additions and 39 deletions

View File

@ -89,11 +89,19 @@ class Layer(checkpointable.CheckpointableBase):
once. Should actually perform the logic of applying the layer to the once. Should actually perform the logic of applying the layer to the
input tensors (which should be passed in as the first argument). input tensors (which should be passed in as the first argument).
By default, layers will cast all their inputs and arguments to the layer's A note on a layer's `dtype` property:
dtype, if set. This is useful for creating a model with multiple dtypes, as A layer's dtype can be specified via the constructor `dtype` argument, and
the user does not need to explicitly cast tensors. If a `Layer` descendant defaults to the dtype of the first input when the layer is called. The dtype
wants only a subset of inputs/arguments to be casted, or none of them, cannot be changed once set.
`_cast_inputs_and_args()` should be overridden.
All floating point tensor inputs and arguments are casted to the layer's
dtype, before the body of the layer computation happens. For models with
layers of different dtypes, this helps getting rid of the explicit casts
between layers.
The casting behavior can be customized in subclasses by overridding
`_cast_inputs_and_args()` function, which is useful if certain or all inputs
should not be casted.
Arguments: Arguments:
trainable: Boolean, whether the layer's variables should be trainable. trainable: Boolean, whether the layer's variables should be trainable.
@ -675,10 +683,9 @@ class Layer(checkpointable.CheckpointableBase):
kwargs['mask'] = previous_mask kwargs['mask'] = previous_mask
input_shapes = None input_shapes = None
# We only cast inputs if self.dtype was previous set, which occurs when # Inputs are only casted if a dtype is pased in the constructor, or if a
# a dtype was passed to the constructor, or when this layer has previously # layer's __call__() has been previously invoked. At present, only floating
# been called. We cast floating point inputs to self.dtype to ensure the # point tensor inputs are affected.
# layer runs with the correct dtype.
# TODO(b/77478433): Perhaps we should only cast inputs if a dtype was passed # TODO(b/77478433): Perhaps we should only cast inputs if a dtype was passed
# to the constructor, not when the layer has previously been called. # to the constructor, not when the layer has previously been called.
inputs_should_be_cast = (self.dtype is not None) inputs_should_be_cast = (self.dtype is not None)
@ -810,10 +817,13 @@ class Layer(checkpointable.CheckpointableBase):
def _cast_inputs_and_args(self, inputs, *args, **kwargs): def _cast_inputs_and_args(self, inputs, *args, **kwargs):
"""Casts the inputs, args, and kwargs of a layer to the layer's dtype. """Casts the inputs, args, and kwargs of a layer to the layer's dtype.
This is intended to be potentially overridden by layer subclasses. By This is intended to be potentially overridden by subclasses. By default,
default, inputs, args, and kwargs are automatically casted to the layer's inputs, args, and kwargs are automatically casted to the layer's dtype.
dtype. Overriding this method allows only some of the inputs, args, and Overriding this method allows only some of the parameters to be treated
kwargs (or none of them) to be casted. differently.
Currently, this only casts floating point tensors to floating point dtypes,
but more types may be casted in the future.
Does not modify inputs, args, or kwargs. Does not modify inputs, args, or kwargs.
@ -823,7 +833,7 @@ class Layer(checkpointable.CheckpointableBase):
**kwargs: The kwargs to self.__call__. **kwargs: The kwargs to self.__call__.
Returns: Returns:
The tuple (new_inputs, new_args, new_kwargs), where tensors in inputs, A tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
args, and kwargs have been casted to self.dtype. args, and kwargs have been casted to self.dtype.
""" """
new_inputs = nest.map_structure(self._cast_fn, inputs) new_inputs = nest.map_structure(self._cast_fn, inputs)

View File

@ -1057,24 +1057,30 @@ class TopologyConstructionTest(test.TestCase):
def compute_output_shape(self, input_shapes): def compute_output_shape(self, input_shapes):
return input_shapes[0] return input_shapes[0]
x = keras.layers.Input((32,), dtype='float64') default_layer = SingleInputLayer()
layer1 = SingleInputLayer() fp32_layer = SingleInputLayer(dtype='float32')
layer2 = SingleInputLayer(dtype='float32') fp16_layer = MultiInputLayer(dtype='float16')
layer3 = MultiInputLayer(dtype='float16')
i1 = layer1(x) input_t = keras.layers.Input((32,), dtype='float64')
i2 = layer2(i1) o1 = default_layer(input_t)
y = layer3((i1, i2)) o2 = fp32_layer(o1)
network = keras.engine.Network(x, y) # fp16_layer has inputs of different dtypes.
x2 = array_ops.ones((32,), dtype='float16') output_t = fp16_layer((o1, o2))
y2 = network(x2) network = keras.engine.Network(input_t, output_t)
self.assertEqual(layer1.dtype, dtypes.float64)
self.assertEqual(layer1.a.dtype, dtypes.float64) x = array_ops.ones((32,), dtype='float16')
self.assertEqual(layer2.dtype, dtypes.float32) y = network(x)
self.assertEqual(layer2.a.dtype, dtypes.float32) self.assertEqual(default_layer.dtype, dtypes.float64)
self.assertEqual(layer3.dtype, dtypes.float16) self.assertEqual(default_layer.a.dtype, dtypes.float64)
self.assertEqual(layer3.a.dtype, dtypes.float16)
self.assertEqual(layer3.b.dtype, dtypes.float16) self.assertEqual(fp32_layer.dtype, dtypes.float32)
self.assertEqual(y2.dtype, dtypes.float16) self.assertEqual(fp32_layer.a.dtype, dtypes.float32)
self.assertEqual(fp16_layer.dtype, dtypes.float16)
self.assertEqual(fp16_layer.a.dtype, dtypes.float16)
self.assertEqual(fp16_layer.b.dtype, dtypes.float16)
self.assertEqual(y.dtype, dtypes.float16)
class DeferredModeTest(test.TestCase): class DeferredModeTest(test.TestCase):

View File

@ -593,7 +593,8 @@ class BaseLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testOnlyCastInputsWhenDtypeSpecified(self): def testOnlyCastInputsWhenDtypeSpecified(self):
class MyLayerBase(keras_base_layer.Layer):
class MyKerasLayer(keras_base_layer.Layer):
def call(self, inputs): def call(self, inputs):
self.x = inputs[0] self.x = inputs[0]
@ -603,13 +604,13 @@ class BaseLayerTest(test.TestCase):
# Inherit from both the Keras Layer and base_layers.Layer to ensure we # Inherit from both the Keras Layer and base_layers.Layer to ensure we
# still get the base_layers.Layer behavior when directly inheriting from # still get the base_layers.Layer behavior when directly inheriting from
# the Keras Layer. # the Keras Layer.
class MyLayer(MyLayerBase, base_layers.Layer): class MyTFLayer(MyKerasLayer, base_layers.Layer):
pass pass
# Test inputs are casted. # Test inputs are casted.
input1 = array_ops.constant(1.0, dtype=dtypes.float64) input1 = array_ops.constant(1.0, dtype=dtypes.float64)
input2 = array_ops.constant(1.0, dtype=dtypes.float32) input2 = array_ops.constant(1.0, dtype=dtypes.float32)
layer = MyLayer(dtype=dtypes.float16) layer = MyTFLayer(dtype=dtypes.float16)
output1, output2 = layer([input1, input2]) output1, output2 = layer([input1, input2])
self.assertEqual(output1.dtype, dtypes.float16) self.assertEqual(output1.dtype, dtypes.float16)
self.assertEqual(output2.dtype, dtypes.float16) self.assertEqual(output2.dtype, dtypes.float16)
@ -617,14 +618,15 @@ class BaseLayerTest(test.TestCase):
# Test inputs are not casted. # Test inputs are not casted.
input1 = array_ops.constant(1.0, dtype=dtypes.float64) input1 = array_ops.constant(1.0, dtype=dtypes.float64)
input2 = array_ops.constant(1.0, dtype=dtypes.float32) input2 = array_ops.constant(1.0, dtype=dtypes.float32)
layer = MyLayer() layer = MyTFLayer()
output1, output2 = layer([input1, input2]) output1, output2 = layer([input1, input2])
self.assertEqual(output1.dtype, dtypes.float64) self.assertEqual(output1.dtype, dtypes.float64)
self.assertEqual(output2.dtype, dtypes.float32) self.assertEqual(output2.dtype, dtypes.float32)
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testVariablesDefaultToFloat32(self): def testVariablesDefaultToFloat32(self):
class MyLayerBase(keras_base_layer.Layer):
class MyKerasLayer(keras_base_layer.Layer):
def build(self, input_shape): def build(self, input_shape):
self.x = self.add_weight('x', ()) self.x = self.add_weight('x', ())
@ -635,14 +637,14 @@ class BaseLayerTest(test.TestCase):
# Inherit from both the Keras Layer and base_layers.Layer to ensure we # Inherit from both the Keras Layer and base_layers.Layer to ensure we
# still get the base_layers.Layer behavior when directly inheriting from # still get the base_layers.Layer behavior when directly inheriting from
# the Keras Layer. # the Keras Layer.
class MyLayer(MyLayerBase, base_layers.Layer): class MyTFLayer(MyKerasLayer, base_layers.Layer):
pass pass
try: try:
# The behavior of Keras Layers is to default to floatx. Ensure that this # The behavior of Keras Layers is to default to floatx. Ensure that this
# behavior is overridden to instead default to float32. # behavior is overridden to instead default to float32.
backend.set_floatx('float16') backend.set_floatx('float16')
layer = MyLayer() layer = MyTFLayer()
layer.build(()) layer.build(())
self.assertEqual(layer.dtype, None) self.assertEqual(layer.dtype, None)
self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32) self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32)