Merge pull request #35666 from ROCmSoftwarePlatform:google_upstream_rocblas_complex
PiperOrigin-RevId: 304218949 Change-Id: Ic8b3408a71502444b8caf3846d4662ae21f1c325
This commit is contained in:
		
						commit
						c58518c893
					
				| @ -3411,7 +3411,6 @@ tf_py_test( | ||||
|     data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"], | ||||
|     shard_count = 20, | ||||
|     tags = [ | ||||
|         "no_rocm",  # flaky test | ||||
|         "no_windows", | ||||
|     ], | ||||
|     deps = [ | ||||
|  | ||||
| @ -262,10 +262,9 @@ class BatchMatMulBenchmark(test.Benchmark): | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   dtypes_to_test = [np.float16, np.float32, np.float64, np.int32] | ||||
|   if not test.is_built_with_rocm(): | ||||
|     # ROCm does not support BLAS operations for complex types | ||||
|     dtypes_to_test += [np.complex64, np.complex128] | ||||
|   dtypes_to_test = [ | ||||
|       np.float16, np.float32, np.float64, np.int32, np.complex64, np.complex128 | ||||
|   ] | ||||
|   for dtype_ in dtypes_to_test: | ||||
|     for adjoint_a_ in False, True: | ||||
|       for adjoint_b_ in False, True: | ||||
|  | ||||
| @ -183,10 +183,10 @@ def _GetEigTest(dtype_, shape_, compute_v_): | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64] | ||||
|   if not test.is_built_with_rocm(): | ||||
|     # ROCm does not support BLAS operations for complex types | ||||
|     dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128] | ||||
|   dtypes_to_test = [ | ||||
|       dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, | ||||
|       dtypes_lib.complex128 | ||||
|   ] | ||||
|   for compute_v in True, False: | ||||
|     for dtype in dtypes_to_test: | ||||
|       for size in 1, 2, 5, 10: | ||||
|  | ||||
| @ -49,6 +49,7 @@ def identicaltest(tc, init1, init2, shape=None): | ||||
|     init2: An Initializer that generates a tensor of a given shape | ||||
|     shape: Shape of the tensor to initialize or `None` to use a vector of length | ||||
|       100. | ||||
| 
 | ||||
|   Returns: | ||||
|     True or False as determined by test. | ||||
|   """ | ||||
| @ -75,6 +76,7 @@ def duplicated_initializer(tc, init, graph_seed, shape=None): | ||||
|     graph_seed: A graph-level seed to use. | ||||
|     shape: Shape of the tensor to initialize or `None` to use a vector of length | ||||
|       100. | ||||
| 
 | ||||
|   Returns: | ||||
|     True or False as determined by test. | ||||
|   """ | ||||
| @ -94,6 +96,7 @@ def _init_sampler(tc, init, num): | ||||
|     tc: An instance of TensorFlowTestCase. | ||||
|     init: An Initializer that generates a tensor of a given shape | ||||
|     num: Size of 1D tensor to create. | ||||
| 
 | ||||
|   Returns: | ||||
|     Function to generate a random tensor. | ||||
|   """ | ||||
| @ -187,8 +190,8 @@ class ConstantInitializersTest(test.TestCase): | ||||
|     expected = list(value) | ||||
| 
 | ||||
|     self._testNDimConstantInitializer("list", value, shape, expected) | ||||
|     self._testNDimConstantInitializer("ndarray", | ||||
|                                       np.asarray(value), shape, expected) | ||||
|     self._testNDimConstantInitializer("ndarray", np.asarray(value), shape, | ||||
|                                       expected) | ||||
|     self._testNDimConstantInitializer("2D-ndarray", | ||||
|                                       np.asarray(value).reshape(tuple(shape)), | ||||
|                                       shape, expected) | ||||
| @ -214,11 +217,11 @@ class ConstantInitializersTest(test.TestCase): | ||||
|     expected = list(value) | ||||
| 
 | ||||
|     self._testNDimConstantInitializerLessValues("list", value, shape, expected) | ||||
|     self._testNDimConstantInitializerLessValues("ndarray", | ||||
|                                                 np.asarray(value), shape, | ||||
|                                                 expected) | ||||
|     self._testNDimConstantInitializerLessValues("ndarray", np.asarray(value), | ||||
|                                                 shape, expected) | ||||
|     self._testNDimConstantInitializerLessValues( | ||||
|         "2D-ndarray", np.asarray(value).reshape(tuple([2, 3])), shape, expected) | ||||
|         "2D-ndarray", | ||||
|         np.asarray(value).reshape(tuple([2, 3])), shape, expected) | ||||
| 
 | ||||
|   def _testNDimConstantInitializerMoreValues(self, value, shape): | ||||
|     ops.reset_default_graph() | ||||
| @ -242,8 +245,8 @@ class ConstantInitializersTest(test.TestCase): | ||||
| 
 | ||||
|   def testInvalidValueTypeForConstantInitializerCausesTypeError(self): | ||||
|     c = constant_op.constant([1.0, 2.0, 3.0]) | ||||
|     with self.assertRaisesRegexp( | ||||
|         TypeError, r"Invalid type for initial value: .*Tensor.*"): | ||||
|     with self.assertRaisesRegexp(TypeError, | ||||
|                                  r"Invalid type for initial value: .*Tensor.*"): | ||||
|       init_ops.constant_initializer(c, dtype=dtypes.float32) | ||||
|     v = variables.Variable([3.0, 2.0, 1.0]) | ||||
|     with self.assertRaisesRegexp( | ||||
| @ -393,11 +396,11 @@ class VarianceScalingInitializationTest(test.TestCase): | ||||
|     expect_mean = 0. | ||||
|     expect_var = 1. / shape[0] | ||||
|     init = init_ops.variance_scaling_initializer( | ||||
|         distribution='truncated_normal') | ||||
|         distribution="truncated_normal") | ||||
| 
 | ||||
|     with self.session(use_gpu=True), \ | ||||
|       test.mock.patch.object( | ||||
|           random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ | ||||
|           random_ops, "truncated_normal", wraps=random_ops.truncated_normal) \ | ||||
|           as mock_truncated_normal: | ||||
|       x = init(shape).eval() | ||||
|       self.assertTrue(mock_truncated_normal.called) | ||||
| @ -410,11 +413,11 @@ class VarianceScalingInitializationTest(test.TestCase): | ||||
|     shape = [100, 100] | ||||
|     expect_mean = 0. | ||||
|     expect_var = 1. / shape[0] | ||||
|     init = init_ops.variance_scaling_initializer(distribution='normal') | ||||
|     init = init_ops.variance_scaling_initializer(distribution="normal") | ||||
| 
 | ||||
|     with self.session(use_gpu=True), \ | ||||
|       test.mock.patch.object( | ||||
|           random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ | ||||
|           random_ops, "truncated_normal", wraps=random_ops.truncated_normal) \ | ||||
|           as mock_truncated_normal: | ||||
|       x = init(shape).eval() | ||||
|       self.assertTrue(mock_truncated_normal.called) | ||||
| @ -428,11 +431,11 @@ class VarianceScalingInitializationTest(test.TestCase): | ||||
|     expect_mean = 0. | ||||
|     expect_var = 1. / shape[0] | ||||
|     init = init_ops.variance_scaling_initializer( | ||||
|         distribution='untruncated_normal') | ||||
|         distribution="untruncated_normal") | ||||
| 
 | ||||
|     with self.session(use_gpu=True), \ | ||||
|       test.mock.patch.object( | ||||
|           random_ops, 'random_normal', wraps=random_ops.random_normal) \ | ||||
|           random_ops, "random_normal", wraps=random_ops.random_normal) \ | ||||
|           as mock_random_normal: | ||||
|       x = init(shape).eval() | ||||
|       self.assertTrue(mock_random_normal.called) | ||||
| @ -445,7 +448,7 @@ class VarianceScalingInitializationTest(test.TestCase): | ||||
|     shape = [100, 100] | ||||
|     expect_mean = 0. | ||||
|     expect_var = 1. / shape[0] | ||||
|     init = init_ops.variance_scaling_initializer(distribution='uniform') | ||||
|     init = init_ops.variance_scaling_initializer(distribution="uniform") | ||||
| 
 | ||||
|     with self.session(use_gpu=True): | ||||
|       x = init(shape).eval() | ||||
| @ -525,17 +528,13 @@ class RangeTest(test.TestCase): | ||||
|         math_ops.range(zero_float64, zero_int32, 1).dtype, dtypes.float64) | ||||
| 
 | ||||
|     self.assertEqual( | ||||
|         math_ops.range( | ||||
|             0, 0, 1, dtype=dtypes.int32).dtype, dtypes.int32) | ||||
|         math_ops.range(0, 0, 1, dtype=dtypes.int32).dtype, dtypes.int32) | ||||
|     self.assertEqual( | ||||
|         math_ops.range( | ||||
|             0, 0, 1, dtype=dtypes.int64).dtype, dtypes.int64) | ||||
|         math_ops.range(0, 0, 1, dtype=dtypes.int64).dtype, dtypes.int64) | ||||
|     self.assertEqual( | ||||
|         math_ops.range( | ||||
|             0, 0, 1, dtype=dtypes.float32).dtype, dtypes.float32) | ||||
|         math_ops.range(0, 0, 1, dtype=dtypes.float32).dtype, dtypes.float32) | ||||
|     self.assertEqual( | ||||
|         math_ops.range( | ||||
|             0, 0, 1, dtype=dtypes.float64).dtype, dtypes.float64) | ||||
|         math_ops.range(0, 0, 1, dtype=dtypes.float64).dtype, dtypes.float64) | ||||
| 
 | ||||
|   def testMixedDType(self): | ||||
|     # Test case for GitHub issue 35710 | ||||
| @ -578,8 +577,8 @@ class LinSpaceTest(test.TestCase): | ||||
|       self.assertArrayNear( | ||||
|           self._LinSpace(-1., -5., 3), np.array([-1., -3., -5.]), 1e-5) | ||||
|       self.assertArrayNear( | ||||
|           self._LinSpace(-1., -5., 4), | ||||
|           np.array([-1., -7. / 3., -11. / 3., -5.]), 1e-5) | ||||
|           self._LinSpace(-1., -5., 4), np.array([-1., -7. / 3., -11. / 3., | ||||
|                                                  -5.]), 1e-5) | ||||
| 
 | ||||
|   def testNegativeToPositive(self): | ||||
|     for self.force_gpu in self._gpu_modes(): | ||||
| @ -859,7 +858,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase): | ||||
| 
 | ||||
|   def testInvalidDataType(self): | ||||
|     self.assertRaises( | ||||
|         ValueError, init_ops.convolutional_delta_orthogonal, | ||||
|         ValueError, | ||||
|         init_ops.convolutional_delta_orthogonal, | ||||
|         dtype=dtypes.string) | ||||
| 
 | ||||
|   def testInvalidShape(self): | ||||
| @ -872,8 +872,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase): | ||||
|     shape = (3, 3, 10, 10) | ||||
|     for dtype in [dtypes.float32, dtypes.float64]: | ||||
|       init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype) | ||||
|       init2 = init_ops.convolutional_delta_orthogonal(gain=3.14, | ||||
|                                                       seed=1, dtype=dtype) | ||||
|       init2 = init_ops.convolutional_delta_orthogonal( | ||||
|           gain=3.14, seed=1, dtype=dtype) | ||||
|       with self.session(graph=ops.Graph(), use_gpu=True): | ||||
|         t1 = init1(shape).eval() | ||||
|         t2 = init2(shape).eval() | ||||
| @ -896,18 +896,14 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase): | ||||
|         else: | ||||
|           shape = [4, 16, 16, 16, 64] | ||||
|           convolution = convolutional.conv3d | ||||
| 
 | ||||
|           if test.is_built_with_rocm(): | ||||
|             # This subtest triggers a known bug in ROCm runtime code | ||||
|             # The bug has been fixed and will be available in ROCm 2.7 | ||||
|             # Re-enable this test once ROCm 2.7 is released | ||||
|             continue | ||||
| 
 | ||||
|         inputs = random_ops.random_normal(shape, dtype=dtype) | ||||
|         inputs_2norm = linalg_ops.norm(inputs) | ||||
|         outputs = convolution( | ||||
|             inputs, padding="same", filters=128, | ||||
|             kernel_size=kernel_size, use_bias=False, | ||||
|             inputs, | ||||
|             padding="same", | ||||
|             filters=128, | ||||
|             kernel_size=kernel_size, | ||||
|             use_bias=False, | ||||
|             kernel_initializer=init_ops.convolutional_delta_orthogonal( | ||||
|                 gain=gain)) | ||||
|         outputs_shape = shape[0:-1] + [128] | ||||
| @ -931,9 +927,10 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase): | ||||
|     tol = 1e-5 | ||||
|     with self.session(use_gpu=True): | ||||
|       for i in range(count): | ||||
|         x = variable_scope.get_variable("{}".format(i), shape=shape, | ||||
|                                         initializer= | ||||
|                                         init_ops.convolutional_delta_orthogonal) | ||||
|         x = variable_scope.get_variable( | ||||
|             "{}".format(i), | ||||
|             shape=shape, | ||||
|             initializer=init_ops.convolutional_delta_orthogonal) | ||||
|         x.initializer.run() | ||||
|         y = self.evaluate(x)[1, 1, :, :] | ||||
|         determinant = np.linalg.det(y) | ||||
| @ -971,8 +968,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase): | ||||
| 
 | ||||
|   def testInvalidDataType(self): | ||||
|     self.assertRaises( | ||||
|         ValueError, init_ops.convolutional_orthogonal_1d, | ||||
|         dtype=dtypes.string) | ||||
|         ValueError, init_ops.convolutional_orthogonal_1d, dtype=dtypes.string) | ||||
| 
 | ||||
|   def testInvalidShape(self): | ||||
|     init1 = init_ops.convolutional_orthogonal_1d() | ||||
| @ -984,8 +980,8 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase): | ||||
|     shape = (3, 10, 10) | ||||
|     for dtype in [dtypes.float32, dtypes.float64]: | ||||
|       init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype) | ||||
|       init2 = init_ops.convolutional_orthogonal_1d(gain=3.14, | ||||
|                                                    seed=1, dtype=dtype) | ||||
|       init2 = init_ops.convolutional_orthogonal_1d( | ||||
|           gain=3.14, seed=1, dtype=dtype) | ||||
|       with self.session(graph=ops.Graph(), use_gpu=True): | ||||
|         t1 = init1(shape).eval() | ||||
|         t2 = init2(shape).eval() | ||||
| @ -1000,9 +996,10 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase): | ||||
|     tol = 1e-5 | ||||
|     with self.session(use_gpu=True): | ||||
|       for i in range(count): | ||||
|         x = variable_scope.get_variable("{}".format(i), shape=shape, | ||||
|                                         initializer= | ||||
|                                         init_ops.convolutional_orthogonal_1d) | ||||
|         x = variable_scope.get_variable( | ||||
|             "{}".format(i), | ||||
|             shape=shape, | ||||
|             initializer=init_ops.convolutional_orthogonal_1d) | ||||
|         x.initializer.run() | ||||
|         y = np.sum(x.eval(), axis=0) | ||||
|         determinant = np.linalg.det(y) | ||||
| @ -1018,6 +1015,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase): | ||||
| 
 | ||||
|   @test_util.run_deprecated_v1 | ||||
|   def testShapesValues(self): | ||||
| 
 | ||||
|     def circular_pad(input_, width, kernel_size): | ||||
|       """Pad input_ for computing (circular) convolution. | ||||
| 
 | ||||
