Fix crash when float64 or mixed precision used in certain layers.
Also add many more tests to layer_corectness_test.py, so that most Keras layers are tested. Also do not test distribution strategy without mixed precision in layer_corectness_test.py. This previuosly was only tested for ease of debugging if a test failed. But the distribution strategy tests take a very long time, so it's not worth testing with distribution strategy without mixed precision. I made the test a py_test instead of a cuda_py_test to save a bit of GPU resources. Fixes #35817 and fixes #35883. PiperOrigin-RevId: 291825015 Change-Id: I8ece6de5b9d549f0de06b643774686f56775781e
This commit is contained in:
parent
e8376142f5
commit
f9e899854c
@ -246,8 +246,8 @@ class ThresholdedReLU(Layer):
|
||||
self.theta = K.cast_to_floatx(theta)
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs * math_ops.cast(
|
||||
math_ops.greater(inputs, self.theta), K.floatx())
|
||||
theta = math_ops.cast(self.theta, inputs.dtype)
|
||||
return inputs * math_ops.cast(math_ops.greater(inputs, theta), inputs.dtype)
|
||||
|
||||
def get_config(self):
|
||||
config = {'theta': float(self.theta)}
|
||||
|
@ -283,7 +283,8 @@ class ConvRNN2D(RNN):
|
||||
shape = list(self.cell.kernel_shape)
|
||||
shape[-1] = self.cell.filters
|
||||
initial_state = self.cell.input_conv(initial_state,
|
||||
array_ops.zeros(tuple(shape)),
|
||||
array_ops.zeros(tuple(shape),
|
||||
initial_state.dtype),
|
||||
padding=self.cell.padding)
|
||||
|
||||
if hasattr(self.cell.state_size, '__len__'):
|
||||
|
@ -187,7 +187,7 @@ class AlphaDropout(Layer):
|
||||
|
||||
kept_idx = math_ops.greater_equal(
|
||||
K.random_uniform(noise_shape, seed=seed), rate)
|
||||
kept_idx = math_ops.cast(kept_idx, K.floatx())
|
||||
kept_idx = math_ops.cast(kept_idx, inputs.dtype)
|
||||
|
||||
# Get affine transformation params
|
||||
a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5
|
||||
|
@ -47,16 +47,18 @@ class NoiseLayersTest(keras_parameterized.TestCase):
|
||||
keras.layers.AlphaDropout, kwargs={'rate': 0.2}, input_shape=(3, 2, 3))
|
||||
|
||||
@staticmethod
|
||||
def _make_model(dtype, gtype):
|
||||
def _make_model(dtype, class_type):
|
||||
assert dtype in (dtypes_module.float32, dtypes_module.float64)
|
||||
assert gtype in ('noise', 'dropout')
|
||||
assert class_type in ('gaussian_noise', 'gaussian_dropout', 'alpha_noise')
|
||||
model = keras.Sequential()
|
||||
model.add(keras.layers.Dense(8, input_shape=(32,), dtype=dtype))
|
||||
if gtype == 'noise':
|
||||
gaussian = keras.layers.GaussianNoise(0.0003)
|
||||
if class_type == 'gaussian_noise':
|
||||
layer = keras.layers.GaussianNoise(0.0003, dtype=dtype)
|
||||
elif class_type == 'gaussian_dropout':
|
||||
layer = keras.layers.GaussianDropout(0.1, dtype=dtype)
|
||||
else:
|
||||
gaussian = keras.layers.GaussianDropout(0.1)
|
||||
model.add(gaussian)
|
||||
layer = keras.layers.AlphaDropout(0.5, dtype=dtype)
|
||||
model.add(layer)
|
||||
return model
|
||||
|
||||
def _train_model(self, dtype, gtype):
|
||||
@ -68,16 +70,22 @@ class NoiseLayersTest(keras_parameterized.TestCase):
|
||||
model.train_on_batch(np.zeros((8, 32)), np.zeros((8, 8)))
|
||||
|
||||
def test_noise_float32(self):
|
||||
self._train_model(dtypes_module.float32, 'noise')
|
||||
self._train_model(dtypes_module.float32, 'gaussian_noise')
|
||||
|
||||
def test_noise_float64(self):
|
||||
self._train_model(dtypes_module.float64, 'noise')
|
||||
self._train_model(dtypes_module.float64, 'gaussian_noise')
|
||||
|
||||
def test_dropout_float32(self):
|
||||
self._train_model(dtypes_module.float32, 'dropout')
|
||||
self._train_model(dtypes_module.float32, 'gaussian_dropout')
|
||||
|
||||
def test_dropout_float64(self):
|
||||
self._train_model(dtypes_module.float64, 'dropout')
|
||||
self._train_model(dtypes_module.float64, 'gaussian_dropout')
|
||||
|
||||
def test_alpha_dropout_float32(self):
|
||||
self._train_model(dtypes_module.float32, 'alpha_noise')
|
||||
|
||||
def test_alpha_dropout_float64(self):
|
||||
self._train_model(dtypes_module.float64, 'alpha_noise')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1112,9 +1112,9 @@ class LayerNormalization(Layer):
|
||||
# self.gamma and self.beta have the wrong shape for fused_batch_norm, so
|
||||
# we cannot pass them as the scale and offset parameters. Therefore, we
|
||||
# create two constant tensors in correct shapes for fused_batch_norm and
|
||||
# later contuct a separate calculation on the scale and offset.
|
||||
scale = _set_const_tensor(1.0, inputs.dtype, [pre_dim])
|
||||
offset = _set_const_tensor(0.0, inputs.dtype, [pre_dim])
|
||||
# later constuct a separate calculation on the scale and offset.
|
||||
scale = _set_const_tensor(1.0, self.dtype, [pre_dim])
|
||||
offset = _set_const_tensor(0.0, self.dtype, [pre_dim])
|
||||
|
||||
# Compute layer normalization using the fused_batch_norm function.
|
||||
outputs, _, _ = nn.fused_batch_norm(
|
||||
@ -1129,9 +1129,9 @@ class LayerNormalization(Layer):
|
||||
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
|
||||
|
||||
if scale is not None:
|
||||
outputs = outputs * scale
|
||||
outputs = outputs * math_ops.cast(scale, outputs.dtype)
|
||||
if offset is not None:
|
||||
outputs = outputs + offset
|
||||
outputs = outputs + math_ops.cast(offset, outputs.dtype)
|
||||
|
||||
# If some components of the shape got lost due to adjustments, fix that.
|
||||
outputs.set_shape(input_shape)
|
||||
|
@ -223,14 +223,15 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
py_test(
|
||||
name = "layer_correctness_test",
|
||||
size = "medium",
|
||||
srcs = ["layer_correctness_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = ["no_rocm"],
|
||||
shard_count = 10,
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/distribute:mirrored_strategy",
|
||||
"//tensorflow/python/distribute:one_device_strategy",
|
||||
"//tensorflow/python/keras",
|
||||
|
@ -17,125 +17,213 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import config as config_module
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers import advanced_activations
|
||||
from tensorflow.python.keras.layers import convolutional
|
||||
from tensorflow.python.keras.layers import convolutional_recurrent
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import dense_attention
|
||||
from tensorflow.python.keras.layers import embeddings
|
||||
from tensorflow.python.keras.layers import local
|
||||
from tensorflow.python.keras.layers import merge
|
||||
from tensorflow.python.keras.layers import noise
|
||||
from tensorflow.python.keras.layers import normalization
|
||||
from tensorflow.python.keras.layers import normalization_v2
|
||||
from tensorflow.python.keras.layers import pooling
|
||||
from tensorflow.python.keras.layers import recurrent
|
||||
from tensorflow.python.keras.layers import recurrent_v2
|
||||
from tensorflow.python.keras.layers import wrappers
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def create_mirrored_strategy():
|
||||
if context.num_gpus() >= 1:
|
||||
return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0'])
|
||||
else:
|
||||
return mirrored_strategy.MirroredStrategy(['cpu:0'])
|
||||
# The test creates two virtual CPUs, and we use both of them to test with
|
||||
# multiple devices.
|
||||
return mirrored_strategy.MirroredStrategy(['cpu:0', 'cpu:1'])
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LayerCorrectnessTest(keras_parameterized.TestCase):
|
||||
|
||||
def _create_model_from_layer(self, layer, input_shape):
|
||||
x = layers.Input(batch_input_shape=input_shape)
|
||||
y = layer(x)
|
||||
model = models.Model(x, y)
|
||||
def setUp(self):
|
||||
super(LayerCorrectnessTest, self).setUp()
|
||||
# Set two virtual CPUs to test MirroredStrategy with multiple devices
|
||||
cpus = config_module.list_physical_devices('CPU')
|
||||
config_module.set_logical_device_configuration(cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
])
|
||||
|
||||
def _create_model_from_layer(self, layer, input_shapes):
|
||||
inputs = [layers.Input(batch_input_shape=s) for s in input_shapes]
|
||||
if len(inputs) == 1:
|
||||
inputs = inputs[0]
|
||||
y = layer(inputs)
|
||||
model = models.Model(inputs, y)
|
||||
model.compile('sgd', 'mse')
|
||||
return model
|
||||
|
||||
def _test_layer(self, f32_layer, input_shape):
|
||||
@parameterized.named_parameters(
|
||||
('LeakyReLU', advanced_activations.LeakyReLU, (2, 2)),
|
||||
('PReLU', advanced_activations.PReLU, (2, 2)),
|
||||
('ELU', advanced_activations.ELU, (2, 2)),
|
||||
('ThresholdedReLU', advanced_activations.ThresholdedReLU, (2, 2)),
|
||||
('Softmax', advanced_activations.Softmax, (2, 2)),
|
||||
('ReLU', advanced_activations.ReLU, (2, 2)),
|
||||
('Conv1D', lambda: convolutional.Conv1D(2, 2), (2, 2, 1)),
|
||||
('Conv2D', lambda: convolutional.Conv2D(2, 2), (2, 2, 2, 1)),
|
||||
('Conv3D', lambda: convolutional.Conv3D(2, 2), (2, 2, 2, 2, 1)),
|
||||
('Conv2DTranspose', lambda: convolutional.Conv2DTranspose(2, 2),
|
||||
(2, 2, 2, 2)),
|
||||
('SeparableConv2D', lambda: convolutional.SeparableConv2D(2, 2),
|
||||
(2, 2, 2, 1)),
|
||||
('DepthwiseConv2D', lambda: convolutional.DepthwiseConv2D(2, 2),
|
||||
(2, 2, 2, 1)),
|
||||
('UpSampling2D', convolutional.UpSampling2D, (2, 2, 2, 1)),
|
||||
('ZeroPadding2D', convolutional.ZeroPadding2D, (2, 2, 2, 1)),
|
||||
('Cropping2D', convolutional.Cropping2D, (2, 3, 3, 1)),
|
||||
('ConvLSTM2D',
|
||||
lambda: convolutional_recurrent.ConvLSTM2D(4, kernel_size=(2, 2)),
|
||||
(4, 4, 4, 4, 4)),
|
||||
('Dense', lambda: core.Dense(2), (2, 2)),
|
||||
('Dropout', lambda: core.Dropout(0.5), (2, 2)),
|
||||
('SpatialDropout2D', lambda: core.SpatialDropout2D(0.5), (2, 2, 2, 2)),
|
||||
('Activation', lambda: core.Activation('sigmoid'), (2, 2)),
|
||||
('Reshape', lambda: core.Reshape((1, 4, 1)), (2, 2, 2)),
|
||||
('Permute', lambda: core.Permute((2, 1)), (2, 2, 2)),
|
||||
('Attention', dense_attention.Attention,
|
||||
[(2, 2, 3), (2, 3, 3), (2, 3, 3)]),
|
||||
('AdditiveAttention', dense_attention.AdditiveAttention,
|
||||
[(2, 2, 3), (2, 3, 3), (2, 3, 3)]),
|
||||
('Embedding', lambda: embeddings.Embedding(4, 4), (2, 4), 2e-3, 2e-3,
|
||||
np.random.randint(4, size=(2, 4))),
|
||||
('LocallyConnected1D', lambda: local.LocallyConnected1D(2, 2), (2, 2, 1)),
|
||||
('LocallyConnected2D', lambda: local.LocallyConnected2D(2, 2),
|
||||
(2, 2, 2, 1)),
|
||||
('Add', merge.Add, [(2, 2), (2, 2)]),
|
||||
('Subtract', merge.Subtract, [(2, 2), (2, 2)]),
|
||||
('Multiply', merge.Multiply, [(2, 2), (2, 2)]),
|
||||
('Average', merge.Average, [(2, 2), (2, 2)]),
|
||||
('Maximum', merge.Maximum, [(2, 2), (2, 2)]),
|
||||
('Minimum', merge.Minimum, [(2, 2), (2, 2)]),
|
||||
('Concatenate', merge.Concatenate, [(2, 2), (2, 2)]),
|
||||
('Dot', lambda: merge.Dot(1), [(2, 2), (2, 2)]),
|
||||
('GaussianNoise', lambda: noise.GaussianNoise(0.5), (2, 2)),
|
||||
('GaussianDropout', lambda: noise.GaussianDropout(0.5), (2, 2)),
|
||||
('AlphaDropout', lambda: noise.AlphaDropout(0.5), (2, 2)),
|
||||
('BatchNormalization', normalization_v2.BatchNormalization, (2, 2),
|
||||
1e-2, 1e-2),
|
||||
('LayerNormalization', normalization.LayerNormalization, (2, 2)),
|
||||
('MaxPooling2D', pooling.MaxPooling2D, (2, 2, 2, 1)),
|
||||
('AveragePooling2D', pooling.AveragePooling2D, (2, 2, 2, 1)),
|
||||
('GlobalMaxPooling2D', pooling.GlobalMaxPooling2D, (2, 2, 2, 1)),
|
||||
('GlobalAveragePooling2D', pooling.GlobalAveragePooling2D, (2, 2, 2, 1)),
|
||||
('SimpleRNN', lambda: recurrent.SimpleRNN(units=4), (4, 4, 4),
|
||||
1e-2, 1e-2),
|
||||
('GRU', lambda: recurrent.GRU(units=4), (4, 4, 4)),
|
||||
('LSTM', lambda: recurrent.LSTM(units=4), (4, 4, 4)),
|
||||
('GRUV2', lambda: recurrent_v2.GRU(units=4), (4, 4, 4)),
|
||||
('LSTMV2', lambda: recurrent_v2.LSTM(units=4), (4, 4, 4)),
|
||||
('TimeDistributed', lambda: wrappers.TimeDistributed(core.Dense(2)),
|
||||
(2, 2, 2)),
|
||||
('Bidirectional',
|
||||
lambda: wrappers.Bidirectional(recurrent.SimpleRNN(units=4)), (2, 2, 2)),
|
||||
)
|
||||
def test_layer(self, f32_layer_fn, input_shape, rtol=2e-3, atol=2e-3,
|
||||
input_data=None):
|
||||
"""Tests a layer by comparing the float32 and mixed precision weights.
|
||||
|
||||
A float32 layer, a mixed precision layer, a distributed float32 layer, and a
|
||||
distributed mixed precision layer are run. The four layers are identical
|
||||
other than their dtypes and distribution strategies. The weights after
|
||||
running fit() are asserted to be close.
|
||||
|
||||
Running the distributed float32 layer does not test mixed precision but we
|
||||
still test it for debugging purposes. If the distributed mixed precision
|
||||
layer fails, it's easier to debug if you know whether the issue also occurs
|
||||
in the distributed float32 layer.
|
||||
A float32 layer, a mixed precision layer, and a distributed mixed precision
|
||||
layer are run. The three layers are identical other than their dtypes and
|
||||
distribution strategies. The outputs after predict() and weights after fit()
|
||||
are asserted to be close.
|
||||
|
||||
Args:
|
||||
f32_layer: A float32 layer. The other three layers will automatically
|
||||
be created from this
|
||||
input_shape: The shape of the inputs to the layer, including the batch
|
||||
dimension.
|
||||
f32_layer_fn: A function returning a float32 layer. The other two layers
|
||||
will automatically be created from this
|
||||
input_shape: The shape of the input to the layer, including the batch
|
||||
dimension. Or a list of shapes if the layer takes multiple inputs.
|
||||
rtol: The relative tolerance to be asserted.
|
||||
atol: The absolute tolerance to be asserted.
|
||||
input_data: A Numpy array with the data of the input. If None, input data
|
||||
will be randomly generated
|
||||
"""
|
||||
if isinstance(input_shape[0], int):
|
||||
input_shapes = [input_shape]
|
||||
else:
|
||||
input_shapes = input_shape
|
||||
strategy = create_mirrored_strategy()
|
||||
f32_layer = f32_layer_fn()
|
||||
|
||||
# Create the layers
|
||||
assert f32_layer.dtype == f32_layer._compute_dtype == 'float32'
|
||||
config = f32_layer.get_config()
|
||||
distributed_f32_layer = f32_layer.__class__.from_config(config)
|
||||
config['dtype'] = policy.Policy('mixed_float16')
|
||||
mp_layer = f32_layer.__class__.from_config(config)
|
||||
distributed_mp_layer = f32_layer.__class__.from_config(config)
|
||||
|
||||
# Compute per_replica_input_shape for the distributed models
|
||||
global_batch_size = input_shape[0]
|
||||
assert global_batch_size % strategy.num_replicas_in_sync == 0
|
||||
# Compute per_replica_input_shapes for the distributed model
|
||||
global_batch_size = input_shapes[0][0]
|
||||
assert global_batch_size % strategy.num_replicas_in_sync == 0, (
|
||||
'The number of replicas, %d, does not divide the global batch size of '
|
||||
'%d' % (strategy.num_replicas_in_sync, global_batch_size))
|
||||
per_replica_batch_size = (
|
||||
global_batch_size // strategy.num_replicas_in_sync)
|
||||
per_replica_input_shape = list(input_shape)
|
||||
per_replica_input_shape[0] = per_replica_batch_size
|
||||
per_replica_input_shapes = [(per_replica_batch_size,) + s[1:]
|
||||
for s in input_shapes]
|
||||
|
||||
# Create the models
|
||||
f32_model = self._create_model_from_layer(f32_layer, input_shape)
|
||||
mp_model = self._create_model_from_layer(mp_layer, input_shape)
|
||||
f32_model = self._create_model_from_layer(f32_layer, input_shapes)
|
||||
mp_model = self._create_model_from_layer(mp_layer, input_shapes)
|
||||
with strategy.scope():
|
||||
distributed_f32_model = self._create_model_from_layer(
|
||||
distributed_f32_layer, per_replica_input_shape)
|
||||
distributed_mp_model = self._create_model_from_layer(
|
||||
distributed_mp_layer, per_replica_input_shape)
|
||||
distributed_mp_layer, per_replica_input_shapes)
|
||||
|
||||
# Set all model weights to the same values
|
||||
f32_weights = f32_model.get_weights()
|
||||
for model in mp_model, distributed_f32_model, distributed_mp_model:
|
||||
model.set_weights(f32_weights)
|
||||
mp_model.set_weights(f32_weights)
|
||||
distributed_mp_model.set_weights(f32_weights)
|
||||
|
||||
# Generate input data
|
||||
if input_data is None:
|
||||
# Cast inputs to float16 to avoid measuring error from having f16 layers
|
||||
# cast to float16.
|
||||
input_data = [np.random.normal(size=s).astype('float16')
|
||||
for s in input_shapes]
|
||||
if len(input_data) == 1:
|
||||
input_data = input_data[0]
|
||||
|
||||
# Assert all models have close outputs.
|
||||
f32_output = f32_model.predict(input_data)
|
||||
mp_output = mp_model.predict(input_data)
|
||||
self.assertAllClose(
|
||||
mp_output, f32_output, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(
|
||||
distributed_mp_model.predict(input_data), f32_output, rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
# Run fit() on models
|
||||
x = np.random.normal(size=input_shape)
|
||||
y = np.random.normal(size=input_shape)
|
||||
for model in (f32_model, mp_model, distributed_f32_model,
|
||||
distributed_mp_model):
|
||||
model.fit(x, y, batch_size=global_batch_size)
|
||||
output = np.random.normal(size=f32_model.outputs[0].shape).astype('float16')
|
||||
for model in f32_model, mp_model, distributed_mp_model:
|
||||
model.fit(input_data, output, batch_size=global_batch_size)
|
||||
|
||||
# Assert all models have close weights
|
||||
f32_weights = f32_model.get_weights()
|
||||
self.assertAllClose(
|
||||
mp_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4)
|
||||
mp_model.get_weights(), f32_weights, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(
|
||||
distributed_f32_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4)
|
||||
self.assertAllClose(
|
||||
distributed_mp_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4)
|
||||
distributed_mp_model.get_weights(), f32_weights, rtol=rtol, atol=atol)
|
||||
|
||||
# Note: There is no need to test every layer subclass here, as otherwise this
|
||||
# test would take too long. Only layers which do something special or are
|
||||
# unusual in regards to mixed precision need to be tested.
|
||||
|
||||
# We test RNNs as some RNNs use the implementation_selector grappler pass,
|
||||
# which can cause issues with AutoCastVariables.
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_simple_rnn(self):
|
||||
self._test_layer(recurrent.SimpleRNN(units=4, return_sequences=True),
|
||||
input_shape=(4, 4, 4))
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_gru(self):
|
||||
self._test_layer(recurrent_v2.GRU(units=4, return_sequences=True),
|
||||
input_shape=(4, 4, 4))
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_lstm(self):
|
||||
self._test_layer(recurrent_v2.LSTM(units=4, return_sequences=True),
|
||||
input_shape=(4, 4, 4))
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user