Update a few documentation for layer-input-casting feature.
PiperOrigin-RevId: 201152785
This commit is contained in:
parent
fc6ff59c0c
commit
707ac111cf
@ -89,11 +89,19 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
once. Should actually perform the logic of applying the layer to the
|
||||
input tensors (which should be passed in as the first argument).
|
||||
|
||||
By default, layers will cast all their inputs and arguments to the layer's
|
||||
dtype, if set. This is useful for creating a model with multiple dtypes, as
|
||||
the user does not need to explicitly cast tensors. If a `Layer` descendant
|
||||
wants only a subset of inputs/arguments to be casted, or none of them,
|
||||
`_cast_inputs_and_args()` should be overridden.
|
||||
A note on a layer's `dtype` property:
|
||||
A layer's dtype can be specified via the constructor `dtype` argument, and
|
||||
defaults to the dtype of the first input when the layer is called. The dtype
|
||||
cannot be changed once set.
|
||||
|
||||
All floating point tensor inputs and arguments are casted to the layer's
|
||||
dtype, before the body of the layer computation happens. For models with
|
||||
layers of different dtypes, this helps getting rid of the explicit casts
|
||||
between layers.
|
||||
|
||||
The casting behavior can be customized in subclasses by overridding
|
||||
`_cast_inputs_and_args()` function, which is useful if certain or all inputs
|
||||
should not be casted.
|
||||
|
||||
Arguments:
|
||||
trainable: Boolean, whether the layer's variables should be trainable.
|
||||
@ -675,10 +683,9 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
kwargs['mask'] = previous_mask
|
||||
|
||||
input_shapes = None
|
||||
# We only cast inputs if self.dtype was previous set, which occurs when
|
||||
# a dtype was passed to the constructor, or when this layer has previously
|
||||
# been called. We cast floating point inputs to self.dtype to ensure the
|
||||
# layer runs with the correct dtype.
|
||||
# Inputs are only casted if a dtype is pased in the constructor, or if a
|
||||
# layer's __call__() has been previously invoked. At present, only floating
|
||||
# point tensor inputs are affected.
|
||||
# TODO(b/77478433): Perhaps we should only cast inputs if a dtype was passed
|
||||
# to the constructor, not when the layer has previously been called.
|
||||
inputs_should_be_cast = (self.dtype is not None)
|
||||
@ -810,10 +817,13 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
def _cast_inputs_and_args(self, inputs, *args, **kwargs):
|
||||
"""Casts the inputs, args, and kwargs of a layer to the layer's dtype.
|
||||
|
||||
This is intended to be potentially overridden by layer subclasses. By
|
||||
default, inputs, args, and kwargs are automatically casted to the layer's
|
||||
dtype. Overriding this method allows only some of the inputs, args, and
|
||||
kwargs (or none of them) to be casted.
|
||||
This is intended to be potentially overridden by subclasses. By default,
|
||||
inputs, args, and kwargs are automatically casted to the layer's dtype.
|
||||
Overriding this method allows only some of the parameters to be treated
|
||||
differently.
|
||||
|
||||
Currently, this only casts floating point tensors to floating point dtypes,
|
||||
but more types may be casted in the future.
|
||||
|
||||
Does not modify inputs, args, or kwargs.
|
||||
|
||||
@ -823,7 +833,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
**kwargs: The kwargs to self.__call__.
|
||||
|
||||
Returns:
|
||||
The tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
|
||||
A tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
|
||||
args, and kwargs have been casted to self.dtype.
|
||||
"""
|
||||
new_inputs = nest.map_structure(self._cast_fn, inputs)
|
||||
|
@ -1057,24 +1057,30 @@ class TopologyConstructionTest(test.TestCase):
|
||||
def compute_output_shape(self, input_shapes):
|
||||
return input_shapes[0]
|
||||
|
||||
x = keras.layers.Input((32,), dtype='float64')
|
||||
layer1 = SingleInputLayer()
|
||||
layer2 = SingleInputLayer(dtype='float32')
|
||||
layer3 = MultiInputLayer(dtype='float16')
|
||||
i1 = layer1(x)
|
||||
i2 = layer2(i1)
|
||||
y = layer3((i1, i2))
|
||||
network = keras.engine.Network(x, y)
|
||||
x2 = array_ops.ones((32,), dtype='float16')
|
||||
y2 = network(x2)
|
||||
self.assertEqual(layer1.dtype, dtypes.float64)
|
||||
self.assertEqual(layer1.a.dtype, dtypes.float64)
|
||||
self.assertEqual(layer2.dtype, dtypes.float32)
|
||||
self.assertEqual(layer2.a.dtype, dtypes.float32)
|
||||
self.assertEqual(layer3.dtype, dtypes.float16)
|
||||
self.assertEqual(layer3.a.dtype, dtypes.float16)
|
||||
self.assertEqual(layer3.b.dtype, dtypes.float16)
|
||||
self.assertEqual(y2.dtype, dtypes.float16)
|
||||
default_layer = SingleInputLayer()
|
||||
fp32_layer = SingleInputLayer(dtype='float32')
|
||||
fp16_layer = MultiInputLayer(dtype='float16')
|
||||
|
||||
input_t = keras.layers.Input((32,), dtype='float64')
|
||||
o1 = default_layer(input_t)
|
||||
o2 = fp32_layer(o1)
|
||||
# fp16_layer has inputs of different dtypes.
|
||||
output_t = fp16_layer((o1, o2))
|
||||
network = keras.engine.Network(input_t, output_t)
|
||||
|
||||
x = array_ops.ones((32,), dtype='float16')
|
||||
y = network(x)
|
||||
self.assertEqual(default_layer.dtype, dtypes.float64)
|
||||
self.assertEqual(default_layer.a.dtype, dtypes.float64)
|
||||
|
||||
self.assertEqual(fp32_layer.dtype, dtypes.float32)
|
||||
self.assertEqual(fp32_layer.a.dtype, dtypes.float32)
|
||||
|
||||
self.assertEqual(fp16_layer.dtype, dtypes.float16)
|
||||
self.assertEqual(fp16_layer.a.dtype, dtypes.float16)
|
||||
self.assertEqual(fp16_layer.b.dtype, dtypes.float16)
|
||||
|
||||
self.assertEqual(y.dtype, dtypes.float16)
|
||||
|
||||
|
||||
class DeferredModeTest(test.TestCase):
|
||||
|
@ -593,7 +593,8 @@ class BaseLayerTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testOnlyCastInputsWhenDtypeSpecified(self):
|
||||
class MyLayerBase(keras_base_layer.Layer):
|
||||
|
||||
class MyKerasLayer(keras_base_layer.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
self.x = inputs[0]
|
||||
@ -603,13 +604,13 @@ class BaseLayerTest(test.TestCase):
|
||||
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
|
||||
# still get the base_layers.Layer behavior when directly inheriting from
|
||||
# the Keras Layer.
|
||||
class MyLayer(MyLayerBase, base_layers.Layer):
|
||||
class MyTFLayer(MyKerasLayer, base_layers.Layer):
|
||||
pass
|
||||
|
||||
# Test inputs are casted.
|
||||
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
|
||||
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
|
||||
layer = MyLayer(dtype=dtypes.float16)
|
||||
layer = MyTFLayer(dtype=dtypes.float16)
|
||||
output1, output2 = layer([input1, input2])
|
||||
self.assertEqual(output1.dtype, dtypes.float16)
|
||||
self.assertEqual(output2.dtype, dtypes.float16)
|
||||
@ -617,14 +618,15 @@ class BaseLayerTest(test.TestCase):
|
||||
# Test inputs are not casted.
|
||||
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
|
||||
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
|
||||
layer = MyLayer()
|
||||
layer = MyTFLayer()
|
||||
output1, output2 = layer([input1, input2])
|
||||
self.assertEqual(output1.dtype, dtypes.float64)
|
||||
self.assertEqual(output2.dtype, dtypes.float32)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testVariablesDefaultToFloat32(self):
|
||||
class MyLayerBase(keras_base_layer.Layer):
|
||||
|
||||
class MyKerasLayer(keras_base_layer.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
self.x = self.add_weight('x', ())
|
||||
@ -635,14 +637,14 @@ class BaseLayerTest(test.TestCase):
|
||||
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
|
||||
# still get the base_layers.Layer behavior when directly inheriting from
|
||||
# the Keras Layer.
|
||||
class MyLayer(MyLayerBase, base_layers.Layer):
|
||||
class MyTFLayer(MyKerasLayer, base_layers.Layer):
|
||||
pass
|
||||
|
||||
try:
|
||||
# The behavior of Keras Layers is to default to floatx. Ensure that this
|
||||
# behavior is overridden to instead default to float32.
|
||||
backend.set_floatx('float16')
|
||||
layer = MyLayer()
|
||||
layer = MyTFLayer()
|
||||
layer.build(())
|
||||
self.assertEqual(layer.dtype, None)
|
||||
self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user