Fix crash when float64 or mixed precision used in certain layers.

Also add many more tests to layer_corectness_test.py, so that most Keras layers are tested.

Also do not test distribution strategy without mixed precision in layer_corectness_test.py. This previuosly was only tested for ease of debugging if a test failed. But the distribution strategy tests take a very long time, so it's not worth testing with distribution strategy without mixed precision.

I made the test a py_test instead of a cuda_py_test to save a bit of GPU resources.

Fixes #35817 and fixes #35883.

PiperOrigin-RevId: 291825015
Change-Id: I8ece6de5b9d549f0de06b643774686f56775781e
This commit is contained in:
Reed Wanderman-Milne 2020-01-27 16:23:55 -08:00 committed by TensorFlower Gardener
parent e8376142f5
commit f9e899854c
7 changed files with 186 additions and 88 deletions

View File

@ -246,8 +246,8 @@ class ThresholdedReLU(Layer):
self.theta = K.cast_to_floatx(theta)
def call(self, inputs):
return inputs * math_ops.cast(
math_ops.greater(inputs, self.theta), K.floatx())
theta = math_ops.cast(self.theta, inputs.dtype)
return inputs * math_ops.cast(math_ops.greater(inputs, theta), inputs.dtype)
def get_config(self):
config = {'theta': float(self.theta)}

View File

@ -283,7 +283,8 @@ class ConvRNN2D(RNN):
shape = list(self.cell.kernel_shape)
shape[-1] = self.cell.filters
initial_state = self.cell.input_conv(initial_state,
array_ops.zeros(tuple(shape)),
array_ops.zeros(tuple(shape),
initial_state.dtype),
padding=self.cell.padding)
if hasattr(self.cell.state_size, '__len__'):

View File

@ -187,7 +187,7 @@ class AlphaDropout(Layer):
kept_idx = math_ops.greater_equal(
K.random_uniform(noise_shape, seed=seed), rate)
kept_idx = math_ops.cast(kept_idx, K.floatx())
kept_idx = math_ops.cast(kept_idx, inputs.dtype)
# Get affine transformation params
a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5

View File

@ -47,16 +47,18 @@ class NoiseLayersTest(keras_parameterized.TestCase):
keras.layers.AlphaDropout, kwargs={'rate': 0.2}, input_shape=(3, 2, 3))
@staticmethod
def _make_model(dtype, gtype):
def _make_model(dtype, class_type):
assert dtype in (dtypes_module.float32, dtypes_module.float64)
assert gtype in ('noise', 'dropout')
assert class_type in ('gaussian_noise', 'gaussian_dropout', 'alpha_noise')
model = keras.Sequential()
model.add(keras.layers.Dense(8, input_shape=(32,), dtype=dtype))
if gtype == 'noise':
gaussian = keras.layers.GaussianNoise(0.0003)
if class_type == 'gaussian_noise':
layer = keras.layers.GaussianNoise(0.0003, dtype=dtype)
elif class_type == 'gaussian_dropout':
layer = keras.layers.GaussianDropout(0.1, dtype=dtype)
else:
gaussian = keras.layers.GaussianDropout(0.1)
model.add(gaussian)
layer = keras.layers.AlphaDropout(0.5, dtype=dtype)
model.add(layer)
return model
def _train_model(self, dtype, gtype):
@ -68,16 +70,22 @@ class NoiseLayersTest(keras_parameterized.TestCase):
model.train_on_batch(np.zeros((8, 32)), np.zeros((8, 8)))
def test_noise_float32(self):
self._train_model(dtypes_module.float32, 'noise')
self._train_model(dtypes_module.float32, 'gaussian_noise')
def test_noise_float64(self):
self._train_model(dtypes_module.float64, 'noise')
self._train_model(dtypes_module.float64, 'gaussian_noise')
def test_dropout_float32(self):
self._train_model(dtypes_module.float32, 'dropout')
self._train_model(dtypes_module.float32, 'gaussian_dropout')
def test_dropout_float64(self):
self._train_model(dtypes_module.float64, 'dropout')
self._train_model(dtypes_module.float64, 'gaussian_dropout')
def test_alpha_dropout_float32(self):
self._train_model(dtypes_module.float32, 'alpha_noise')
def test_alpha_dropout_float64(self):
self._train_model(dtypes_module.float64, 'alpha_noise')
if __name__ == '__main__':

View File