| @ -1025,6 +1023,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase): | ||||
|         input_: the input tensor | ||||
|         width: the width of the tensor. | ||||
|         kernel_size: the kernel size of the filter. | ||||
| 
 | ||||
|       Returns: | ||||
|         a tensor whose width is (width + kernel_size - 1). | ||||
|       """ | ||||
| @ -1053,8 +1052,11 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase): | ||||
|       inputs_2norm = linalg_ops.norm(inputs) | ||||
|       input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) | ||||
|       outputs = convolution( | ||||
|           input_with_circular_pad, padding="valid", filters=cout, | ||||
|           kernel_size=kernel_size[0], use_bias=False, | ||||
|           input_with_circular_pad, | ||||
|           padding="valid", | ||||
|           filters=cout, | ||||
|           kernel_size=kernel_size[0], | ||||
|           use_bias=False, | ||||
|           kernel_initializer=init_ops.convolutional_orthogonal_1d(gain=gain)) | ||||
|       outputs_2norm = linalg_ops.norm(outputs) | ||||
|       ratio = outputs_2norm / inputs_2norm | ||||
| @ -1091,8 +1093,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase): | ||||
| 
 | ||||
|   def testInvalidDataType(self): | ||||
|     self.assertRaises( | ||||
|         ValueError, init_ops.convolutional_orthogonal_2d, | ||||
|         dtype=dtypes.string) | ||||
|         ValueError, init_ops.convolutional_orthogonal_2d, dtype=dtypes.string) | ||||
| 
 | ||||
|   def testInvalidShape(self): | ||||
|     init1 = init_ops.convolutional_orthogonal_2d() | ||||
| @ -1104,8 +1105,8 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase): | ||||
|     shape = (3, 3, 10, 10) | ||||
|     for dtype in [dtypes.float32, dtypes.float64]: | ||||
|       init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype) | ||||
|       init2 = init_ops.convolutional_orthogonal_2d(gain=3.14, | ||||
|                                                    seed=1, dtype=dtype) | ||||
|       init2 = init_ops.convolutional_orthogonal_2d( | ||||
|           gain=3.14, seed=1, dtype=dtype) | ||||
|       with self.session(graph=ops.Graph(), use_gpu=True): | ||||
|         t1 = init1(shape).eval() | ||||
|         t2 = init2(shape).eval() | ||||
| @ -1113,6 +1114,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase): | ||||
| 
 | ||||
|   @test_util.run_deprecated_v1 | ||||
|   def testShapesValues(self): | ||||
| 
 | ||||
|     def circular_pad(input_, width, kernel_size): | ||||
|       """Pad input_ for computing (circular) convolution. | ||||
| 
 | ||||
| @ -1120,6 +1122,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase): | ||||
|         input_: the input tensor | ||||
|         width: the width of the tensor. | ||||
|         kernel_size: the kernel size of the filter. | ||||
| 
 | ||||
