From f9e899854cc96db28564fa65f22d32a647268fc1 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 27 Jan 2020 16:23:55 -0800 Subject: [PATCH] Fix crash when float64 or mixed precision used in certain layers. Also add many more tests to layer_corectness_test.py, so that most Keras layers are tested. Also do not test distribution strategy without mixed precision in layer_corectness_test.py. This previuosly was only tested for ease of debugging if a test failed. But the distribution strategy tests take a very long time, so it's not worth testing with distribution strategy without mixed precision. I made the test a py_test instead of a cuda_py_test to save a bit of GPU resources. Fixes #35817 and fixes #35883. PiperOrigin-RevId: 291825015 Change-Id: I8ece6de5b9d549f0de06b643774686f56775781e --- .../keras/layers/advanced_activations.py | 4 +- .../keras/layers/convolutional_recurrent.py | 3 +- tensorflow/python/keras/layers/noise.py | 2 +- tensorflow/python/keras/layers/noise_test.py | 28 ++- .../python/keras/layers/normalization.py | 10 +- .../keras/mixed_precision/experimental/BUILD | 5 +- .../experimental/layer_correctness_test.py | 222 ++++++++++++------ 7 files changed, 186 insertions(+), 88 deletions(-) diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index 51b22323955..cc46ed0d64b 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -246,8 +246,8 @@ class ThresholdedReLU(Layer): self.theta = K.cast_to_floatx(theta) def call(self, inputs): - return inputs * math_ops.cast( - math_ops.greater(inputs, self.theta), K.floatx()) + theta = math_ops.cast(self.theta, inputs.dtype) + return inputs * math_ops.cast(math_ops.greater(inputs, theta), inputs.dtype) def get_config(self): config = {'theta': float(self.theta)} diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index bc4ee3ce5bd..e5fb30083a4 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -283,7 +283,8 @@ class ConvRNN2D(RNN): shape = list(self.cell.kernel_shape) shape[-1] = self.cell.filters initial_state = self.cell.input_conv(initial_state, - array_ops.zeros(tuple(shape)), + array_ops.zeros(tuple(shape), + initial_state.dtype), padding=self.cell.padding) if hasattr(self.cell.state_size, '__len__'): diff --git a/tensorflow/python/keras/layers/noise.py b/tensorflow/python/keras/layers/noise.py index d15e97af04b..9623be84f56 100644 --- a/tensorflow/python/keras/layers/noise.py +++ b/tensorflow/python/keras/layers/noise.py @@ -187,7 +187,7 @@ class AlphaDropout(Layer): kept_idx = math_ops.greater_equal( K.random_uniform(noise_shape, seed=seed), rate) - kept_idx = math_ops.cast(kept_idx, K.floatx()) + kept_idx = math_ops.cast(kept_idx, inputs.dtype) # Get affine transformation params a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5 diff --git a/tensorflow/python/keras/layers/noise_test.py b/tensorflow/python/keras/layers/noise_test.py index 7f9f0391cd9..96c6d595bdf 100644 --- a/tensorflow/python/keras/layers/noise_test.py +++ b/tensorflow/python/keras/layers/noise_test.py @@ -47,16 +47,18 @@ class NoiseLayersTest(keras_parameterized.TestCase): keras.layers.AlphaDropout, kwargs={'rate': 0.2}, input_shape=(3, 2, 3)) @staticmethod - def _make_model(dtype, gtype): + def _make_model(dtype, class_type): assert dtype in (dtypes_module.float32, dtypes_module.float64) - assert gtype in ('noise', 'dropout') + assert class_type in ('gaussian_noise', 'gaussian_dropout', 'alpha_noise') model = keras.Sequential() model.add(keras.layers.Dense(8, input_shape=(32,), dtype=dtype)) - if gtype == 'noise': - gaussian = keras.layers.GaussianNoise(0.0003) + if class_type == 'gaussian_noise': + layer = keras.layers.GaussianNoise(0.0003, dtype=dtype) + elif class_type == 'gaussian_dropout': + layer = keras.layers.GaussianDropout(0.1, dtype=dtype) else: - gaussian = keras.layers.GaussianDropout(0.1) - model.add(gaussian) + layer = keras.layers.AlphaDropout(0.5, dtype=dtype) + model.add(layer) return model def _train_model(self, dtype, gtype): @@ -68,16 +70,22 @@ class NoiseLayersTest(keras_parameterized.TestCase): model.train_on_batch(np.zeros((8, 32)), np.zeros((8, 8))) def test_noise_float32(self): - self._train_model(dtypes_module.float32, 'noise') + self._train_model(dtypes_module.float32, 'gaussian_noise') def test_noise_float64(self): - self._train_model(dtypes_module.float64, 'noise') + self._train_model(dtypes_module.float64, 'gaussian_noise') def test_dropout_float32(self): - self._train_model(dtypes_module.float32, 'dropout') + self._train_model(dtypes_module.float32, 'gaussian_dropout') def test_dropout_float64(self): - self._train_model(dtypes_module.float64, 'dropout') + self._train_model(dtypes_module.float64, 'gaussian_dropout') + + def test_alpha_dropout_float32(self): + self._train_model(dtypes_module.float32, 'alpha_noise') + + def test_alpha_dropout_float64(self): + self._train_model(dtypes_module.float64, 'alpha_noise') if __name__ == '__main__': diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 08273dfc7d2..be686ad5e50 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -1112,9 +1112,9 @@ class LayerNormalization(Layer): # self.gamma and self.beta have the wrong shape for fused_batch_norm, so # we cannot pass them as the scale and offset parameters. Therefore, we # create two constant tensors in correct shapes for fused_batch_norm and - # later contuct a separate calculation on the scale and offset. - scale = _set_const_tensor(1.0, inputs.dtype, [pre_dim]) - offset = _set_const_tensor(0.0, inputs.dtype, [pre_dim]) + # later constuct a separate calculation on the scale and offset. + scale = _set_const_tensor(1.0, self.dtype, [pre_dim]) + offset = _set_const_tensor(0.0, self.dtype, [pre_dim]) # Compute layer normalization using the fused_batch_norm function. outputs, _, _ = nn.fused_batch_norm( @@ -1129,9 +1129,9 @@ class LayerNormalization(Layer): scale, offset = _broadcast(self.gamma), _broadcast(self.beta) if scale is not None: - outputs = outputs * scale + outputs = outputs * math_ops.cast(scale, outputs.dtype) if offset is not None: - outputs = outputs + offset + outputs = outputs + math_ops.cast(offset, outputs.dtype) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD index afe6827f3bd..d25954be694 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD @@ -223,14 +223,15 @@ cuda_py_test( ], ) -cuda_py_test( +py_test( name = "layer_correctness_test", size = "medium", srcs = ["layer_correctness_test.py"], python_version = "PY3", - tags = ["no_rocm"], + shard_count = 10, deps = [ "//tensorflow/python:client_testlib", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/distribute:one_device_strategy", "//tensorflow/python/keras", diff --git a/tensorflow/python/keras/mixed_precision/experimental/layer_correctness_test.py b/tensorflow/python/keras/mixed_precision/experimental/layer_correctness_test.py index 707418dfabe..210529ede36 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/layer_correctness_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/layer_correctness_test.py @@ -17,125 +17,213 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np +from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.eager import context -from tensorflow.python.framework import test_util +from tensorflow.python.framework import config as config_module from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers from tensorflow.python.keras import models -from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.layers import advanced_activations +from tensorflow.python.keras.layers import convolutional +from tensorflow.python.keras.layers import convolutional_recurrent +from tensorflow.python.keras.layers import core +from tensorflow.python.keras.layers import dense_attention +from tensorflow.python.keras.layers import embeddings +from tensorflow.python.keras.layers import local +from tensorflow.python.keras.layers import merge +from tensorflow.python.keras.layers import noise +from tensorflow.python.keras.layers import normalization +from tensorflow.python.keras.layers import normalization_v2 +from tensorflow.python.keras.layers import pooling from tensorflow.python.keras.layers import recurrent from tensorflow.python.keras.layers import recurrent_v2 +from tensorflow.python.keras.layers import wrappers from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.platform import test def create_mirrored_strategy(): - if context.num_gpus() >= 1: - return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0']) - else: - return mirrored_strategy.MirroredStrategy(['cpu:0']) + # The test creates two virtual CPUs, and we use both of them to test with + # multiple devices. + return mirrored_strategy.MirroredStrategy(['cpu:0', 'cpu:1']) -@test_util.run_all_in_graph_and_eager_modes class LayerCorrectnessTest(keras_parameterized.TestCase): - def _create_model_from_layer(self, layer, input_shape): - x = layers.Input(batch_input_shape=input_shape) - y = layer(x) - model = models.Model(x, y) + def setUp(self): + super(LayerCorrectnessTest, self).setUp() + # Set two virtual CPUs to test MirroredStrategy with multiple devices + cpus = config_module.list_physical_devices('CPU') + config_module.set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + ]) + + def _create_model_from_layer(self, layer, input_shapes): + inputs = [layers.Input(batch_input_shape=s) for s in input_shapes] + if len(inputs) == 1: + inputs = inputs[0] + y = layer(inputs) + model = models.Model(inputs, y) model.compile('sgd', 'mse') return model - def _test_layer(self, f32_layer, input_shape): + @parameterized.named_parameters( + ('LeakyReLU', advanced_activations.LeakyReLU, (2, 2)), + ('PReLU', advanced_activations.PReLU, (2, 2)), + ('ELU', advanced_activations.ELU, (2, 2)), + ('ThresholdedReLU', advanced_activations.ThresholdedReLU, (2, 2)), + ('Softmax', advanced_activations.Softmax, (2, 2)), + ('ReLU', advanced_activations.ReLU, (2, 2)), + ('Conv1D', lambda: convolutional.Conv1D(2, 2), (2, 2, 1)), + ('Conv2D', lambda: convolutional.Conv2D(2, 2), (2, 2, 2, 1)), + ('Conv3D', lambda: convolutional.Conv3D(2, 2), (2, 2, 2, 2, 1)), + ('Conv2DTranspose', lambda: convolutional.Conv2DTranspose(2, 2), + (2, 2, 2, 2)), + ('SeparableConv2D', lambda: convolutional.SeparableConv2D(2, 2), + (2, 2, 2, 1)), + ('DepthwiseConv2D', lambda: convolutional.DepthwiseConv2D(2, 2), + (2, 2, 2, 1)), + ('UpSampling2D', convolutional.UpSampling2D, (2, 2, 2, 1)), + ('ZeroPadding2D', convolutional.ZeroPadding2D, (2, 2, 2, 1)), + ('Cropping2D', convolutional.Cropping2D, (2, 3, 3, 1)), + ('ConvLSTM2D', + lambda: convolutional_recurrent.ConvLSTM2D(4, kernel_size=(2, 2)), + (4, 4, 4, 4, 4)), + ('Dense', lambda: core.Dense(2), (2, 2)), + ('Dropout', lambda: core.Dropout(0.5), (2, 2)), + ('SpatialDropout2D', lambda: core.SpatialDropout2D(0.5), (2, 2, 2, 2)), + ('Activation', lambda: core.Activation('sigmoid'), (2, 2)), + ('Reshape', lambda: core.Reshape((1, 4, 1)), (2, 2, 2)), + ('Permute', lambda: core.Permute((2, 1)), (2, 2, 2)), + ('Attention', dense_attention.Attention, + [(2, 2, 3), (2, 3, 3), (2, 3, 3)]), + ('AdditiveAttention', dense_attention.AdditiveAttention, + [(2, 2, 3), (2, 3, 3), (2, 3, 3)]), + ('Embedding', lambda: embeddings.Embedding(4, 4), (2, 4), 2e-3, 2e-3, + np.random.randint(4, size=(2, 4))), + ('LocallyConnected1D', lambda: local.LocallyConnected1D(2, 2), (2, 2, 1)), + ('LocallyConnected2D', lambda: local.LocallyConnected2D(2, 2), + (2, 2, 2, 1)), + ('Add', merge.Add, [(2, 2), (2, 2)]), + ('Subtract', merge.Subtract, [(2, 2), (2, 2)]), + ('Multiply', merge.Multiply, [(2, 2), (2, 2)]), + ('Average', merge.Average, [(2, 2), (2, 2)]), + ('Maximum', merge.Maximum, [(2, 2), (2, 2)]), + ('Minimum', merge.Minimum, [(2, 2), (2, 2)]), + ('Concatenate', merge.Concatenate, [(2, 2), (2, 2)]), + ('Dot', lambda: merge.Dot(1), [(2, 2), (2, 2)]), + ('GaussianNoise', lambda: noise.GaussianNoise(0.5), (2, 2)), + ('GaussianDropout', lambda: noise.GaussianDropout(0.5), (2, 2)), + ('AlphaDropout', lambda: noise.AlphaDropout(0.5), (2, 2)), + ('BatchNormalization', normalization_v2.BatchNormalization, (2, 2), + 1e-2, 1e-2), + ('LayerNormalization', normalization.LayerNormalization, (2, 2)), + ('MaxPooling2D', pooling.MaxPooling2D, (2, 2, 2, 1)), + ('AveragePooling2D', pooling.AveragePooling2D, (2, 2, 2, 1)), + ('GlobalMaxPooling2D', pooling.GlobalMaxPooling2D, (2, 2, 2, 1)), + ('GlobalAveragePooling2D', pooling.GlobalAveragePooling2D, (2, 2, 2, 1)), + ('SimpleRNN', lambda: recurrent.SimpleRNN(units=4), (4, 4, 4), + 1e-2, 1e-2), + ('GRU', lambda: recurrent.GRU(units=4), (4, 4, 4)), + ('LSTM', lambda: recurrent.LSTM(units=4), (4, 4, 4)), + ('GRUV2', lambda: recurrent_v2.GRU(units=4), (4, 4, 4)), + ('LSTMV2', lambda: recurrent_v2.LSTM(units=4), (4, 4, 4)), + ('TimeDistributed', lambda: wrappers.TimeDistributed(core.Dense(2)), + (2, 2, 2)), + ('Bidirectional', + lambda: wrappers.Bidirectional(recurrent.SimpleRNN(units=4)), (2, 2, 2)), + ) + def test_layer(self, f32_layer_fn, input_shape, rtol=2e-3, atol=2e-3, + input_data=None): """Tests a layer by comparing the float32 and mixed precision weights. - A float32 layer, a mixed precision layer, a distributed float32 layer, and a - distributed mixed precision layer are run. The four layers are identical - other than their dtypes and distribution strategies. The weights after - running fit() are asserted to be close. - - Running the distributed float32 layer does not test mixed precision but we - still test it for debugging purposes. If the distributed mixed precision - layer fails, it's easier to debug if you know whether the issue also occurs - in the distributed float32 layer. + A float32 layer, a mixed precision layer, and a distributed mixed precision + layer are run. The three layers are identical other than their dtypes and + distribution strategies. The outputs after predict() and weights after fit() + are asserted to be close. Args: - f32_layer: A float32 layer. The other three layers will automatically - be created from this - input_shape: The shape of the inputs to the layer, including the batch - dimension. + f32_layer_fn: A function returning a float32 layer. The other two layers + will automatically be created from this + input_shape: The shape of the input to the layer, including the batch + dimension. Or a list of shapes if the layer takes multiple inputs. + rtol: The relative tolerance to be asserted. + atol: The absolute tolerance to be asserted. + input_data: A Numpy array with the data of the input. If None, input data + will be randomly generated """ + if isinstance(input_shape[0], int): + input_shapes = [input_shape] + else: + input_shapes = input_shape strategy = create_mirrored_strategy() + f32_layer = f32_layer_fn() # Create the layers assert f32_layer.dtype == f32_layer._compute_dtype == 'float32' config = f32_layer.get_config() - distributed_f32_layer = f32_layer.__class__.from_config(config) config['dtype'] = policy.Policy('mixed_float16') mp_layer = f32_layer.__class__.from_config(config) distributed_mp_layer = f32_layer.__class__.from_config(config) - # Compute per_replica_input_shape for the distributed models - global_batch_size = input_shape[0] - assert global_batch_size % strategy.num_replicas_in_sync == 0 + # Compute per_replica_input_shapes for the distributed model + global_batch_size = input_shapes[0][0] + assert global_batch_size % strategy.num_replicas_in_sync == 0, ( + 'The number of replicas, %d, does not divide the global batch size of ' + '%d' % (strategy.num_replicas_in_sync, global_batch_size)) per_replica_batch_size = ( global_batch_size // strategy.num_replicas_in_sync) - per_replica_input_shape = list(input_shape) - per_replica_input_shape[0] = per_replica_batch_size + per_replica_input_shapes = [(per_replica_batch_size,) + s[1:] + for s in input_shapes] # Create the models - f32_model = self._create_model_from_layer(f32_layer, input_shape) - mp_model = self._create_model_from_layer(mp_layer, input_shape) + f32_model = self._create_model_from_layer(f32_layer, input_shapes) + mp_model = self._create_model_from_layer(mp_layer, input_shapes) with strategy.scope(): - distributed_f32_model = self._create_model_from_layer( - distributed_f32_layer, per_replica_input_shape) distributed_mp_model = self._create_model_from_layer( - distributed_mp_layer, per_replica_input_shape) + distributed_mp_layer, per_replica_input_shapes) # Set all model weights to the same values f32_weights = f32_model.get_weights() - for model in mp_model, distributed_f32_model, distributed_mp_model: - model.set_weights(f32_weights) + mp_model.set_weights(f32_weights) + distributed_mp_model.set_weights(f32_weights) + + # Generate input data + if input_data is None: + # Cast inputs to float16 to avoid measuring error from having f16 layers + # cast to float16. + input_data = [np.random.normal(size=s).astype('float16') + for s in input_shapes] + if len(input_data) == 1: + input_data = input_data[0] + + # Assert all models have close outputs. + f32_output = f32_model.predict(input_data) + mp_output = mp_model.predict(input_data) + self.assertAllClose( + mp_output, f32_output, rtol=rtol, atol=atol) + self.assertAllClose( + distributed_mp_model.predict(input_data), f32_output, rtol=rtol, + atol=atol) # Run fit() on models - x = np.random.normal(size=input_shape) - y = np.random.normal(size=input_shape) - for model in (f32_model, mp_model, distributed_f32_model, - distributed_mp_model): - model.fit(x, y, batch_size=global_batch_size) + output = np.random.normal(size=f32_model.outputs[0].shape).astype('float16') + for model in f32_model, mp_model, distributed_mp_model: + model.fit(input_data, output, batch_size=global_batch_size) # Assert all models have close weights f32_weights = f32_model.get_weights() self.assertAllClose( - mp_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4) + mp_model.get_weights(), f32_weights, rtol=rtol, atol=atol) self.assertAllClose( - distributed_f32_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4) - self.assertAllClose( - distributed_mp_model.get_weights(), f32_weights, rtol=1e-2, atol=1e-4) + distributed_mp_model.get_weights(), f32_weights, rtol=rtol, atol=atol) - # Note: There is no need to test every layer subclass here, as otherwise this - # test would take too long. Only layers which do something special or are - # unusual in regards to mixed precision need to be tested. - - # We test RNNs as some RNNs use the implementation_selector grappler pass, - # which can cause issues with AutoCastVariables. - @testing_utils.enable_v2_dtype_behavior - def test_simple_rnn(self): - self._test_layer(recurrent.SimpleRNN(units=4, return_sequences=True), - input_shape=(4, 4, 4)) - - @testing_utils.enable_v2_dtype_behavior - def test_gru(self): - self._test_layer(recurrent_v2.GRU(units=4, return_sequences=True), - input_shape=(4, 4, 4)) - - @testing_utils.enable_v2_dtype_behavior - def test_lstm(self): - self._test_layer(recurrent_v2.LSTM(units=4, return_sequences=True), - input_shape=(4, 4, 4)) if __name__ == '__main__': + v2_compat.enable_v2_behavior() test.main()