Update a few documentation for layer-input-casting feature.
PiperOrigin-RevId: 201152785
This commit is contained in:
parent
fc6ff59c0c
commit
707ac111cf
@ -89,11 +89,19 @@ class Layer(checkpointable.CheckpointableBase):
|
|||||||
once. Should actually perform the logic of applying the layer to the
|
once. Should actually perform the logic of applying the layer to the
|
||||||
input tensors (which should be passed in as the first argument).
|
input tensors (which should be passed in as the first argument).
|
||||||
|
|
||||||
By default, layers will cast all their inputs and arguments to the layer's
|
A note on a layer's `dtype` property:
|
||||||
dtype, if set. This is useful for creating a model with multiple dtypes, as
|
A layer's dtype can be specified via the constructor `dtype` argument, and
|
||||||
the user does not need to explicitly cast tensors. If a `Layer` descendant
|
defaults to the dtype of the first input when the layer is called. The dtype
|
||||||
wants only a subset of inputs/arguments to be casted, or none of them,
|
cannot be changed once set.
|
||||||
`_cast_inputs_and_args()` should be overridden.
|
|
||||||
|
All floating point tensor inputs and arguments are casted to the layer's
|
||||||
|
dtype, before the body of the layer computation happens. For models with
|
||||||
|
layers of different dtypes, this helps getting rid of the explicit casts
|
||||||
|
between layers.
|
||||||
|
|
||||||
|
The casting behavior can be customized in subclasses by overridding
|
||||||
|
`_cast_inputs_and_args()` function, which is useful if certain or all inputs
|
||||||
|
should not be casted.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
trainable: Boolean, whether the layer's variables should be trainable.
|
trainable: Boolean, whether the layer's variables should be trainable.
|
||||||
@ -675,10 +683,9 @@ class Layer(checkpointable.CheckpointableBase):
|
|||||||
kwargs['mask'] = previous_mask
|
kwargs['mask'] = previous_mask
|
||||||
|
|
||||||
input_shapes = None
|
input_shapes = None
|
||||||
# We only cast inputs if self.dtype was previous set, which occurs when
|
# Inputs are only casted if a dtype is pased in the constructor, or if a
|
||||||
# a dtype was passed to the constructor, or when this layer has previously
|
# layer's __call__() has been previously invoked. At present, only floating
|
||||||
# been called. We cast floating point inputs to self.dtype to ensure the
|
# point tensor inputs are affected.
|
||||||
# layer runs with the correct dtype.
|
|
||||||
# TODO(b/77478433): Perhaps we should only cast inputs if a dtype was passed
|
# TODO(b/77478433): Perhaps we should only cast inputs if a dtype was passed
|
||||||
# to the constructor, not when the layer has previously been called.
|
# to the constructor, not when the layer has previously been called.
|
||||||
inputs_should_be_cast = (self.dtype is not None)
|
inputs_should_be_cast = (self.dtype is not None)
|
||||||
@ -810,10 +817,13 @@ class Layer(checkpointable.CheckpointableBase):
|
|||||||
def _cast_inputs_and_args(self, inputs, *args, **kwargs):
|
def _cast_inputs_and_args(self, inputs, *args, **kwargs):
|
||||||
"""Casts the inputs, args, and kwargs of a layer to the layer's dtype.
|
"""Casts the inputs, args, and kwargs of a layer to the layer's dtype.
|
||||||
|
|
||||||
This is intended to be potentially overridden by layer subclasses. By
|
This is intended to be potentially overridden by subclasses. By default,
|
||||||
default, inputs, args, and kwargs are automatically casted to the layer's
|
inputs, args, and kwargs are automatically casted to the layer's dtype.
|
||||||
dtype. Overriding this method allows only some of the inputs, args, and
|
Overriding this method allows only some of the parameters to be treated
|
||||||
kwargs (or none of them) to be casted.
|
differently.
|
||||||
|
|
||||||
|
Currently, this only casts floating point tensors to floating point dtypes,
|
||||||
|
but more types may be casted in the future.
|
||||||
|
|
||||||
Does not modify inputs, args, or kwargs.
|
Does not modify inputs, args, or kwargs.
|
||||||
|
|
||||||
@ -823,7 +833,7 @@ class Layer(checkpointable.CheckpointableBase):
|
|||||||
**kwargs: The kwargs to self.__call__.
|
**kwargs: The kwargs to self.__call__.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
|
A tuple (new_inputs, new_args, new_kwargs), where tensors in inputs,
|
||||||
args, and kwargs have been casted to self.dtype.
|
args, and kwargs have been casted to self.dtype.
|
||||||
"""
|
"""
|
||||||
new_inputs = nest.map_structure(self._cast_fn, inputs)
|
new_inputs = nest.map_structure(self._cast_fn, inputs)
|
||||||
|
@ -1057,24 +1057,30 @@ class TopologyConstructionTest(test.TestCase):
|
|||||||
def compute_output_shape(self, input_shapes):
|
def compute_output_shape(self, input_shapes):
|
||||||
return input_shapes[0]
|
return input_shapes[0]
|
||||||
|
|
||||||
x = keras.layers.Input((32,), dtype='float64')
|
default_layer = SingleInputLayer()
|
||||||
layer1 = SingleInputLayer()
|
fp32_layer = SingleInputLayer(dtype='float32')
|
||||||
layer2 = SingleInputLayer(dtype='float32')
|
fp16_layer = MultiInputLayer(dtype='float16')
|
||||||
layer3 = MultiInputLayer(dtype='float16')
|
|
||||||
i1 = layer1(x)
|
input_t = keras.layers.Input((32,), dtype='float64')
|
||||||
i2 = layer2(i1)
|
o1 = default_layer(input_t)
|
||||||
y = layer3((i1, i2))
|
o2 = fp32_layer(o1)
|
||||||
network = keras.engine.Network(x, y)
|
# fp16_layer has inputs of different dtypes.
|
||||||
x2 = array_ops.ones((32,), dtype='float16')
|
output_t = fp16_layer((o1, o2))
|
||||||
y2 = network(x2)
|
network = keras.engine.Network(input_t, output_t)
|
||||||
self.assertEqual(layer1.dtype, dtypes.float64)
|
|
||||||
self.assertEqual(layer1.a.dtype, dtypes.float64)
|
x = array_ops.ones((32,), dtype='float16')
|
||||||
self.assertEqual(layer2.dtype, dtypes.float32)
|
y = network(x)
|
||||||
self.assertEqual(layer2.a.dtype, dtypes.float32)
|
self.assertEqual(default_layer.dtype, dtypes.float64)
|
||||||
self.assertEqual(layer3.dtype, dtypes.float16)
|
self.assertEqual(default_layer.a.dtype, dtypes.float64)
|
||||||
self.assertEqual(layer3.a.dtype, dtypes.float16)
|
|
||||||
self.assertEqual(layer3.b.dtype, dtypes.float16)
|
self.assertEqual(fp32_layer.dtype, dtypes.float32)
|
||||||
self.assertEqual(y2.dtype, dtypes.float16)
|
self.assertEqual(fp32_layer.a.dtype, dtypes.float32)
|
||||||
|
|
||||||
|
self.assertEqual(fp16_layer.dtype, dtypes.float16)
|
||||||
|
self.assertEqual(fp16_layer.a.dtype, dtypes.float16)
|
||||||
|
self.assertEqual(fp16_layer.b.dtype, dtypes.float16)
|
||||||
|
|
||||||
|
self.assertEqual(y.dtype, dtypes.float16)
|
||||||
|
|
||||||
|
|
||||||
class DeferredModeTest(test.TestCase):
|
class DeferredModeTest(test.TestCase):
|
||||||
|
@ -593,7 +593,8 @@ class BaseLayerTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testOnlyCastInputsWhenDtypeSpecified(self):
|
def testOnlyCastInputsWhenDtypeSpecified(self):
|
||||||
class MyLayerBase(keras_base_layer.Layer):
|
|
||||||
|
class MyKerasLayer(keras_base_layer.Layer):
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
self.x = inputs[0]
|
self.x = inputs[0]
|
||||||
@ -603,13 +604,13 @@ class BaseLayerTest(test.TestCase):
|
|||||||
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
|
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
|
||||||
# still get the base_layers.Layer behavior when directly inheriting from
|
# still get the base_layers.Layer behavior when directly inheriting from
|
||||||
# the Keras Layer.
|
# the Keras Layer.
|
||||||
class MyLayer(MyLayerBase, base_layers.Layer):
|
class MyTFLayer(MyKerasLayer, base_layers.Layer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Test inputs are casted.
|
# Test inputs are casted.
|
||||||
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
|
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
|
||||||
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
|
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
|
||||||
layer = MyLayer(dtype=dtypes.float16)
|
layer = MyTFLayer(dtype=dtypes.float16)
|
||||||
output1, output2 = layer([input1, input2])
|
output1, output2 = layer([input1, input2])
|
||||||
self.assertEqual(output1.dtype, dtypes.float16)
|
self.assertEqual(output1.dtype, dtypes.float16)
|
||||||
self.assertEqual(output2.dtype, dtypes.float16)
|
self.assertEqual(output2.dtype, dtypes.float16)
|
||||||
@ -617,14 +618,15 @@ class BaseLayerTest(test.TestCase):
|
|||||||
# Test inputs are not casted.
|
# Test inputs are not casted.
|
||||||
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
|
input1 = array_ops.constant(1.0, dtype=dtypes.float64)
|
||||||
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
|
input2 = array_ops.constant(1.0, dtype=dtypes.float32)
|
||||||
layer = MyLayer()
|
layer = MyTFLayer()
|
||||||
output1, output2 = layer([input1, input2])
|
output1, output2 = layer([input1, input2])
|
||||||
self.assertEqual(output1.dtype, dtypes.float64)
|
self.assertEqual(output1.dtype, dtypes.float64)
|
||||||
self.assertEqual(output2.dtype, dtypes.float32)
|
self.assertEqual(output2.dtype, dtypes.float32)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
def testVariablesDefaultToFloat32(self):
|
def testVariablesDefaultToFloat32(self):
|
||||||
class MyLayerBase(keras_base_layer.Layer):
|
|
||||||
|
class MyKerasLayer(keras_base_layer.Layer):
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
self.x = self.add_weight('x', ())
|
self.x = self.add_weight('x', ())
|
||||||
@ -635,14 +637,14 @@ class BaseLayerTest(test.TestCase):
|
|||||||
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
|
# Inherit from both the Keras Layer and base_layers.Layer to ensure we
|
||||||
# still get the base_layers.Layer behavior when directly inheriting from
|
# still get the base_layers.Layer behavior when directly inheriting from
|
||||||
# the Keras Layer.
|
# the Keras Layer.
|
||||||
class MyLayer(MyLayerBase, base_layers.Layer):
|
class MyTFLayer(MyKerasLayer, base_layers.Layer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# The behavior of Keras Layers is to default to floatx. Ensure that this
|
# The behavior of Keras Layers is to default to floatx. Ensure that this
|
||||||
# behavior is overridden to instead default to float32.
|
# behavior is overridden to instead default to float32.
|
||||||
backend.set_floatx('float16')
|
backend.set_floatx('float16')
|
||||||
layer = MyLayer()
|
layer = MyTFLayer()
|
||||||
layer.build(())
|
layer.build(())
|
||||||
self.assertEqual(layer.dtype, None)
|
self.assertEqual(layer.dtype, None)
|
||||||
self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32)
|
self.assertEqual(layer.x.dtype.base_dtype, dtypes.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user