|       Returns: | ||||
|         a tensor whose width is (width + kernel_size - 1). | ||||
|       """ | ||||
| @ -1153,8 +1156,11 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase): | ||||
|       inputs_2norm = linalg_ops.norm(inputs) | ||||
|       input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) | ||||
|       outputs = convolution( | ||||
|           input_with_circular_pad, padding="valid", filters=cout, | ||||
|           kernel_size=kernel_size, use_bias=False, | ||||
|           input_with_circular_pad, | ||||
|           padding="valid", | ||||
|           filters=cout, | ||||
|           kernel_size=kernel_size, | ||||
|           use_bias=False, | ||||
|           kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain)) | ||||
|       outputs_2norm = linalg_ops.norm(outputs) | ||||
|       ratio = outputs_2norm / inputs_2norm | ||||
| @ -1191,8 +1197,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase): | ||||
| 
 | ||||
|   def testInvalidDataType(self): | ||||
|     self.assertRaises( | ||||
|         ValueError, init_ops.convolutional_orthogonal_3d, | ||||
|         dtype=dtypes.string) | ||||
|         ValueError, init_ops.convolutional_orthogonal_3d, dtype=dtypes.string) | ||||
| 
 | ||||
|   def testInvalidShape(self): | ||||
|     init1 = init_ops.convolutional_orthogonal_3d() | ||||
| @ -1204,8 +1209,8 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase): | ||||
|     shape = (3, 3, 3, 10, 10) | ||||
|     for dtype in [dtypes.float32, dtypes.float64]: | ||||
|       init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype) | ||||
|       init2 = init_ops.convolutional_orthogonal_3d(gain=3.14, | ||||
|                                                    seed=1, dtype=dtype) | ||||
|       init2 = init_ops.convolutional_orthogonal_3d( | ||||
|           gain=3.14, seed=1, dtype=dtype) | ||||
|       with self.session(graph=ops.Graph(), use_gpu=True): | ||||
|         t1 = init1(shape).eval() | ||||
|         t2 = init2(shape).eval() | ||||
| @ -1220,9 +1225,10 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase): | ||||
|     tol = 1e-5 | ||||
|     with self.session(use_gpu=True): | ||||
|       for i in range(count): | ||||
|         x = variable_scope.get_variable("{}".format(i), shape=shape, | ||||
|                                         initializer= | ||||
|                                         init_ops.convolutional_orthogonal_3d) | ||||
|         x = variable_scope.get_variable( | ||||
|             "{}".format(i), | ||||
|             shape=shape, | ||||
|             initializer=init_ops.convolutional_orthogonal_3d) | ||||
|         x.initializer.run() | ||||
|         y = np.sum(x.eval(), axis=(0, 1, 2)) | ||||
|         determinant = np.linalg.det(y) | ||||
| @ -1238,6 +1244,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase): | ||||
| 
 | ||||
|   @test_util.run_deprecated_v1 | ||||
|   def testShapesValues(self): | ||||
| 
 | ||||
|     def circular_pad(input_, width, kernel_size): | ||||
|       """Padding input_ for computing circular convolution. | ||||
| 
 | ||||
| @ -1255,14 +1262,12 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase): | ||||
| 
 | ||||
|       tmp_up = array_ops.slice(input_, [0, width - beginning, 0, 0, 0], | ||||
|                                [-1, beginning, -1, -1, -1]) | ||||
|       tmp_down = array_ops.slice(input_, [0, 0, 0, 0, 0], | ||||
|                                  [-1, end, -1, -1, -1]) | ||||
|       tmp_down = array_ops.slice(input_, [0, 0, 0, 0, 0], [-1, end, -1, -1, -1]) | ||||
|       tmp = array_ops.concat([tmp_up, input_, tmp_down], 1) | ||||
| 
 | ||||
|       tmp_left = array_ops.slice(tmp, [0, 0, width - beginning, 0, 0], | ||||
|                                  [-1, -1, beginning, -1, -1]) | ||||
|       tmp_right = array_ops.slice(tmp, [0, 0, 0, 0, 0], | ||||
|                                   [-1, -1, end, -1, -1]) | ||||
|       tmp_right = array_ops.slice(tmp, [0, 0, 0, 0, 0], [-1, -1, end, -1, -1]) | ||||
|       tmp = array_ops.concat([tmp_left, tmp, tmp_right], 2) | ||||
| 
 | ||||
|       tmp_front = array_ops.slice(tmp, [0, 0, 0, width - beginning, 0], | ||||
| @ -1284,8 +1289,11 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase): | ||||
|       inputs_2norm = linalg_ops.norm(inputs) | ||||
|       input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) | ||||
|       outputs = convolution( | ||||
|           input_with_circular_pad, padding="valid", filters=cout, | ||||
|           kernel_size=kernel_size[0], use_bias=False, | ||||
|           input_with_circular_pad, | ||||
|           padding="valid", | ||||
|           filters=cout, | ||||
|           kernel_size=kernel_size[0], | ||||
|           use_bias=False, | ||||
|           kernel_initializer=init_ops.convolutional_orthogonal_3d(gain=gain)) | ||||
|       outputs_2norm = linalg_ops.norm(outputs) | ||||
|       ratio = outputs_2norm / inputs_2norm | ||||
|  | ||||
| @ -141,8 +141,6 @@ class LinearOperatorAdjointTest( | ||||
|                 full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) | ||||
| 
 | ||||
