Merge pull request #32272 from reedwm/docstring_cherrypicks

[r2.0-rc1 CherryPick]: Improve Layer docstrings in regards to autocasting.
This commit is contained in:
Goldie Gadde 2019-09-06 14:17:49 -07:00 committed by GitHub
commit d2e5f5a49e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -160,17 +160,60 @@ class Layer(module.Module):
y = layer(x)
```
A layer subclass can prevent its inputs from being autocasted by
Currently, only tensors in the first argument to the layer's `call` method are
casted. For example:
```
class MyLayer(tf.keras.layers.Layer):
# Bug! `b` will not be casted.
def call(self, a, b):
return a + 1., b + 1.
a = tf.constant(1., dtype="float32")
b = tf.constant(1., dtype="float32")
layer = MyLayer(dtype="float64")
x, y = layer(a, b)
print(x.dtype) # float64
print(y.dtype) # float32. Not casted since `b` was not passed to first input
```
It is recommended to accept tensors only in the first argument. This way,
all tensors are casted to the layer's dtype. `MyLayer` should therefore be
written as:
```
class MyLayer(tf.keras.layers.Layer):
# Now, all tensor inputs will be casted.
def call(self, inputs):
a, b = inputs
return a + 1., b + 1.
a = tf.constant(1., dtype="float32")
b = tf.constant(1., dtype="float32")
layer = MyLayer(dtype="float64")
x, y = layer((a, b))
print(x.dtype) # float64
print(y.dtype) # float64.
```
In a future minor release, tensors in other arguments may be casted as well.
Currently, other arguments are not automatically casted for
technical reasons, but this may change in a future minor release.
A layer subclass can prevent its inputs from being autocasted by passing
`autocast=False` to the layer constructor. For example:
```
class MyLayer(tf.keras.layers.Layer):
def __init__(**kwargs):
def __init__(self, **kwargs):
kwargs['autocast']=False
super(MyLayer, self).__init__(**kwargs)
def call(inp):
def call(self, inp):
return inp
x = tf.ones((4, 4, 4, 4), dtype='float64')
@ -188,8 +231,8 @@ class Layer(module.Module):
```
tf.keras.backend.set_floatx('float64')
layer1 = tf.keras.layers.Dense(4),
layer2 = tf.keras.layers.Dense(4),
layer1 = tf.keras.layers.Dense(4)
layer2 = tf.keras.layers.Dense(4)
x = tf.ones((4, 4))
y = layer2(layer1(x)) # Both layers run in float64
@ -201,18 +244,18 @@ class Layer(module.Module):
well:
```
layer1 = tf.keras.layers.Dense(4, dtype='float64'),
layer2 = tf.keras.layers.Dense(4, dtype='float64),
layer1 = tf.keras.layers.Dense(4, dtype='float64')
layer2 = tf.keras.layers.Dense(4, dtype='float64')
x = tf.ones((4, 4))
y = layer2(layer1(x)) # Both layers run in float64
class NestedLayer(tf.keras.layers.Layer):
def __init__(**kwargs):
super(MyLayer, self).__init__(**kwargs)
def __init__(self, **kwargs):
super(NestedLayer, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(4, dtype=kwargs.get('dtype'))
def call(inp):
def call(self, inp):
return self.dense(inp)
layer3 = NestedLayer(dtype='float64')