diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 2d1f05855ec..421d29d1c5d 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -160,17 +160,60 @@ class Layer(module.Module): y = layer(x) ``` - A layer subclass can prevent its inputs from being autocasted by + Currently, only tensors in the first argument to the layer's `call` method are + casted. For example: + + ``` + class MyLayer(tf.keras.layers.Layer): + # Bug! `b` will not be casted. + def call(self, a, b): + return a + 1., b + 1. + + a = tf.constant(1., dtype="float32") + b = tf.constant(1., dtype="float32") + + layer = MyLayer(dtype="float64") + x, y = layer(a, b) + print(x.dtype) # float64 + print(y.dtype) # float32. Not casted since `b` was not passed to first input + ``` + + It is recommended to accept tensors only in the first argument. This way, + all tensors are casted to the layer's dtype. `MyLayer` should therefore be + written as: + + ``` + class MyLayer(tf.keras.layers.Layer): + # Now, all tensor inputs will be casted. + def call(self, inputs): + a, b = inputs + return a + 1., b + 1. + + a = tf.constant(1., dtype="float32") + b = tf.constant(1., dtype="float32") + + layer = MyLayer(dtype="float64") + x, y = layer((a, b)) + print(x.dtype) # float64 + print(y.dtype) # float64. + ``` + + In a future minor release, tensors in other arguments may be casted as well. + + Currently, other arguments are not automatically casted for + technical reasons, but this may change in a future minor release. + + A layer subclass can prevent its inputs from being autocasted by passing `autocast=False` to the layer constructor. For example: ``` class MyLayer(tf.keras.layers.Layer): - def __init__(**kwargs): + def __init__(self, **kwargs): kwargs['autocast']=False super(MyLayer, self).__init__(**kwargs) - def call(inp): + def call(self, inp): return inp x = tf.ones((4, 4, 4, 4), dtype='float64') @@ -188,8 +231,8 @@ class Layer(module.Module): ``` tf.keras.backend.set_floatx('float64') - layer1 = tf.keras.layers.Dense(4), - layer2 = tf.keras.layers.Dense(4), + layer1 = tf.keras.layers.Dense(4) + layer2 = tf.keras.layers.Dense(4) x = tf.ones((4, 4)) y = layer2(layer1(x)) # Both layers run in float64 @@ -201,18 +244,18 @@ class Layer(module.Module): well: ``` - layer1 = tf.keras.layers.Dense(4, dtype='float64'), - layer2 = tf.keras.layers.Dense(4, dtype='float64), + layer1 = tf.keras.layers.Dense(4, dtype='float64') + layer2 = tf.keras.layers.Dense(4, dtype='float64') x = tf.ones((4, 4)) y = layer2(layer1(x)) # Both layers run in float64 class NestedLayer(tf.keras.layers.Layer): - def __init__(**kwargs): - super(MyLayer, self).__init__(**kwargs) + def __init__(self, **kwargs): + super(NestedLayer, self).__init__(**kwargs) self.dense = tf.keras.layers.Dense(4, dtype=kwargs.get('dtype')) - def call(inp): + def call(self, inp): return self.dense(inp) layer3 = NestedLayer(dtype='float64')