|   def test_matmul_adjoint_complex_operator(self): | ||||
|     if test.is_built_with_rocm(): | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
|     matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) | ||||
|     matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) | ||||
|     full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) | ||||
| @ -201,7 +199,8 @@ class LinearOperatorAdjointTest( | ||||
| 
 | ||||
|   def test_solve_adjoint_complex_operator(self): | ||||
|     if test.is_built_with_rocm(): | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
|       self.skipTest("ROCm does not support BLAS solve operations" | ||||
|                     " for complex types") | ||||
|     matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix( | ||||
|         [4, 4], dtype=dtypes.complex128, force_well_conditioned=True) + | ||||
|                             1j * linear_operator_test_util.random_tril_matrix( | ||||
|  | ||||
| @ -114,18 +114,18 @@ class LinearOperatorCirculantTestSelfAdjointOperator( | ||||
|     # real, the matrix will not be real. | ||||
|     return [dtypes.complex64, dtypes.complex128] | ||||
| 
 | ||||
|   def operator_and_matrix( | ||||
|       self, shape_info, dtype, use_placeholder, | ||||
|       ensure_self_adjoint_and_pd=False): | ||||
|   def operator_and_matrix(self, | ||||
|                           shape_info, | ||||
|                           dtype, | ||||
|                           use_placeholder, | ||||
|                           ensure_self_adjoint_and_pd=False): | ||||
|     shape = shape_info.shape | ||||
|     # For this test class, we are creating real spectrums. | ||||
|     # We also want the spectrum to have eigenvalues bounded away from zero. | ||||
|     # | ||||
|     # spectrum is bounded away from zero. | ||||
|     spectrum = linear_operator_test_util.random_sign_uniform( | ||||
|         shape=self._shape_to_spectrum_shape(shape), | ||||
|         minval=1., | ||||
|         maxval=2.) | ||||
|         shape=self._shape_to_spectrum_shape(shape), minval=1., maxval=2.) | ||||
|     if ensure_self_adjoint_and_pd: | ||||
|       spectrum = math_ops.abs(spectrum) | ||||
|     # If dtype is complex, cast spectrum to complex.  The imaginary part will be | ||||
| @ -176,9 +176,11 @@ class LinearOperatorCirculantTestHermitianSpectrum( | ||||
|   zero imaginary part. | ||||
|   """ | ||||
| 
 | ||||
|   def operator_and_matrix( | ||||
|       self, shape_info, dtype, use_placeholder, | ||||
|       ensure_self_adjoint_and_pd=False): | ||||
|   def operator_and_matrix(self, | ||||
|                           shape_info, | ||||
|                           dtype, | ||||
|                           use_placeholder, | ||||
|                           ensure_self_adjoint_and_pd=False): | ||||
|     shape = shape_info.shape | ||||
|     # For this test class, we are creating Hermitian spectrums. | ||||
|     # We also want the spectrum to have eigenvalues bounded away from zero. | ||||
| @ -259,9 +261,11 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( | ||||
|   def skip_these_tests(): | ||||
|     return ["cholesky", "eigvalsh"] | ||||
| 
 | ||||
|   def operator_and_matrix( | ||||
|       self, shape_info, dtype, use_placeholder, | ||||
|       ensure_self_adjoint_and_pd=False): | ||||
|   def operator_and_matrix(self, | ||||
|                           shape_info, | ||||
|                           dtype, | ||||
|                           use_placeholder, | ||||
|                           ensure_self_adjoint_and_pd=False): | ||||
|     del ensure_self_adjoint_and_pd | ||||
|     shape = shape_info.shape | ||||
|     # Will be well conditioned enough to get accurate solves. | ||||
| @ -357,11 +361,6 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( | ||||
|         self.evaluate(operator.assert_non_singular()) | ||||
| 
 | ||||
|   def test_assert_non_singular_does_not_fail_for_non_singular_operator(self): | ||||
| 
 | ||||
|     if test.is_built_with_rocm(): | ||||
|       # ROCm does not yet support BLAS operations with complex types. | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
| 
 | ||||
|     spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64) | ||||
|     operator = linalg.LinearOperatorCirculant(spectrum) | ||||
|     with self.cached_session(): | ||||
| @ -486,9 +485,11 @@ class LinearOperatorCirculant2DTestHermitianSpectrum( | ||||
|   def skip_these_tests(): | ||||
|     return ["cond"] | ||||
| 
 | ||||
|   def operator_and_matrix( | ||||
|       self, shape_info, dtype, use_placeholder, | ||||
|       ensure_self_adjoint_and_pd=False): | ||||
|   def operator_and_matrix(self, | ||||
|                           shape_info, | ||||
|                           dtype, | ||||
|                           use_placeholder, | ||||
|                           ensure_self_adjoint_and_pd=False): | ||||
|     shape = shape_info.shape | ||||
|     # For this test class, we are creating Hermitian spectrums. | ||||
|     # We also want the spectrum to have eigenvalues bounded away from zero. | ||||
| @ -547,9 +548,11 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( | ||||
|   def skip_these_tests(): | ||||
|     return ["cholesky", "eigvalsh"] | ||||
| 
 | ||||
|   def operator_and_matrix( | ||||
|       self, shape_info, dtype, use_placeholder, | ||||
|       ensure_self_adjoint_and_pd=False): | ||||
|   def operator_and_matrix(self, | ||||
|                           shape_info, | ||||
|                           dtype, | ||||
|                           use_placeholder, | ||||
|                           ensure_self_adjoint_and_pd=False): | ||||
|     del ensure_self_adjoint_and_pd | ||||
|     shape = shape_info.shape | ||||
|     # Will be well conditioned enough to get accurate solves. | ||||
| @ -665,11 +668,6 @@ class LinearOperatorCirculant3DTest(test.TestCase): | ||||
|       yield sess | ||||
| 
 | ||||
|   def test_real_spectrum_gives_self_adjoint_operator(self): | ||||
| 
 | ||||
|     if test.is_built_with_rocm(): | ||||
|       # ROCm does not yet support BLAS operations with complext types | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
| 
 | ||||
|     with self.cached_session(): | ||||
|       # This is a real and hermitian spectrum. | ||||
|       spectrum = linear_operator_test_util.random_normal( | ||||
| @ -686,11 +684,6 @@ class LinearOperatorCirculant3DTest(test.TestCase): | ||||
|       self.assertAllClose(matrix, matrix_h) | ||||
| 
 | ||||
|   def test_defining_operator_using_real_convolution_kernel(self): | ||||
| 
 | ||||
|     if test.is_built_with_rocm(): | ||||
|       # ROCm does not yet support BLAS operations with complext types | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
| 
 | ||||
|     with self.cached_session(): | ||||
|       convolution_kernel = linear_operator_test_util.random_normal( | ||||
|           shape=(2, 2, 3, 5), dtype=dtypes.float32) | ||||
| @ -709,11 +702,6 @@ class LinearOperatorCirculant3DTest(test.TestCase): | ||||
|       np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5) | ||||
| 
 | ||||
|   def test_defining_spd_operator_by_taking_real_part(self): | ||||
| 
 | ||||
|     if test.is_built_with_rocm(): | ||||
|       # ROCm does not yet support BLAS operations with complext types | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
| 
 | ||||
|     with self.cached_session():  # Necessary for fft_kernel_label_map | ||||
|       # S is real and positive. | ||||
|       s = linear_operator_test_util.random_uniform( | ||||
|  | ||||
| @ -130,14 +130,12 @@ class LuOpTest(test.TestCase): | ||||
|       for output_idx_type in (dtypes.int32, dtypes.int64): | ||||
|         self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type) | ||||
| 
 | ||||
|     if not test.is_built_with_rocm(): | ||||
|       # ROCm does not support BLAS operations for complex types | ||||
|       for dtype in (np.complex64, np.complex128): | ||||
|         for output_idx_type in (dtypes.int32, dtypes.int64): | ||||
|           complex_data = np.tril(1j * data, -1).astype(dtype) | ||||
|           complex_data += np.triu(-1j * data, 1).astype(dtype) | ||||
|           complex_data += data | ||||
|           self._verifyLu(complex_data, output_idx_type=output_idx_type) | ||||
|     for dtype in (np.complex64, np.complex128): | ||||
|       for output_idx_type in (dtypes.int32, dtypes.int64): | ||||
|         complex_data = np.tril(1j * data, -1).astype(dtype) | ||||
|         complex_data += np.triu(-1j * data, 1).astype(dtype) | ||||
|         complex_data += data | ||||
|         self._verifyLu(complex_data, output_idx_type=output_idx_type) | ||||
| 
 | ||||
|   def testPivoting(self): | ||||
|     # This matrix triggers partial pivoting because the first diagonal entry | ||||
| @ -152,17 +150,15 @@ class LuOpTest(test.TestCase): | ||||
|       # Make sure p_val is not the identity permutation. | ||||
|       self.assertNotAllClose(np.arange(3), p_val) | ||||
| 
 | ||||
|     if not test.is_built_with_rocm(): | ||||
|       # ROCm does not support BLAS operations for complex types | ||||
|       for dtype in (np.complex64, np.complex128): | ||||
|         complex_data = np.tril(1j * data, -1).astype(dtype) | ||||
|         complex_data += np.triu(-1j * data, 1).astype(dtype) | ||||
|         complex_data += data | ||||
|         self._verifyLu(complex_data) | ||||
|         _, p = linalg_ops.lu(data) | ||||
|         p_val = self.evaluate([p]) | ||||
|         # Make sure p_val is not the identity permutation. | ||||
|         self.assertNotAllClose(np.arange(3), p_val) | ||||
|     for dtype in (np.complex64, np.complex128): | ||||
|       complex_data = np.tril(1j * data, -1).astype(dtype) | ||||
|       complex_data += np.triu(-1j * data, 1).astype(dtype) | ||||
|       complex_data += data | ||||
|       self._verifyLu(complex_data) | ||||
|       _, p = linalg_ops.lu(data) | ||||
|       p_val = self.evaluate([p]) | ||||
|       # Make sure p_val is not the identity permutation. | ||||
|       self.assertNotAllClose(np.arange(3), p_val) | ||||
| 
 | ||||
|   def testInvalidMatrix(self): | ||||
|     # LU factorization gives an error when the input is singular. | ||||
| @ -195,13 +191,11 @@ class LuOpTest(test.TestCase): | ||||
|     matrices = np.random.rand(batch_size, 5, 5) | ||||
|     self._verifyLu(matrices) | ||||
| 
 | ||||
|     if not test.is_built_with_rocm(): | ||||
|       # ROCm does not support BLAS operations for complex types | ||||
|       # Generate random complex valued matrices. | ||||
|       np.random.seed(52) | ||||
|       matrices = np.random.rand(batch_size, 5, | ||||
|                                 5) + 1j * np.random.rand(batch_size, 5, 5) | ||||
|       self._verifyLu(matrices) | ||||
|     # Generate random complex valued matrices. | ||||
|     np.random.seed(52) | ||||
|     matrices = np.random.rand(batch_size, 5, | ||||
|                               5) + 1j * np.random.rand(batch_size, 5, 5) | ||||
|     self._verifyLu(matrices) | ||||
| 
 | ||||
|   def testLargeMatrix(self): | ||||
|     # Generate random matrices. | ||||
| @ -210,12 +204,10 @@ class LuOpTest(test.TestCase): | ||||
|     data = np.random.rand(n, n) | ||||
|     self._verifyLu(data) | ||||
| 
 | ||||
|     if not test.is_built_with_rocm(): | ||||
|       # ROCm does not support BLAS operations for complex types | ||||
|       # Generate random complex valued matrices. | ||||
|       np.random.seed(129) | ||||
|       data = np.random.rand(n, n) + 1j * np.random.rand(n, n) | ||||
|       self._verifyLu(data) | ||||
|     # Generate random complex valued matrices. | ||||
|     np.random.seed(129) | ||||
|     data = np.random.rand(n, n) + 1j * np.random.rand(n, n) | ||||
|     self._verifyLu(data) | ||||
| 
 | ||||
|   @test_util.run_v1_only("b/120545219") | ||||
|   def testEmpty(self): | ||||
|  | ||||
| @ -226,10 +226,10 @@ class MatMulInfixOperatorTest(test_lib.TestCase): | ||||
| if __name__ == "__main__": | ||||
|   sizes = [1, 3, 5] | ||||
|   trans_options = [[False, False], [True, False], [False, True]] | ||||
|   dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64] | ||||
|   if not test_lib.is_built_with_rocm(): | ||||
|     # ROCm does not support BLAS operations for complex types | ||||
|     dtypes_to_test += [np.complex64, np.complex128] | ||||
|   dtypes_to_test = [ | ||||
|       np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64, | ||||
|       np.complex128 | ||||
|   ] | ||||
|   # TF2 does not support placeholders under eager so we skip it | ||||
|   for use_static_shape in set([True, tf2.enabled()]): | ||||
|     for dtype in dtypes_to_test: | ||||
|  | ||||
| @ -91,8 +91,6 @@ class ExponentialOpTest(test.TestCase): | ||||
| 
 | ||||
|   @test_util.run_deprecated_v1 | ||||
|   def testNonsymmetricComplex(self): | ||||
|     if test.is_built_with_rocm(): | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
|     matrix1 = np.array([[1., 2.], [3., 4.]]) | ||||
|     matrix2 = np.array([[1., 3.], [3., 5.]]) | ||||
|     matrix1 = matrix1.astype(np.complex64) | ||||
| @ -114,8 +112,6 @@ class ExponentialOpTest(test.TestCase): | ||||
|     self._verifyExponentialReal(self._makeBatch(matrix1, matrix2)) | ||||
| 
 | ||||
|   def testSymmetricPositiveDefiniteComplex(self): | ||||
|     if test.is_built_with_rocm(): | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
|     matrix1 = np.array([[2., 1.], [1., 2.]]) | ||||
|     matrix2 = np.array([[3., -1.], [-1., 3.]]) | ||||
|     matrix1 = matrix1.astype(np.complex64) | ||||
| @ -185,8 +181,8 @@ class MatrixExponentialBenchmark(test.Benchmark): | ||||
|     shape = shape[-2:] | ||||
|     assert shape[0] == shape[1] | ||||
|     n = shape[0] | ||||
|     matrix = np.ones(shape).astype(np.float32) / ( | ||||
|         2.0 * n) + np.diag(np.ones(n).astype(np.float32)) | ||||
|     matrix = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( | ||||
|         np.ones(n).astype(np.float32)) | ||||
|     return variables.Variable(np.tile(matrix, batch_shape + (1, 1))) | ||||
| 
 | ||||
|   def benchmarkMatrixExponentialOp(self): | ||||
| @ -201,8 +197,7 @@ class MatrixExponentialBenchmark(test.Benchmark): | ||||
|             sess, | ||||
|             control_flow_ops.group(expm), | ||||
|             min_iters=25, | ||||
|             name="matrix_exponential_cpu_{shape}".format( | ||||
|                 shape=shape)) | ||||
|             name="matrix_exponential_cpu_{shape}".format(shape=shape)) | ||||
| 
 | ||||
|       if test.is_gpu_available(True): | ||||
|         with ops.Graph().as_default(), \ | ||||
| @ -215,8 +210,7 @@ class MatrixExponentialBenchmark(test.Benchmark): | ||||
|               sess, | ||||
|               control_flow_ops.group(expm), | ||||
|               min_iters=25, | ||||
|               name="matrix_exponential_gpu_{shape}".format( | ||||
|                   shape=shape)) | ||||
|               name="matrix_exponential_gpu_{shape}".format(shape=shape)) | ||||
| 
 | ||||
| 
 | ||||
| def _TestRandomSmall(dtype, batch_dims, size): | ||||
| @ -224,9 +218,7 @@ def _TestRandomSmall(dtype, batch_dims, size): | ||||
|   def Test(self): | ||||
|     np.random.seed(42) | ||||
|     shape = batch_dims + (size, size) | ||||
|     matrix = np.random.uniform( | ||||
|         low=-1.0, high=1.0, | ||||
|         size=shape).astype(dtype) | ||||
|     matrix = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype) | ||||
|     self._verifyExponentialReal(matrix) | ||||
| 
 | ||||
|   return Test | ||||
| @ -237,10 +229,9 @@ def _TestL1Norms(dtype, shape, scale): | ||||
|   def Test(self): | ||||
|     np.random.seed(42) | ||||
|     matrix = np.random.uniform( | ||||
|         low=-1.0, high=1.0, | ||||
|         size=np.prod(shape)).reshape(shape).astype(dtype) | ||||
|         low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) | ||||
|     print(dtype, shape, scale, matrix) | ||||
|     l1_norm = np.max(np.sum(np.abs(matrix), axis=matrix.ndim-2)) | ||||
|     l1_norm = np.max(np.sum(np.abs(matrix), axis=matrix.ndim - 2)) | ||||
|     matrix /= l1_norm | ||||
|     self._verifyExponentialReal(scale * matrix) | ||||
| 
 | ||||
| @ -258,12 +249,12 @@ if __name__ == "__main__": | ||||
|   for shape_ in [(3, 3), (2, 3, 3)]: | ||||
|     for dtype_ in [np.float32, np.complex64]: | ||||
|       for scale_ in [0.1, 1.5, 5.0, 20.0]: | ||||
|         name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_*10)) | ||||
|         name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_ * 10)) | ||||
|         setattr(ExponentialOpTest, "testL1Norms_" + name, | ||||
|                 _TestL1Norms(dtype_, shape_, scale_)) | ||||
|     for dtype_ in [np.float64, np.complex128]: | ||||
|       for scale_ in [0.01, 0.2, 0.5, 1.5, 6.0, 25.0]: | ||||
|         name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_*100)) | ||||
|         name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_ * 100)) | ||||
|         setattr(ExponentialOpTest, "testL1Norms_" + name, | ||||
|                 _TestL1Norms(dtype_, shape_, scale_)) | ||||
|   test.main() | ||||
|  | ||||
| @ -74,17 +74,14 @@ class InverseOpTest(test.TestCase): | ||||
|     self._verifyInverseReal(matrix2) | ||||
|     # A multidimensional batch of 2x2 matrices | ||||
|     self._verifyInverseReal(self._makeBatch(matrix1, matrix2)) | ||||
|     if not test.is_built_with_rocm(): | ||||
|       # ROCm does not support BLAS operations for complex types | ||||
|       # Complex | ||||
|       matrix1 = matrix1.astype(np.complex64) | ||||
|       matrix1 += 1j * matrix1 | ||||
|       matrix2 = matrix2.astype(np.complex64) | ||||
|       matrix2 += 1j * matrix2 | ||||
|       self._verifyInverseComplex(matrix1) | ||||
|       self._verifyInverseComplex(matrix2) | ||||
|       # Complex batch | ||||
|       self._verifyInverseComplex(self._makeBatch(matrix1, matrix2)) | ||||
|     matrix1 = matrix1.astype(np.complex64) | ||||
|     matrix1 += 1j * matrix1 | ||||
|     matrix2 = matrix2.astype(np.complex64) | ||||
|     matrix2 += 1j * matrix2 | ||||
|     self._verifyInverseComplex(matrix1) | ||||
|     self._verifyInverseComplex(matrix2) | ||||
|     # Complex batch | ||||
|     self._verifyInverseComplex(self._makeBatch(matrix1, matrix2)) | ||||
| 
 | ||||
|   def testSymmetricPositiveDefinite(self): | ||||
|     # 2x2 matrices | ||||
| @ -94,17 +91,14 @@ class InverseOpTest(test.TestCase): | ||||
|     self._verifyInverseReal(matrix2) | ||||
|     # A multidimensional batch of 2x2 matrices | ||||
|     self._verifyInverseReal(self._makeBatch(matrix1, matrix2)) | ||||
|     if not test.is_built_with_rocm(): | ||||
|       # ROCm does not support BLAS operations for complex types | ||||
|       # Complex | ||||
|       matrix1 = matrix1.astype(np.complex64) | ||||
|       matrix1 += 1j * matrix1 | ||||
|       matrix2 = matrix2.astype(np.complex64) | ||||
|       matrix2 += 1j * matrix2 | ||||
|       self._verifyInverseComplex(matrix1) | ||||
|       self._verifyInverseComplex(matrix2) | ||||
|       # Complex batch | ||||
|       self._verifyInverseComplex(self._makeBatch(matrix1, matrix2)) | ||||
|     matrix1 = matrix1.astype(np.complex64) | ||||
|     matrix1 += 1j * matrix1 | ||||
|     matrix2 = matrix2.astype(np.complex64) | ||||
|     matrix2 += 1j * matrix2 | ||||
|     self._verifyInverseComplex(matrix1) | ||||
|     self._verifyInverseComplex(matrix2) | ||||
|     # Complex batch | ||||
|     self._verifyInverseComplex(self._makeBatch(matrix1, matrix2)) | ||||
| 
 | ||||
|   @test_util.deprecated_graph_mode_only | ||||
|   def testNonSquareMatrix(self): | ||||
|  | ||||
| @ -59,8 +59,6 @@ class LogarithmOpTest(test.TestCase): | ||||
| 
 | ||||
|   @test_util.run_v1_only("b/120545219") | ||||
|   def testNonsymmetric(self): | ||||
|     if test.is_built_with_rocm(): | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
|     # 2x2 matrices | ||||
|     matrix1 = np.array([[1., 2.], [3., 4.]]) | ||||
|     matrix2 = np.array([[1., 3.], [3., 5.]]) | ||||
| @ -75,8 +73,6 @@ class LogarithmOpTest(test.TestCase): | ||||
| 
 | ||||
|   @test_util.run_v1_only("b/120545219") | ||||
|   def testSymmetricPositiveDefinite(self): | ||||
|     if test.is_built_with_rocm(): | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
|     # 2x2 matrices | ||||
|     matrix1 = np.array([[2., 1.], [1., 2.]]) | ||||
|     matrix2 = np.array([[3., -1.], [-1., 3.]]) | ||||
| @ -111,8 +107,6 @@ class LogarithmOpTest(test.TestCase): | ||||
| 
 | ||||
|   @test_util.run_v1_only("b/120545219") | ||||
|   def testRandomSmallAndLargeComplex64(self): | ||||
|     if test.is_built_with_rocm(): | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
|     np.random.seed(42) | ||||
|     for batch_dims in [(), (1,), (3,), (2, 2)]: | ||||
|       for size in 8, 31, 32: | ||||
| @ -124,8 +118,6 @@ class LogarithmOpTest(test.TestCase): | ||||
| 
 | ||||
|   @test_util.run_v1_only("b/120545219") | ||||
|   def testRandomSmallAndLargeComplex128(self): | ||||
|     if test.is_built_with_rocm(): | ||||
|       self.skipTest("ROCm does not support BLAS operations for complex types") | ||||
|     np.random.seed(42) | ||||
|     for batch_dims in [(), (1,), (3,), (2, 2)]: | ||||
|       for size in 8, 31, 32: | ||||
| @ -169,8 +161,8 @@ class MatrixLogarithmBenchmark(test.Benchmark): | ||||
|     shape = shape[-2:] | ||||
|     assert shape[0] == shape[1] | ||||
|     n = shape[0] | ||||
|     matrix = np.ones(shape).astype(np.complex64) / ( | ||||
|         2.0 * n) + np.diag(np.ones(n).astype(np.complex64)) | ||||
|     matrix = np.ones(shape).astype(np.complex64) / (2.0 * n) + np.diag( | ||||
|         np.ones(n).astype(np.complex64)) | ||||
|     return variables.Variable(np.tile(matrix, batch_shape + (1, 1))) | ||||
| 
 | ||||
|   def benchmarkMatrixLogarithmOp(self): | ||||
| @ -185,8 +177,7 @@ class MatrixLogarithmBenchmark(test.Benchmark): | ||||
|             sess, | ||||
|             control_flow_ops.group(logm), | ||||
|             min_iters=25, | ||||
|             name="matrix_logarithm_cpu_{shape}".format( | ||||
|                 shape=shape)) | ||||
|             name="matrix_logarithm_cpu_{shape}".format(shape=shape)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  | ||||
| @ -59,16 +59,13 @@ class SquareRootOpTest(test.TestCase): | ||||
|     self._verifySquareRootReal(matrix1) | ||||
|     self._verifySquareRootReal(matrix2) | ||||
|     self._verifySquareRootReal(self._makeBatch(matrix1, matrix2)) | ||||
|     if not test.is_built_with_rocm(): | ||||
|       # ROCm does not support BLAS operations for complex types | ||||
|       # Complex | ||||
|       matrix1 = matrix1.astype(np.complex64) | ||||
|       matrix2 = matrix2.astype(np.complex64) | ||||
|       matrix1 += 1j * matrix1 | ||||
|       matrix2 += 1j * matrix2 | ||||
|       self._verifySquareRootComplex(matrix1) | ||||
|       self._verifySquareRootComplex(matrix2) | ||||
|       self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2)) | ||||
|     matrix1 = matrix1.astype(np.complex64) | ||||
|     matrix2 = matrix2.astype(np.complex64) | ||||
|     matrix1 += 1j * matrix1 | ||||
|     matrix2 += 1j * matrix2 | ||||
|     self._verifySquareRootComplex(matrix1) | ||||
|     self._verifySquareRootComplex(matrix2) | ||||
|     self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2)) | ||||
| 
 | ||||
|   def testSymmetricPositiveDefinite(self): | ||||
|     matrix1 = np.array([[2., 1.], [1., 2.]]) | ||||
|  | ||||
| @ -240,10 +240,10 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_, compute_v_): | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64] | ||||
|   if not test.is_built_with_rocm(): | ||||
|     # ROCm does not support BLAS operations for complex types | ||||
|     dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128] | ||||
|   dtypes_to_test = [ | ||||
|       dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, | ||||
|       dtypes_lib.complex128 | ||||
|   ] | ||||
|   for compute_v in True, False: | ||||
|     for dtype in dtypes_to_test: | ||||
|       for size in 1, 2, 5, 10: | ||||
|  | ||||
| @ -125,7 +125,6 @@ cuda_py_tests( | ||||
|     srcs = ["spectral_ops_test.py"], | ||||
|     python_version = "PY3", | ||||
|     tags = [ | ||||
|         "no_rocm", | ||||
|         "nomac", | ||||
|     ], | ||||
|     deps = [ | ||||
|  | ||||
| @ -370,10 +370,7 @@ class SVDBenchmark(test.Benchmark): | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   dtypes_to_test = [np.float32, np.float64] | ||||
|   if not test.is_built_with_rocm(): | ||||
|     # ROCm does not support BLAS operations for complex types | ||||
|     dtypes_to_test += [np.complex64, np.complex128] | ||||
|   dtypes_to_test = [np.float32, np.float64, np.complex64, np.complex128] | ||||
|   for compute_uv in False, True: | ||||
|     for full_matrices in False, True: | ||||
|       for dtype in dtypes_to_test: | ||||
| @ -392,7 +389,7 @@ if __name__ == "__main__": | ||||
|   for compute_uv in False, True: | ||||
|     for full_matrices in False, True: | ||||
|       dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] * | ||||
|                 (not compute_uv) * (not test.is_built_with_rocm())) | ||||
|                 (not compute_uv)) | ||||
|       for dtype in dtypes: | ||||
|         mat_shapes = [(10, 11), (11, 10), (11, 11), (2, 2, 2, 3)] | ||||
|         if not full_matrices or not compute_uv: | ||||
|  | ||||
| @ -221,10 +221,9 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   dtypes_to_test = [np.float16, np.float32, np.float64] | ||||
|   if not test_lib.is_built_with_rocm(): | ||||
|     # ROCm does not support BLAS operations for complex types | ||||
|     dtypes_to_test += [np.complex64, np.complex128] | ||||
|   dtypes_to_test = [ | ||||
|       np.float16, np.float32, np.float64, np.complex64, np.complex128 | ||||
|   ] | ||||
|   for dtype in dtypes_to_test: | ||||
|     for rank_a in 1, 2, 4, 5: | ||||
|       for rank_b in 1, 2, 4, 5: | ||||
|  | ||||
| @ -562,13 +562,7 @@ class EinsumTest(test.TestCase): | ||||
|     self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,)) | ||||
| 
 | ||||
|   def test_dtypes(self): | ||||
|     dtypes = [] | ||||
|     if test.is_built_with_rocm(): | ||||
|       # This test triggers the BLAS op calls on the GPU | ||||
|       # ROCm does not support BLAS operations for complex types | ||||
|       dtypes = [np.float64, np.float32] | ||||
|     else: | ||||
|       dtypes = [np.float64, np.float32, np.complex64, np.complex128] | ||||
|     dtypes = [np.float64, np.float32, np.complex64, np.complex128] | ||||
|     for dtype in dtypes: | ||||
|       self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype) | ||||
|       self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype) | ||||
|  | ||||
| @ -114,10 +114,10 @@ namespace wrap { | ||||
|     __macro(rocblas_zdotc)                 */ \ | ||||
|   __macro(rocblas_sscal)                    \ | ||||
|   __macro(rocblas_dscal)                    \ | ||||
|   /*__macro(rocblas_cscal)                    \
 | ||||
|   __macro(rocblas_cscal)                    \ | ||||
|     __macro(rocblas_csscal)                   \ | ||||
|     __macro(rocblas_zscal)                    \ | ||||
|     __macro(rocblas_zdscal)                */ \ | ||||
|     __macro(rocblas_zdscal)                 \ | ||||
|   __macro(rocblas_saxpy)                    \ | ||||
|   __macro(rocblas_daxpy)                    \ | ||||
|   /*__macro(rocblas_caxpy)                    \
 | ||||
| @ -158,9 +158,9 @@ namespace wrap { | ||||
|     __macro(rocblas_drotmg)                */ \ | ||||
|   __macro(rocblas_sgemv)                    \ | ||||
|   __macro(rocblas_dgemv)                    \ | ||||
|   /*__macro(rocblas_cgemv)                    \
 | ||||
|   __macro(rocblas_cgemv)                    \ | ||||
|     __macro(rocblas_zgemv)                    \ | ||||
|     __macro(rocblas_sgbmv)                    \ | ||||
|   /*  __macro(rocblas_sgbmv)                    \
 | ||||
|     __macro(rocblas_dgbmv)                    \ | ||||
|     __macro(rocblas_cgbmv)                    \ | ||||
|     __macro(rocblas_zgbmv)                    \ | ||||
| @ -231,9 +231,9 @@ namespace wrap { | ||||
|   __macro(rocblas_sgemm)                    \ | ||||
|   __macro(rocblas_dgemm)                    \ | ||||
|   __macro(rocblas_hgemm)                    \ | ||||
|   /*__macro(rocblas_cgemm)                    \
 | ||||
|   __macro(rocblas_cgemm)                    \ | ||||
|     __macro(rocblas_zgemm)                    \ | ||||
|     __macro(rocblas_ssyrk)                    \ | ||||
|   /*  __macro(rocblas_ssyrk)                    \
 | ||||
|     __macro(rocblas_dsyrk)                    \ | ||||
|     __macro(rocblas_csyrk)                    \ | ||||
|     __macro(rocblas_zsyrk)                    \ | ||||
| @ -285,12 +285,37 @@ STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched) | ||||
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched) | ||||
| // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
 | ||||
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched) | ||||
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_strided_batched) | ||||
| STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_strided_batched) | ||||
| // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
 | ||||
| // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
 | ||||
| ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP) | ||||
| 
 | ||||
| }  // namespace wrap
 | ||||
| 
 | ||||
| template <class T> | ||||
| const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast( | ||||
|     const DeviceMemory<T> &a) { | ||||
|   return reinterpret_cast< | ||||
|       const typename RocBlasTypeConversionHelper<T>::mapped_type *>( | ||||
|       GpuMemory(a)); | ||||
| } | ||||
| template <class T> | ||||
| const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast( | ||||
|     const T &a) { | ||||
|   return reinterpret_cast< | ||||
|       const typename RocBlasTypeConversionHelper<T>::mapped_type *>(&a); | ||||
| } | ||||
| template <class T> | ||||
| typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast( | ||||
|     DeviceMemory<T> *a) { | ||||
|   return reinterpret_cast< | ||||
|       typename RocBlasTypeConversionHelper<T>::mapped_type *>( | ||||
|       GpuMemoryMutable(a)); | ||||
| } | ||||
| 
 | ||||
| static void blas_log(const char *c) {} | ||||
| 
 | ||||
| static string ToString(rocblas_status status) { | ||||
|   switch (status) { | ||||
|     case rocblas_status_success: | ||||
| @ -451,6 +476,7 @@ bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, | ||||
| bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, | ||||
|                           const DeviceMemory<float> &x, int incx, | ||||
|                           DeviceMemory<float> *y, int incy) { | ||||
|   blas_log("DoBlasAxpy"); | ||||
|   return DoBlasInternal(wrap::rocblas_saxpy, stream, | ||||
|                         true /* = pointer_mode_host */, elem_count, &alpha, | ||||
|                         GpuMemory(x), incx, GpuMemoryMutable(y), incy); | ||||
| @ -459,6 +485,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, | ||||
| bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, | ||||
|                           const DeviceMemory<double> &x, int incx, | ||||
|                           DeviceMemory<double> *y, int incy) { | ||||
|   blas_log("DoBlasAxpy"); | ||||
|   return DoBlasInternal(wrap::rocblas_daxpy, stream, | ||||
|                         true /* = pointer_mode_host */, elem_count, &alpha, | ||||
|                         GpuMemory(x), incx, GpuMemoryMutable(y), incy); | ||||
| @ -518,6 +545,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, | ||||
|                          const DeviceMemory<float> &x, int incx, | ||||
|                          const DeviceMemory<float> &y, int incy, | ||||
|                          DeviceMemory<float> *result) { | ||||
|   blas_log("DoBlasDot"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count, | ||||
|       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); | ||||
| @ -527,6 +555,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, | ||||
|                          const DeviceMemory<double> &x, int incx, | ||||
|                          const DeviceMemory<double> &y, int incy, | ||||
|                          DeviceMemory<double> *result) { | ||||
|   blas_log("DoBlasDot"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count, | ||||
|       GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); | ||||
| @ -707,6 +736,7 @@ bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, | ||||
|                           DeviceMemory<float> *x, int incx) { | ||||
|   blas_log("DoBlasScal<float>"); | ||||
|   return DoBlasInternal(wrap::rocblas_sscal, stream, | ||||
|                         true /* = pointer_mode_host */, elem_count, &alpha, | ||||
|                         GpuMemoryMutable(x), incx); | ||||
| @ -721,32 +751,32 @@ bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, | ||||
|                           DeviceMemory<std::complex<float>> *x, int incx) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " | ||||
|              << "for the \"complex<float>\" datatype"; | ||||
|   return false; | ||||
|   return DoBlasInternal(wrap::rocblas_csscal, stream, | ||||
|                         true /* = pointer_mode_host */, elem_count, &alpha, | ||||
|                         complex_cast(x), incx); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, | ||||
|                           DeviceMemory<std::complex<double>> *x, int incx) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " | ||||
|              << "for the \"complex<double>\" datatype"; | ||||
|   return false; | ||||
|   return DoBlasInternal(wrap::rocblas_zdscal, stream, | ||||
|                         true /* = pointer_mode_host */, elem_count, &alpha, | ||||
|                         complex_cast(x), incx); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, | ||||
|                           std::complex<float> alpha, | ||||
|                           DeviceMemory<std::complex<float>> *x, int incx) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " | ||||
|              << "for the \"complex<float>\" datatype"; | ||||
|   return false; | ||||
|   return DoBlasInternal(wrap::rocblas_cscal, stream, | ||||
|                         true /* = pointer_mode_host */, elem_count, | ||||
|                         complex_cast(alpha), complex_cast(x), incx); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, | ||||
|                           std::complex<double> alpha, | ||||
|                           DeviceMemory<std::complex<double>> *x, int incx) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " | ||||
|              << "for the \"complex<double>\" datatype"; | ||||
|   return false; | ||||
|   return DoBlasInternal(wrap::rocblas_zscal, stream, | ||||
|                         true /* = pointer_mode_host */, elem_count, | ||||
|                         complex_cast(alpha), complex_cast(x), incx); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, | ||||
| @ -893,6 +923,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, | ||||
|                           uint64 n, float alpha, const DeviceMemory<float> &a, | ||||
|                           int lda, const DeviceMemory<float> &x, int incx, | ||||
|                           float beta, DeviceMemory<float> *y, int incy) { | ||||
|   blas_log("DoBlasGemv"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), | ||||
| @ -903,6 +934,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, | ||||
|                           uint64 n, double alpha, const DeviceMemory<double> &a, | ||||
|                           int lda, const DeviceMemory<double> &x, int incx, | ||||
|                           double beta, DeviceMemory<double> *y, int incy) { | ||||
|   blas_log("DoBlasGemv"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), | ||||
| @ -915,9 +947,11 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, | ||||
|                           const DeviceMemory<std::complex<float>> &x, int incx, | ||||
|                           std::complex<float> beta, | ||||
|                           DeviceMemory<std::complex<float>> *y, int incy) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the GEMV operation " | ||||
|              << "for the \"complex<float>\" datatype"; | ||||
|   return false; | ||||
|   blas_log("DoBlasGemv"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_cgemv, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda, | ||||
|       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, | ||||
| @ -926,9 +960,11 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, | ||||
|                           const DeviceMemory<std::complex<double>> &x, int incx, | ||||
|                           std::complex<double> beta, | ||||
|                           DeviceMemory<std::complex<double>> *y, int incy) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the GEMV operation " | ||||
|              << "for the \"complex<double>\" datatype"; | ||||
|   return false; | ||||
|   blas_log("DoBlasGemv\n"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_zgemv, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda, | ||||
|       complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, | ||||
| @ -1481,6 +1517,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, | ||||
|                           float alpha, const DeviceMemory<Eigen::half> &a, | ||||
|                           int lda, const DeviceMemory<Eigen::half> &b, int ldb, | ||||
|                           float beta, DeviceMemory<Eigen::half> *c, int ldc) { | ||||
|   blas_log("DoBlasGemm"); | ||||
|   VLOG(1) << absl::StreamFormat( | ||||
|       "doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u " | ||||
|       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " | ||||
| @ -1526,6 +1563,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, | ||||
|                           float alpha, const DeviceMemory<float> &a, int lda, | ||||
|                           const DeviceMemory<float> &b, int ldb, float beta, | ||||
|                           DeviceMemory<float> *c, int ldc) { | ||||
|   blas_log("DoBlasGemm"); | ||||
|   VLOG(1) << absl::StreamFormat( | ||||
|       "doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u " | ||||
|       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " | ||||
| @ -1565,6 +1603,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, | ||||
|                           double alpha, const DeviceMemory<double> &a, int lda, | ||||
|                           const DeviceMemory<double> &b, int ldb, double beta, | ||||
|                           DeviceMemory<double> *c, int ldc) { | ||||
|   blas_log("DoBlasGemm"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, | ||||
| @ -1578,9 +1617,12 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, | ||||
|                           const DeviceMemory<std::complex<float>> &b, int ldb, | ||||
|                           std::complex<float> beta, | ||||
|                           DeviceMemory<std::complex<float>> *c, int ldc) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the GEMM operation " | ||||
|              << "for the \"complex<float>\" datatype"; | ||||
|   return false; | ||||
|   blas_log("DoBlasGemm"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_cgemm, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, | ||||
|       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb, | ||||
|       complex_cast(beta), complex_cast(c), ldc); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, | ||||
| @ -1590,9 +1632,12 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, | ||||
|                           const DeviceMemory<std::complex<double>> &b, int ldb, | ||||
|                           std::complex<double> beta, | ||||
|                           DeviceMemory<std::complex<double>> *c, int ldc) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the GEMM operation " | ||||
|              << "for the \"complex<double>\" datatype"; | ||||
|   return false; | ||||
|   blas_log("DoBlasGemm"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_zgemm, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, | ||||
|       complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb, | ||||
|       complex_cast(beta), complex_cast(c), ldc); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasGemvWithProfiling( | ||||
| @ -1813,6 +1858,56 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( | ||||
|   return false; | ||||
| } | ||||
| 
 | ||||
| // This copies from source memory: raw_ptrs[i] to target memory:
 | ||||
| // device_memory_ptr at the interval of matrix_byte_size, or vice versa.
 | ||||
| // The below algorithm tries to minimize the number of memcpy by consolidating
 | ||||
| // neighboring memcpy into a single request
 | ||||
| template <typename MAPPED_T> | ||||
| port::Status ReorganizeMemory(Stream *stream, | ||||
|                               DeviceMemory<MAPPED_T> *device_memory, | ||||
|                               const std::vector<MAPPED_T *> &raw_ptrs, | ||||
|                               int batch_count, uint64_t batch_stride, | ||||
|                               bool gather) { | ||||
|   assert(batch_count > 0); | ||||
|   char *device_memory_ptr = static_cast<char *>(device_memory->opaque()); | ||||
|   char *src_ptr = reinterpret_cast<char *>(raw_ptrs[0]); | ||||
|   char *dst_ptr = device_memory_ptr; | ||||
|   size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T); | ||||
|   uint64_t cur_stride_size = matrix_byte_size; | ||||
| 
 | ||||
|   for (int i = 1; i < batch_count; ++i) { | ||||
|     if (reinterpret_cast<char *>(raw_ptrs[i]) == src_ptr + cur_stride_size) { | ||||
|       cur_stride_size += matrix_byte_size; | ||||
|     } else { | ||||
|       DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size); | ||||
|       DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size); | ||||
|       bool a_status = | ||||
|           gather | ||||
|               ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok() | ||||
|               : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok(); | ||||
|       if (!a_status) { | ||||
|         return port::Status( | ||||
|             port::error::INTERNAL, | ||||
|             "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); | ||||
|       } | ||||
|       src_ptr = reinterpret_cast<char *>(raw_ptrs[i]); | ||||
|       dst_ptr = device_memory_ptr + i * matrix_byte_size; | ||||
|       cur_stride_size = matrix_byte_size; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size); | ||||
|   DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size); | ||||
|   bool a_status = | ||||
|       gather ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok() | ||||
|              : stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok(); | ||||
|   if (!a_status) | ||||
|     return port::Status( | ||||
|         port::error::INTERNAL, | ||||
|         "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); | ||||
|   return port::Status::OK(); | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| port::Status ROCMBlas::AllocateStridedBuffer( | ||||
|     const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *> | ||||
| @ -1822,7 +1917,8 @@ port::Status ROCMBlas::AllocateStridedBuffer( | ||||
|     std::unique_ptr<TemporaryDeviceMemory< | ||||
|         typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, | ||||
|     DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> | ||||
|         *device_memory) { | ||||
|         *device_memory, | ||||
|     bool copy_data, bool &reallocated) { | ||||
|   assert(device_memory != nullptr); | ||||
| 
 | ||||
|   using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type; | ||||
| @ -1843,6 +1939,7 @@ port::Status ROCMBlas::AllocateStridedBuffer( | ||||
|   if (!needs_allocate_strided) { | ||||
|     *device_memory = DeviceMemory<MAPPED_T>( | ||||
|         DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size)); | ||||
|     reallocated = false; | ||||
|     return port::Status::OK(); | ||||
|   } | ||||
| 
 | ||||
| @ -1859,19 +1956,11 @@ port::Status ROCMBlas::AllocateStridedBuffer( | ||||
|         DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory()); | ||||
|   } | ||||
| 
 | ||||
|   for (int i = 0; i < batch_count; ++i) { | ||||
|     char *device_memory_ptr = static_cast<char *>(device_memory->opaque()); | ||||
|     DeviceMemoryBase src_mem = DeviceMemoryBase(raw_ptrs[i], matrix_byte_size); | ||||
|     DeviceMemoryBase target_mem = DeviceMemoryBase( | ||||
|         device_memory_ptr + i * matrix_byte_size, matrix_byte_size); | ||||
|     bool a_status = | ||||
|         stream->ThenMemcpy(&target_mem, src_mem, matrix_byte_size).ok(); | ||||
|     if (!a_status) { | ||||
|       return port::Status( | ||||
|           port::error::INTERNAL, | ||||
|           "failed to copy device memory in ROCMBlas::DoBlasGemmBatched"); | ||||
|     } | ||||
|   } | ||||
|   reallocated = true; | ||||
| 
 | ||||
|   if (copy_data) | ||||
|     return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count, | ||||
|                             batch_stride, true); | ||||
|   return port::Status::OK(); | ||||
| } | ||||
| 
 | ||||
| @ -1925,27 +2014,28 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( | ||||
|   DeviceMemory<MAPPED_T> a; | ||||
|   // Make sure the temporary memory are in-scope before the function returns
 | ||||
|   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp; | ||||
|   port::Status a_allocation_status = | ||||
|       AllocateStridedBuffer<T>(a_raw_ptrs, batch_count, batch_stride_a, | ||||
|                                scratch_allocator, stream, &a_temp, &a); | ||||
|   bool reallocated_a, reallocated_b, reallocated_c; | ||||
|   port::Status a_allocation_status = AllocateStridedBuffer<T>( | ||||
|       a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream, | ||||
|       &a_temp, &a, true, reallocated_a); | ||||
|   if (a_allocation_status != port::Status::OK()) { | ||||
|     return a_allocation_status; | ||||
|   } | ||||
| 
 | ||||
|   DeviceMemory<MAPPED_T> b; | ||||
|   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp; | ||||
|   port::Status b_allocation_status = | ||||
|       AllocateStridedBuffer<T>(b_raw_ptrs, batch_count, batch_stride_b, | ||||
|                                scratch_allocator, stream, &b_temp, &b); | ||||
|   port::Status b_allocation_status = AllocateStridedBuffer<T>( | ||||
|       b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream, | ||||
|       &b_temp, &b, true, reallocated_b); | ||||
|   if (b_allocation_status != port::Status::OK()) { | ||||
|     return b_allocation_status; | ||||
|   } | ||||
| 
 | ||||
|   DeviceMemory<MAPPED_T> c; | ||||
|   std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp; | ||||
|   port::Status c_allocation_status = | ||||
|       AllocateStridedBuffer<T>(c_raw_ptrs, batch_count, batch_stride_c, | ||||
|                                scratch_allocator, stream, &c_temp, &c); | ||||
|   port::Status c_allocation_status = AllocateStridedBuffer<T>( | ||||
|       c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream, | ||||
|       &c_temp, &c, true, reallocated_c);  // can disable copy if beta=0
 | ||||
|   if (c_allocation_status != port::Status::OK()) { | ||||
|     return c_allocation_status; | ||||
|   } | ||||
| @ -1953,19 +2043,20 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( | ||||
|   MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha); | ||||
|   MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta); | ||||
| 
 | ||||
|   bool ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */, | ||||
|                            ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), | ||||
|                            m, n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda, | ||||
|                            batch_stride_a, GpuMemory(b), ldb, batch_stride_b, | ||||
|                            GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc, | ||||
|                            batch_stride_c, batch_count); | ||||
| 
 | ||||
|   if (ok) { | ||||
|     return port::Status::OK(); | ||||
|   } else { | ||||
|   bool ok; | ||||
|   ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */, | ||||
|                       ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, | ||||
|                       n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda, | ||||
|                       batch_stride_a, GpuMemory(b), ldb, batch_stride_b, | ||||
|                       GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc, | ||||
|                       batch_stride_c, batch_count); | ||||
|   if (!ok) | ||||
|     return port::Status(port::error::INTERNAL, | ||||
|                         "failed BLAS call, see log for details"); | ||||
|   } | ||||
|   if (reallocated_c) | ||||
|     return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c, | ||||
|                             false); | ||||
|   return port::Status::OK(); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasGemmBatched( | ||||
| @ -1975,6 +2066,7 @@ bool ROCMBlas::DoBlasGemmBatched( | ||||
|     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta, | ||||
|     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc, | ||||
|     int batch_count, ScratchAllocator *scratch_allocator) { | ||||
|   blas_log("DoBlasGemmBatched"); | ||||
|   const Eigen::half alpha_half(alpha); | ||||
|   const Eigen::half beta_half(beta); | ||||
| 
 | ||||
| @ -1996,6 +2088,7 @@ bool ROCMBlas::DoBlasGemmBatched( | ||||
|     const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta, | ||||
|     const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc, | ||||
|     int batch_count, ScratchAllocator *scratch_allocator) { | ||||
|   blas_log("DoBlasGemmBatched"); | ||||
|   port::Status status = DoBlasGemmBatchedInternal( | ||||
|       wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, | ||||
|       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, | ||||
| @ -2013,6 +2106,7 @@ bool ROCMBlas::DoBlasGemmBatched( | ||||
|     const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb, | ||||
|     double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array, | ||||
|     int ldc, int batch_count, ScratchAllocator *scratch_allocator) { | ||||
|   blas_log("DoBlasGemmBatched"); | ||||
|   port::Status status = DoBlasGemmBatchedInternal( | ||||
|       wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, | ||||
|       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, | ||||
| @ -2032,9 +2126,15 @@ bool ROCMBlas::DoBlasGemmBatched( | ||||
|     int ldb, std::complex<float> beta, | ||||
|     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array, | ||||
|     int ldc, int batch_count, ScratchAllocator *scratch_allocator) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation " | ||||
|              << "for the \"complex<float>\" datatype"; | ||||
|   return false; | ||||
|   blas_log("DoBlasGemmBatched"); | ||||
|   port::Status status = DoBlasGemmBatchedInternal( | ||||
|       wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k, | ||||
|       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, | ||||
|       scratch_allocator); | ||||
|   if (!status.ok()) { | ||||
|     LOG(ERROR) << status; | ||||
|   } | ||||
|   return status.ok(); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasGemmBatched( | ||||
| @ -2046,9 +2146,15 @@ bool ROCMBlas::DoBlasGemmBatched( | ||||
|     int ldb, std::complex<double> beta, | ||||
|     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array, | ||||
|     int ldc, int batch_count, ScratchAllocator *scratch_allocator) { | ||||
|   LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation " | ||||
|              << "for the \"complex<double>\" datatype"; | ||||
|   return false; | ||||
|   blas_log("DoBlasGemmBatched"); | ||||
|   port::Status status = DoBlasGemmBatchedInternal( | ||||
|       wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k, | ||||
|       alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, | ||||
|       scratch_allocator); | ||||
|   if (!status.ok()) { | ||||
|     LOG(ERROR) << status; | ||||
|   } | ||||
|   return status.ok(); | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, | ||||
| @ -2296,6 +2402,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, | ||||
|                           blas::Diagonal diag, uint64 m, uint64 n, float alpha, | ||||
|                           const DeviceMemory<float> &a, int lda, | ||||
|                           DeviceMemory<float> *b, int ldb) { | ||||
|   blas_log("DoBlasTrsm"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_strsm, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), | ||||
| @ -2308,6 +2415,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, | ||||
|                           blas::Diagonal diag, uint64 m, uint64 n, double alpha, | ||||
|                           const DeviceMemory<double> &a, int lda, | ||||
|                           DeviceMemory<double> *b, int ldb) { | ||||
|   blas_log("DoBlasTrsm"); | ||||
|   return DoBlasInternal( | ||||
|       wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */, | ||||
|       ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), | ||||
| @ -2336,12 +2444,14 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, | ||||
|              << "for the \"complex<double>\" datatype"; | ||||
|   return false; | ||||
| } | ||||
| 
 | ||||
| bool ROCMBlas::DoBlasGemmStridedBatched( | ||||
|     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, | ||||
|     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, | ||||
|     int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, | ||||
|     int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, | ||||
|     int64 stride_c, int batch_count) { | ||||
|   blas_log("DoBlasGemmStridedBatched"); | ||||
|   const Eigen::half alpha_half(alpha); | ||||
|   const Eigen::half beta_half(beta); | ||||
| 
 | ||||
| @ -2363,6 +2473,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched( | ||||
|     int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, | ||||
|     float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, | ||||
|     int batch_count) { | ||||
|   blas_log("DoBlasGemmStridedBatched"); | ||||
|   return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream, | ||||
|                         false, /* pointer_mode_host */ | ||||
|                         ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, | ||||
| @ -2376,6 +2487,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched( | ||||
|     int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, | ||||
|     double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, | ||||
|     int batch_count) { | ||||
|   blas_log("DoBlasGemmStridedBatched"); | ||||
|   return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream, | ||||
|                         false, /* pointer_mode_host */ | ||||
|                         ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, | ||||
|  | ||||
| @ -45,6 +45,16 @@ struct RocBlasTypeConversionHelper<Eigen::half> { | ||||
|   using mapped_type = rocblas_half; | ||||
| }; | ||||
| 
 | ||||
| template <> | ||||
| struct RocBlasTypeConversionHelper<std::complex<float>> { | ||||
|   using mapped_type = rocblas_float_complex; | ||||
| }; | ||||
| 
 | ||||
| template <> | ||||
| struct RocBlasTypeConversionHelper<std::complex<double>> { | ||||
|   using mapped_type = rocblas_double_complex; | ||||
| }; | ||||
| 
 | ||||
| // Opaque and unique identifier for the rocBLAS plugin.
 | ||||
| extern const PluginId kRocBlasPlugin; | ||||
| 
 | ||||
| @ -121,7 +131,8 @@ class ROCMBlas : public blas::BlasSupport { | ||||
|       std::unique_ptr<TemporaryDeviceMemory< | ||||
|           typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, | ||||
|       DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> | ||||
|           *device_memory); | ||||
|           *device_memory, | ||||
|       bool copy_data, bool &reallocated); | ||||
| 
 | ||||
|   // A helper function to implement DoBlasGemmBatched interfaces for generic
 | ||||
|   // types.
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user