Update a few documentation for layer-input-casting feature.

PiperOrigin-RevId: 201152785
This commit is contained in:
James Qin 2018-06-19 04:15:27 -07:00 committed by TensorFlower Gardener
parent fc6ff59c0c
commit 707ac111cf
3 changed files with 57 additions and 39 deletions

View File

@ -89,11 +89,19 @@ class Layer(checkpointable.CheckpointableBase):
once. Should actually perform the logic of applying the layer to the
input tensors (which should be passed in as the first argument).
By default, layers will cast all their inputs and arguments to the layer's
dtype, if set. This is useful for creating a model with multiple dtypes, as
the user does not need to explicitly cast tensors. If a `Layer` descendant
wants only a subset of inputs/arguments to be casted, or none of them,
`_cast_inputs_and_args()` should be overridden.
A note on a layer's `dtype` property:
A layer's dtype can be specified via the constructor `dtype` argument, and
defaults to the dtype of the first input when the layer is called. The dtype
cannot be changed once set.
All floating point tensor inputs and arguments are casted to the layer's
dtype, before the body of the layer computation happens. For models with
layers of different dtypes, this helps getting rid of the explicit casts
between layers.
The casting behavior can be customized in subclasses by overridding
`_cast_inputs_and_args()` function, which is useful if certain or all inputs
should not be casted.
Arguments:
trainable: Boolean, whether the layer's variables should be trainable.
@ -675,10 +683,9 @@ class Layer(checkpointable.CheckpointableBase):
kwargs['mask'] = previous_mask
input_shapes = None
# We only cast inputs if self.dtype was previous set, which occurs when
# a dtype was passed to the constructor, or when this layer has previously
# been called. We cast floating point inputs to self.dtype to ensure the
# layer runs with the correct dtype.
# Inputs are only casted if a dtype is pased in the constructor, or if a
# layer's __call__() has been previously invoked. At present, only floating
# point tensor inputs are affected.
# TODO(b/77478433): Perhaps we should only cast inputs if a dtype was passed
# to the constructor, not when the layer has previously been called.
inputs_should_be_cast = (self.dtype is not None)
@ -810,10 +817,13 @@ class Layer(checkpointable.CheckpointableBase):
def _cast_inputs_and_args(self, inputs, *args, **kwargs):
"""Casts the inputs, args, and kwargs of a layer to the layer's dtype.
This is intended to be potentially overridden by layer subclasses. By
default, inputs, args, and kwargs are automatically casted to the layer's
dtype. Overriding this method allows only some of the inputs, args, and
kwargs (or none of them) to be casted.
This is intended to be potentially overridden by subclasses. By default,
inputs, args, and kwargs are automatically casted to the layer's dtype.
Overriding this method allows only some of the parameters to be treated
differently.
Currently, this only casts floating point tensors to floating point dtypes,
but more types may be casted in the future.
Does not modify inputs, args, or kwargs.
@ -823,7 +833,7 @@ class Layer(checkpointable.CheckpointableBase):
**kwargs: The kwargs to self.__call__.
Returns:
The tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
A tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
args, and kwargs have been casted to self.dtype.
"""
new_inputs = nest.map_structure(self._cast_fn, inputs)

View File

@ -1057,24 +1057,30 @@ class TopologyConstructionTest(test.TestCase):
def compute_output_shape(self, input_shapes):
return input_shapes[0]
x = keras.layers.Input((32,), dtype='float64')
layer1 = SingleInputLayer()
layer2 = SingleInputLayer(dtype='float32')
layer3 = MultiInputLayer(dtype='float16')
i1 = layer1(x)
i2 = layer2(i1)
y = layer3((i1, i2))
network = keras.engine.Network(x, y)
x2 = array_ops.ones((32,), dtype='float16')
y2 = network(x2)
self.assertEqual(layer1.dtype, dtypes.float64)
self.assertEqual(layer1.a.dtype, dtypes.float64)
self.assertEqual(layer2.dtype, dtypes.float32)
self.assertEqual(layer2.a.dtype, dtypes.float32)
self.assertEqual(layer3.dtype, dtypes.float16)
self.assertEqual(layer3.a.dtype, dtypes.float16)
self.assertEqual(layer3.b.dtype, dtypes.float16)
self.assertEqual(y2.dtype, dtypes.float16)
default_layer = SingleInputLayer()
fp32_layer = SingleInputLayer(dtype='float32')
fp16_layer = MultiInputLayer(dtype='float16')
input_t = keras.layers.Input((32,), dtype='float64')
o1 = default_layer(input_t)
o2 = fp32_layer(o1)
# fp16_layer has inputs of different dtypes.
output_t = fp16_layer((o1, o2))
network = keras.engine.Network(input_t, output_t)
x = array_ops.ones((32,), dtype='float16')
y = network(x)
self.assertEqual(default_layer.dtype, dtypes.float64)
self.assertEqual(default_layer.a.dtype, dtypes.float64)
self.assertEqual(fp32_layer.dtype, dtypes.float32)
self.assertEqual(fp32_layer.a.dtype, dtypes.float32)
self.assertEqual(fp16_layer.dtype, dtypes.float16)
self.assertEqual(fp16_layer.a.dtype, dtypes.float16)
self.assertEqual(fp16_layer.b.dtype, dtypes.float16)
self.assertEqual(y.dtype, dtypes.float16)
class DeferredModeTest(test.TestCase):

View File

@ -593,7 +593,8 @@ class BaseLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testOnlyCastInputsWhenDtypeSpecified(self):
class MyLayerBase(keras_base_layer.Layer):
class MyKerasLayer(keras_base_layer.Layer):
def call(self, inputs):
self.x = inputs[0]
@ -603,13 +604,13 @@ class BaseLayerTest(test.TestCase):
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
# still get the base_layers.Layer behavior when directly inheriting from
# the Keras Layer.
class MyLayer(MyLayerBase, base_layers.Layer):
class MyTFLayer(MyKerasLayer, base_layers.Layer):
pass
# Test inputs are casted.
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
layer = MyLayer(dtype=dtypes.float16)
layer = MyTFLayer(dtype=dtypes.float16)
output1, output2 = layer([input1, input2])
self.assertEqual(output1.dtype, dtypes.float16)
self.assertEqual(output2.dtype, dtypes.float16)
@ -617,14 +618,15 @@ class BaseLayerTest(test.TestCase):
# Test inputs are not casted.
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
layer = MyLayer()
layer = MyTFLayer()
output1, output2 = layer([input1, input2])
self.assertEqual(output1.dtype, dtypes.float64)
self.assertEqual(output2.dtype, dtypes.float32)
@test_util.run_in_graph_and_eager_modes()
def testVariablesDefaultToFloat32(self):
class MyLayerBase(keras_base_layer.Layer):
class MyKerasLayer(keras_base_layer.Layer):
def build(self, input_shape):
self.x = self.add_weight('x', ())
@ -635,14 +637,14 @@ class BaseLayerTest(test.TestCase):
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
# still get the base_layers.Layer behavior when directly inheriting from
# the Keras Layer.
class MyLayer(MyLayerBase, base_layers.Layer):
class MyTFLayer(MyKerasLayer, base_layers.Layer):
pass
try:
# The behavior of Keras Layers is to default to floatx. Ensure that this
# behavior is overridden to instead default to float32.
backend.set_floatx('float16')
layer = MyLayer()
layer = MyTFLayer()
layer.build(())
self.assertEqual(layer.dtype, None)
self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32)