Merge pull request #32272 from reedwm/docstring_cherrypicks
[r2.0-rc1 CherryPick]: Improve Layer docstrings in regards to autocasting.
This commit is contained in:
commit
d2e5f5a49e
@ -160,17 +160,60 @@ class Layer(module.Module):
|
|||||||
y = layer(x)
|
y = layer(x)
|
||||||
```
|
```
|
||||||
|
|
||||||
A layer subclass can prevent its inputs from being autocasted by
|
Currently, only tensors in the first argument to the layer's `call` method are
|
||||||
|
casted. For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
class MyLayer(tf.keras.layers.Layer):
|
||||||
|
# Bug! `b` will not be casted.
|
||||||
|
def call(self, a, b):
|
||||||
|
return a + 1., b + 1.
|
||||||
|
|
||||||
|
a = tf.constant(1., dtype="float32")
|
||||||
|
b = tf.constant(1., dtype="float32")
|
||||||
|
|
||||||
|
layer = MyLayer(dtype="float64")
|
||||||
|
x, y = layer(a, b)
|
||||||
|
print(x.dtype) # float64
|
||||||
|
print(y.dtype) # float32. Not casted since `b` was not passed to first input
|
||||||
|
```
|
||||||
|
|
||||||
|
It is recommended to accept tensors only in the first argument. This way,
|
||||||
|
all tensors are casted to the layer's dtype. `MyLayer` should therefore be
|
||||||
|
written as:
|
||||||
|
|
||||||
|
```
|
||||||
|
class MyLayer(tf.keras.layers.Layer):
|
||||||
|
# Now, all tensor inputs will be casted.
|
||||||
|
def call(self, inputs):
|
||||||
|
a, b = inputs
|
||||||
|
return a + 1., b + 1.
|
||||||
|
|
||||||
|
a = tf.constant(1., dtype="float32")
|
||||||
|
b = tf.constant(1., dtype="float32")
|
||||||
|
|
||||||
|
layer = MyLayer(dtype="float64")
|
||||||
|
x, y = layer((a, b))
|
||||||
|
print(x.dtype) # float64
|
||||||
|
print(y.dtype) # float64.
|
||||||
|
```
|
||||||
|
|
||||||
|
In a future minor release, tensors in other arguments may be casted as well.
|
||||||
|
|
||||||
|
Currently, other arguments are not automatically casted for
|
||||||
|
technical reasons, but this may change in a future minor release.
|
||||||
|
|
||||||
|
A layer subclass can prevent its inputs from being autocasted by passing
|
||||||
`autocast=False` to the layer constructor. For example:
|
`autocast=False` to the layer constructor. For example:
|
||||||
|
|
||||||
```
|
```
|
||||||
class MyLayer(tf.keras.layers.Layer):
|
class MyLayer(tf.keras.layers.Layer):
|
||||||
|
|
||||||
def __init__(**kwargs):
|
def __init__(self, **kwargs):
|
||||||
kwargs['autocast']=False
|
kwargs['autocast']=False
|
||||||
super(MyLayer, self).__init__(**kwargs)
|
super(MyLayer, self).__init__(**kwargs)
|
||||||
|
|
||||||
def call(inp):
|
def call(self, inp):
|
||||||
return inp
|
return inp
|
||||||
|
|
||||||
x = tf.ones((4, 4, 4, 4), dtype='float64')
|
x = tf.ones((4, 4, 4, 4), dtype='float64')
|
||||||
@ -188,8 +231,8 @@ class Layer(module.Module):
|
|||||||
|
|
||||||
```
|
```
|
||||||
tf.keras.backend.set_floatx('float64')
|
tf.keras.backend.set_floatx('float64')
|
||||||
layer1 = tf.keras.layers.Dense(4),
|
layer1 = tf.keras.layers.Dense(4)
|
||||||
layer2 = tf.keras.layers.Dense(4),
|
layer2 = tf.keras.layers.Dense(4)
|
||||||
|
|
||||||
x = tf.ones((4, 4))
|
x = tf.ones((4, 4))
|
||||||
y = layer2(layer1(x)) # Both layers run in float64
|
y = layer2(layer1(x)) # Both layers run in float64
|
||||||
@ -201,18 +244,18 @@ class Layer(module.Module):
|
|||||||
well:
|
well:
|
||||||
|
|
||||||
```
|
```
|
||||||
layer1 = tf.keras.layers.Dense(4, dtype='float64'),
|
layer1 = tf.keras.layers.Dense(4, dtype='float64')
|
||||||
layer2 = tf.keras.layers.Dense(4, dtype='float64),
|
layer2 = tf.keras.layers.Dense(4, dtype='float64')
|
||||||
|
|
||||||
x = tf.ones((4, 4))
|
x = tf.ones((4, 4))
|
||||||
y = layer2(layer1(x)) # Both layers run in float64
|
y = layer2(layer1(x)) # Both layers run in float64
|
||||||
|
|
||||||
class NestedLayer(tf.keras.layers.Layer):
|
class NestedLayer(tf.keras.layers.Layer):
|
||||||
def __init__(**kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(MyLayer, self).__init__(**kwargs)
|
super(NestedLayer, self).__init__(**kwargs)
|
||||||
self.dense = tf.keras.layers.Dense(4, dtype=kwargs.get('dtype'))
|
self.dense = tf.keras.layers.Dense(4, dtype=kwargs.get('dtype'))
|
||||||
|
|
||||||
def call(inp):
|
def call(self, inp):
|
||||||
return self.dense(inp)
|
return self.dense(inp)
|
||||||
|
|
||||||
layer3 = NestedLayer(dtype='float64')
|
layer3 = NestedLayer(dtype='float64')
|
||||||
|
Loading…
Reference in New Issue
Block a user