@ -1112,9 +1112,9 @@ class LayerNormalization(Layer):
# self.gamma and self.beta have the wrong shape for fused_batch_norm, so
# we cannot pass them as the scale and offset parameters. Therefore, we
# create two constant tensors in correct shapes for fused_batch_norm and
# later contuct a separate calculation on the scale and offset.
scale = _set_const_tensor(1.0, inputs.dtype, [pre_dim])
offset = _set_const_tensor(0.0, inputs.dtype, [pre_dim])
# later constuct a separate calculation on the scale and offset.
scale = _set_const_tensor(1.0, self.dtype, [pre_dim])
offset = _set_const_tensor(0.0, self.dtype, [pre_dim])
# Compute layer normalization using the fused_batch_norm function.
outputs, _, _ = nn.fused_batch_norm(
@ -1129,9 +1129,9 @@ class LayerNormalization(Layer):
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
if scale is not None:
outputs = outputs * scale
outputs = outputs * math_ops.cast(scale, outputs.dtype)
if offset is not None:
outputs = outputs + offset
outputs = outputs + math_ops.cast(offset, outputs.dtype)
# If some components of the shape got lost due to adjustments, fix that.
outputs.set_shape(input_shape)

View File

@ -223,14 +223,15 @@ cuda_py_test(
],
)
cuda_py_test(
py_test(
name = "layer_correctness_test",
size = "medium",
srcs = ["layer_correctness_test.py"],
python_version = "PY3",
tags = ["no_rocm"],
shard_count = 10,
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:one_device_strategy",
"//tensorflow/python/keras",

View File

@ -17,125 +17,213 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util
from tensorflow.python.framework import config as config_module
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras.layers import convolutional_recurrent
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import dense_attention
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.keras.layers import local
from tensorflow.python.keras.layers import merge
from tensorflow.python.keras.layers import noise
from tensorflow.python.keras.layers import normalization
from tensorflow.python.keras.layers import normalization_v2
from tensorflow.python.keras.layers import pooling
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.keras.layers import recurrent_v2
from tensorflow.python.keras.layers import wrappers
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.platform import test
def create_mirrored_strategy():
if context.num_gpus() >= 1:
return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0'])
else:
return mirrored_strategy.MirroredStrategy(['cpu:0'])
# The test creates two virtual CPUs, and we use both of them to test with
# multiple devices.
return mirrored_strategy.MirroredStrategy(['cpu:0', 'cpu:1'])
@test_util.run_all_in_graph_and_eager_modes
class LayerCorrectnessTest(keras_parameterized.TestCase):
def _create_model_from_layer(self, layer, input_shape):
x = layers.Input(batch_input_shape=input_shape)
y = layer(x)
model = models.Model(x, y)
def setUp(self):
super(LayerCorrectnessTest, self).setUp()
# Set two virtual CPUs to test MirroredStrategy with multiple devices
cpus = config_module.list_physical_devices('CPU')
config_module.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
def _create_model_from_layer(self, layer, input_shapes):
inputs = [layers.Input(batch_input_shape=s) for s in input_shapes]
if len(inputs) == 1:
inputs = inputs[0]
y = layer(inputs)
model = models.Model(inputs, y)
model.compile('sgd', 'mse')
return model
def _test_layer(self, f32_layer, input_shape):
@parameterized.named_parameters(
('LeakyReLU', advanced_activations.LeakyReLU, (2, 2)),
('PReLU', advanced_activations.PReLU, (2, 2)),
('ELU', advanced_activations.ELU, (2, 2)),
('ThresholdedReLU', advanced_activations.ThresholdedReLU, (2, 2)),
('Softmax', advanced_activations.Softmax, (2, 2)),
('ReLU', advanced_activations.ReLU, (2, 2)),
('Conv1D', lambda: convolutional.Conv1D(2, 2), (2, 2, 1)),
('Conv2D', lambda: convolutional.Conv2D(2, 2), (2, 2, 2, 1)),
('Conv3D', lambda: convolutional.Conv3D(2, 2), (2, 2, 2, 2, 1)),
('Conv2DTranspose', lambda: convolutional.Conv2DTranspose(2, 2),
(2, 2, 2, 2)),
('SeparableConv2D', lambda: convolutional.SeparableConv2D(2, 2),
(2, 2, 2, 1)),
('DepthwiseConv2D', lambda: convolutional.DepthwiseConv2D(2, 2),
(2, 2, 2, 1)),
('UpSampling2D', convolutional.UpSampling2D, (2, 2, 2, 1)),
('ZeroPadding2D', convolutional.ZeroPadding2D, (2, 2, 2, 1)),
('Cropping2D', convolutional.Cropping2D, (2, 3, 3, 1)),
('ConvLSTM2D',
lambda: convolutional_recurrent.ConvLSTM2D(4, kernel_size=(2, 2)),
(4, 4, 4, 4, 4)),
('Dense', lambda: core.Dense(2), (2, 2)),
('Dropout', lambda: core.Dropout(0.5), (2, 2)),
('SpatialDropout2D', lambda: core.SpatialDropout2D(0.5), (2, 2, 2, 2)),
('Activation', lambda: core.Activation('sigmoid'), (2, 2)),
('Reshape', lambda: core.Reshape((1, 4, 1)), (2, 2, 2)),
('Permute', lambda: core.Permute((2, 1)), (2, 2, 2)),
('Attention', dense_attention.Attention,
[(2, 2, 3), (2, 3, 3), (2, 3, 3)]),
('AdditiveAttention', dense_attention.AdditiveAttention,
[(2, 2, 3), (2, 3, 3), (2, 3, 3)]),
('Embedding', lambda: embeddings.Embedding(4, 4), (2, 4), 2e-3, 2e-3,
np.random.randint(4, size=(2, 4))),
('LocallyConnected1D', lambda: local.LocallyConnected1D(2, 2), (2, 2, 1)),
('LocallyConnected2D', lambda: local.LocallyConnected2D(2, 2),
(2, 2, 2, 1)),
('Add', merge.Add, [(2, 2), (2, 2)]),
('Subtract', merge.Subtract, [(2, 2), (2, 2)]),
('Multiply', merge.Multiply, [(2, 2), (2, 2)]),
('Average', merge.Average, [(2, 2), (2, 2)]),
('Maximum', merge.Maximum, [(2, 2), (2, 2)]),
('Minimum', merge.Minimum, [(2, 2), (2, 2)]),
('Concatenate', merge.Concatenate, [(2, 2), (2, 2)]),
('Dot', lambda: merge.Dot(1), [(2, 2), (2, 2)]),
('GaussianNoise', lambda: noise.GaussianNoise(0.5), (2, 2)),
('GaussianDropout', lambda: noise.GaussianDropout(0.5), (2, 2)),
('AlphaDropout', lambda: noise.AlphaDropout(0.5), (2, 2)),
('BatchNormalization', normalization_v2.BatchNormalization, (2, 2),
1e-2, 1e-2),
('LayerNormalization', normalization.LayerNormalization, (2, 2)),
('MaxPooling2D', pooling.MaxPooling2D, (2, 2, 2, 1)),
('AveragePooling2D', pooling.AveragePooling2D, (2, 2, 2, 1)),
('GlobalMaxPooling2D', pooling.GlobalMaxPooling2D, (2, 2, 2, 1)),
('GlobalAveragePooling2D', pooling.GlobalAveragePooling2D, (2, 2, 2, 1)),
('SimpleRNN', lambda: recurrent.SimpleRNN(units=4), (4, 4, 4),
1e-2, 1e-2),
('GRU', lambda: recurrent.GRU(units=4), (4, 4, 4)),
('LSTM', lambda: recurrent.LSTM(units=4), (4, 4, 4)),
('GRUV2', lambda: recurrent_v2.GRU(units=4), (4, 4, 4)),
('LSTMV2', lambda: recurrent_v2.LSTM(units=4), (4, 4, 4)),
('TimeDistributed', lambda: wrappers.TimeDistributed(core.Dense(2)),
(2, 2, 2)),
('Bidirectional',
lambda: wrappers.Bidirectional(recurrent.SimpleRNN(units=4)), (2, 2, 2)),
)
def test_layer(self, f32_layer_fn, input_shape, rtol=2e-3, atol=2e-3,
input_data=None):
"""Tests a layer by comparing the float32 and mixed precision weights.
A float32 layer, a mixed precision layer, a distributed float32 layer, and a
distributed mixed precision layer are run. The four layers are identical
other than their dtypes and distribution strategies. The weights after
running fit() are asserted to be close.
Running the distributed float32 layer does not test mixed precision but we
still test it for debugging purposes. If the distributed mixed precision
layer fails, it's easier to debug if you know whether the issue also occurs
in the distributed float32 layer.
A float32 layer, a mixed precision layer, and a distributed mixed precision
layer are run. The three layers are identical other than their dtypes and
distribution strategies. The outputs after predict() and weights after fit()
are asserted to be close.
Args:
f32_layer: A float32 layer. The other three layers will automatically
be created from this
input_shape: The shape of the inputs to the layer, including the batch
dimension.
f32_layer_fn: A function returning a float32 layer. The other two layers
will automatically be created from this
input_shape: The shape of the input to the layer, including the batch
dimension. Or a list of shapes if the layer takes multiple inputs.
rtol: The relative tolerance to be asserted.
atol: The absolute tolerance to be asserted.
input_data: A Numpy array with the data of the input. If None, input data
will be randomly generated
"""
if isinstance(input_shape[0], int):
input_shapes = [input_shape]
else:
input_shapes = input_shape
strategy = create_mirrored_strategy()
f32_layer = f32_layer_fn()
# Create the layers
assert f32_layer.dtype == f32_layer._compute_dtype == 'float32'
config = f32_layer.get_config()
distributed_f32_layer = f32_layer.__class__.from_config(config)
config['dtype'] = policy.Policy('mixed_float16')
mp_layer = f32_layer.__class__.from_config(config)
distributed_mp_layer = f32_layer.__class__.from_config(config)
# Compute per_replica_input_shape for the distributed models
global_batch_size = input_shape[0]
assert global_batch_size % strategy.num_replicas_in_sync == 0
# Compute per_replica_input_shapes for the distributed model
global_batch_size = input_shapes[0][0]
assert global_batch_size % strategy.num_replicas_in_sync == 0, (
'The number of replicas, %d, does not divide the global batch size of '
'%d' % (strategy.num_replicas_in_sync, global_batch_size))
per_replica_batch_size = (
global_batch_size // strategy.num_replicas_in_sync)
per_replica_input_shape = list(input_shape)
per_replica_input_shape[0] = per_replica_batch_size
per_replica_input_shapes = [(per_replica_batch_size,) + s[1:]
for s in input_shapes]
# Create the models
f32_model = self._create_model_from_layer(f32_layer, input_shape)
mp_model = self._create_model_from_layer(mp_layer, input_shape)
f32_model = self._create_model_from_layer(f32_layer, input_shapes)
mp_model = self._create_model_from_layer(mp_layer, input_shapes)
with strategy.scope():
distributed_f32_model = self._create_model_from_layer(
distributed_f32_layer, per_replica_input_shape)
distributed_mp_model = self._create_model_from_layer(
distributed_mp_layer, per_replica_input_shape)
distributed_mp_layer, per_replica_input_shapes)
# Set all model weights to the same values
f32_weights = f32_model.get_weights()
for model in mp_model, distributed_f32_model, distributed_mp_model:
model.set_weights(f32_weights)
mp_model.set_weights(f32_weights)
distributed_mp_model.set_weights(f32_weights)
# Generate input data
if input_data is None:
# Cast inputs to float16 to avoid measuring error from having f16 layers
# cast to float16.
input_data = [np.random.normal(size=s).astype('float16')
for s in input_shapes]
if len(input_data) == 1:
input_data = input_data[0]
# Assert all models have close outputs.
f32_output = f32_model.predict(input_data)
mp_output = mp_model.predict(input_data)
self.assertAllClose(
mp_output, f32_output, rtol=rtol, atol=atol)
self.assertAllClose(
distributed_mp_model.predict(input_data), f32_output, rtol=rtol,
atol=atol)
# Run fit() on models
x = np.random.normal(size=input_shape)
y = np.random.normal(size=input_shape)
for model in (f32_model, mp_model, distributed_f32_model,
distributed_mp_model):
model.fit(x, y, batch_size=global_batch_size)
output = np.random.normal(size=f32_model.outputs[0].shape).astype('float16')
for model in f32_model, mp_model, distributed_mp_model:
model.fit(input_data, output, batch_size=global_batch_size)
# Assert all models have close weights
f32_weights = f32_model.get_weights()
self.assertAllClose(
mp_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4)
mp_model.get_weights(), f32_weights, rtol=rtol, atol=atol)
self.assertAllClose(
distributed_f32_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4)
self.assertAllClose(
distributed_mp_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4)
distributed_mp_model.get_weights(), f32_weights, rtol=rtol, atol=atol)
# Note: There is no need to test every layer subclass here, as otherwise this
# test would take too long. Only layers which do something special or are
# unusual in regards to mixed precision need to be tested.
# We test RNNs as some RNNs use the implementation_selector grappler pass,
# which can cause issues with AutoCastVariables.
@testing_utils.enable_v2_dtype_behavior
def test_simple_rnn(self):
self._test_layer(recurrent.SimpleRNN(units=4, return_sequences=True),
input_shape=(4, 4, 4))
@testing_utils.enable_v2_dtype_behavior
def test_gru(self):
self._test_layer(recurrent_v2.GRU(units=4, return_sequences=True),
input_shape=(4, 4, 4))
@testing_utils.enable_v2_dtype_behavior
def test_lstm(self):
self._test_layer(recurrent_v2.LSTM(units=4, return_sequences=True),
input_shape=(4, 4, 4))
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()