From 00774b2d5e7843596158159ee1e213ae40d6ed52 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Wed, 14 Nov 2018 11:02:37 -0800 Subject: [PATCH] Sample change making a kernel unittest work for both TF 1.x and 2.x. This change makes all the tests that don't check gradients in relu_op_test.py work for both TensorFlow 1.x and 2.x (where eager execution, resource variables etc. are enabled by default). PiperOrigin-RevId: 221474407 --- .../python/kernel_tests/relu_op_test.py | 259 +++++++++--------- 1 file changed, 131 insertions(+), 128 deletions(-) diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py index b0f2796ede1..68243f27c05 100644 --- a/tensorflow/python/kernel_tests/relu_op_test.py +++ b/tensorflow/python/kernel_tests/relu_op_test.py @@ -25,6 +25,7 @@ from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl @@ -55,52 +56,52 @@ class ReluTest(test.TestCase): np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]]))) - def _testRelu(self, np_features, use_gpu=False): + def _testRelu(self, np_features): np_relu = self._npRelu(np_features) - with self.cached_session(use_gpu=use_gpu): - relu = nn_ops.relu(np_features) - tf_relu = relu.eval() + tf_relu = nn_ops.relu(np_features) self.assertAllClose(np_relu, tf_relu) - self.assertShapeEqual(np_relu, relu) + self.assertShapeEqual(np_relu, tf_relu) - def testNumbers(self): + def testNumbersCPU(self): for t in [np.int32, np.int64, np.float16, np.float32, np.float64]: - self._testRelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=False) - if t in [np.float16, np.float32, np.float64]: + # Force execution on CPU even if a GPU kernel is available for the type. + with ops.device("/device:CPU:0"): self._testRelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=True) + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) - def _testReluInt8x4(self, np_inputs): - if not test.is_gpu_available(cuda_only=True): - return - np_relu = self._npRelu(np_inputs) - with self.cached_session(use_gpu=True): - relu = nn_ops.relu(constant_op.constant(np_inputs, dtypes.qint8)) - if np_inputs.size % 4 == 0: - tf_relu = relu.eval() - self.assertAllClose(np_relu, tf_relu) - self.assertShapeEqual(np_relu, relu) - else: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Tensor size must be a multiple of 4 for Relu. Got %d" % - np_inputs.size): - tf_relu = relu.eval() + def testNumbersGPU(self): + if not test.is_gpu_available(): + self.skipTest("No GPU available") + for t in [np.float16, np.float32, np.float64]: + self._testRelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) def testReluInt8x4GoodShape(self): - self._testReluInt8x4(np.array([[-50, 7, 23, 0], [-1, -5, 6, 11]])) + if not test.is_gpu_available(cuda_only=True): + self.skipTest("No GPU available") + inputs = np.array([[-50, 7, 23, 0], [-1, -5, 6, 11]]) + np_relu = self._npRelu(inputs) + tf_relu = nn_ops.relu(constant_op.constant(inputs, dtypes.qint8)) + self.assertAllClose(np_relu, tf_relu) + self.assertShapeEqual(np_relu, tf_relu) def testReluInt8x4BadShape(self): - np_inputs = np.array([[-50, 7, 23], [0, 1, -5], [6, -2, 11]]) - self.assertEqual(np_inputs.size, 9) - self._testReluInt8x4(np_inputs) - np_inputs = np.array( - [1, -2, 3, -4, 5, -6, 7, -8, 9, -8, 7, -6, 5, -4, 3, -2, 1]) - self.assertEqual(np_inputs.size, 17) - self._testReluInt8x4(np_inputs) + if not test.is_gpu_available(cuda_only=True): + self.skipTest("No GPU available") + inputs = constant_op.constant( + np.array([[-50, 7, 23], [0, 1, -5], [6, -2, 11]]), dtypes.qint8) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Tensor size must be a multiple of 4 for Relu. Got 9"): + self.evaluate(nn_ops.relu(inputs)) + + inputs = constant_op.constant( + np.array([1, -2, 3, -4, 5, -6, 7, -8, 9, -8, 7, -6, 5, -4, 3, -2, 1]), + dtypes.qint8) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Tensor size must be a multiple of 4 for Relu. Got 17"): + self.evaluate(nn_ops.relu(inputs)) # The gradient test for ReLU is a bit tricky as the derivative is not well # defined at around zero and we want to avoid that in terms of input values. @@ -202,15 +203,15 @@ class ReluTest(test.TestCase): self.assertLess(err, 1e-10) def testGradientScalar(self): - with self.cached_session() as sess: - x = variables.Variable(100.) - y = nn_ops.relu(x) - loss = y**2 - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.25) - train_op = optimizer.minimize(loss) - sess.run(variables.global_variables_initializer()) - sess.run(train_op) - self.assertAllClose(x.eval(), 50.0) + x = variables.Variable(100.) + + def loss(): + return nn_ops.relu(x)**2 + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.25) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(optimizer.minimize(loss)) + self.assertAllClose(x.read_value(), 50.0) class Relu6Test(test.TestCase): @@ -228,23 +229,25 @@ class Relu6Test(test.TestCase): np.array([[-0.9, 0.7, -0.5, 0.3, 6.0], [0.1, -0.3, 6.5, -0.7, 0.9]]))) - def _testRelu6(self, np_features, use_gpu=False): + def _testRelu6(self, np_features): np_relu6 = self._npRelu6(np_features) - with self.cached_session(use_gpu=use_gpu): - relu6 = nn_ops.relu6(np_features) - tf_relu6 = relu6.eval() + tf_relu6 = nn_ops.relu6(np_features) self.assertAllClose(np_relu6, tf_relu6) - self.assertShapeEqual(np_relu6, relu6) + self.assertShapeEqual(np_relu6, tf_relu6) - def testNumbers(self): + def testNumbersCPU(self): for t in [np.int32, np.int64, np.float16, np.float32, np.float64]: - self._testRelu6( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=False) - if t in [np.float16, np.float, np.double]: + # Force execution on CPU even if a GPU kernel is available for the type. + with ops.device("/device:CPU:0"): self._testRelu6( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=True) + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) + + def testNumbersGPU(self): + if not test.is_gpu_available(): + self.skipTest("No GPU available") + for t in [np.float16, np.float, np.double]: + self._testRelu6( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) # The gradient test for ReLU6 is a bit tricky as the derivative is # not well defined at around zero and six and we want to avoid that @@ -297,25 +300,27 @@ class LeakyReluTest(test.TestCase): 0.9]]), alpha=0.1)) - def _testLeakyRelu(self, np_features, alpha, use_gpu=False): + def _testLeakyRelu(self, np_features, alpha): np_leaky_relu = self._npLeakyRelu(np_features, alpha) - with self.test_session(use_gpu=use_gpu): - leaky_relu = nn_ops.leaky_relu(np_features, alpha) - tf_leaky_relu = leaky_relu.eval() + tf_leaky_relu = nn_ops.leaky_relu(np_features, alpha) self.assertAllClose(np_leaky_relu, tf_leaky_relu) - self.assertShapeEqual(np_leaky_relu, leaky_relu) + self.assertShapeEqual(np_leaky_relu, tf_leaky_relu) - def testNumbers(self): + def testNumbersCPU(self): for t in [np.int32, np.int64, np.float16, np.float32, np.float64]: - self._testLeakyRelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - alpha=0.2, - use_gpu=False) - if t in [np.float16, np.float32, np.float64]: + # Force execution on CPU even if a GPU kernel is available for the type. + with ops.device("/device:CPU:0"): self._testLeakyRelu( np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - alpha=0.1, - use_gpu=True) + alpha=0.2) + + def testNumbersGPU(self): + if not test.is_gpu_available(): + self.skipTest("No GPU available") + for t in [np.float16, np.float32, np.float64]: + self._testLeakyRelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), + alpha=0.1) # The gradient test for Leaky ReLU is a bit tricky as the derivative is not # well defined at around zero and we want to avoid that in terms of input @@ -391,15 +396,15 @@ class LeakyReluTest(test.TestCase): self.assertLess(err, 1e-10) def testGradientScalar(self): - with self.test_session() as sess: - x = variables.Variable(-100.) - y = nn_ops.leaky_relu(x, 0.05) - loss = y**2 - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2) - train_op = optimizer.minimize(loss) - sess.run(variables.global_variables_initializer()) - sess.run(train_op) - self.assertAllClose(x.eval(), -99.9) + x = variables.Variable(-100.) + + def loss(): + return nn_ops.leaky_relu(x, 0.05)**2 + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(optimizer.minimize(loss)) + self.assertAllClose(x.read_value(), -99.9) class EluTest(test.TestCase): @@ -415,22 +420,24 @@ class EluTest(test.TestCase): np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]]))) - def _testElu(self, np_features, use_gpu=False): + def _testElu(self, np_features): np_elu = self._npElu(np_features) - with self.cached_session(use_gpu=use_gpu): - elu = nn_ops.elu(np_features) - tf_elu = elu.eval() + tf_elu = nn_ops.elu(np_features) self.assertAllClose(np_elu, tf_elu) - self.assertShapeEqual(np_elu, elu) + self.assertShapeEqual(np_elu, tf_elu) - def testNumbers(self): + def testNumbersCPU(self): for t in [np.float16, np.float32, np.float64]: - self._testElu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=False) - self._testElu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=True) + # Force execution on CPU even if a GPU kernel is available for the type. + with ops.device("/device:CPU:0"): + self._testElu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) + + def testNumbersGPU(self): + if not test.is_gpu_available(): + self.skipTest("No GPU available") + for t in [np.float16, np.float32, np.float64]: + self._testElu(np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) def testGradientFloat32(self): with self.cached_session(): @@ -517,22 +524,20 @@ class SeluTest(test.TestCase): np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]]))) - def _testSelu(self, np_features, use_gpu=False): + def _testSelu(self, np_features): np_selu = self._npSelu(np_features) - with self.cached_session(use_gpu=use_gpu): - selu = nn_ops.selu(np_features) - tf_selu = selu.eval() + tf_selu = nn_ops.selu(np_features) self.assertAllClose(np_selu, tf_selu) - self.assertShapeEqual(np_selu, selu) + self.assertShapeEqual(np_selu, tf_selu) def testNumbers(self): for t in [np.float16, np.float32, np.float64]: self._testSelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=False) - self._testSelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=True) + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) + # Force executed on CPU in case GPU kernels are avaiable. + with ops.device("/device:CPU:0"): + self._testSelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) def testGradientFloat32(self): with self.cached_session(): @@ -599,46 +604,44 @@ class CreluTest(test.TestCase): t = nn_ops.crelu(f) self.assertEqual([50, 5, 7, 20], t.get_shape()) - def _testCrelu(self, np_features, use_gpu=False): + def _testCrelu(self, np_features): np_relu = np.maximum(np_features, np.zeros_like(np_features)) np_neg_relu = np.maximum(-np_features, np.zeros_like(np_features)) np_crelu = np.concatenate((np_relu, np_neg_relu), len(np_features.shape) - 1) - with self.cached_session(use_gpu=use_gpu): - crelu = nn_ops.crelu(np_features) - tf_relu = crelu.eval() + tf_crelu = nn_ops.crelu(np_features) - self.assertAllClose(np_crelu, tf_relu) - self.assertShapeEqual(np_crelu, crelu) + self.assertAllClose(np_crelu, tf_crelu) + self.assertShapeEqual(np_crelu, tf_crelu) - def testNumbers(self): + def testNumbersCPU(self): for t in [np.int32, np.int64, np.float16, np.float32, np.float64]: - self._testCrelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=False) - if t in [np.float16, np.float32, np.float64]: + # Force execution on CPU even if a GPU kernel is available for the type. + with ops.device("/device:CPU:0"): self._testCrelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), - use_gpu=True) + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) + + def testNumbersGPU(self): + if not test.is_gpu_available(): + self.skipTest("No GPU available") + for t in [np.float16, np.float32, np.float64]: + self._testCrelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t)) def testNumbersWithAxis0(self): - with self.cached_session(): - crelu = nn_ops.crelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0) - tf_relu = crelu.eval() - np_crelu = np.array([[0, 7, 0, 3, 0], [1, 0, 5, 0, 9], [9, 0, 5, 0, 1], - [0, 3, 0, 7, 0]]) - self.assertAllEqual(np_crelu, tf_relu) + tf_crelu = nn_ops.crelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0) + np_crelu = np.array([[0, 7, 0, 3, 0], [1, 0, 5, 0, 9], [9, 0, 5, 0, 1], + [0, 3, 0, 7, 0]]) + self.assertAllEqual(np_crelu, tf_crelu) def testNumbersWithAxis1(self): - with self.cached_session(): - crelu = nn_ops.crelu( - np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1) - tf_relu = crelu.eval() - np_crelu = np.array([[0, 7, 0, 3, 0, 9, 0, 5, 0, 1], - [1, 0, 5, 0, 9, 0, 3, 0, 7, 0]]) - self.assertAllEqual(np_crelu, tf_relu) + tf_crelu = nn_ops.crelu( + np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1) + np_crelu = np.array([[0, 7, 0, 3, 0, 9, 0, 5, 0, 1], + [1, 0, 5, 0, 9, 0, 3, 0, 7, 0]]) + self.assertAllEqual(np_crelu, tf_crelu) if __name__ == "__main__":