Merge pull request #35666 from ROCmSoftwarePlatform:google_upstream_rocblas_complex

PiperOrigin-RevId: 304218949
Change-Id: Ic8b3408a71502444b8caf3846d4662ae21f1c325
This commit is contained in:
TensorFlower Gardener 2020-04-01 10:59:11 -07:00
commit c58518c893
19 changed files with 378 additions and 308 deletions

View File

@ -3411,7 +3411,6 @@ tf_py_test(
data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"], data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
shard_count = 20, shard_count = 20,
tags = [ tags = [
"no_rocm", # flaky test
"no_windows", "no_windows",
], ],
deps = [ deps = [

View File

@ -262,10 +262,9 @@ class BatchMatMulBenchmark(test.Benchmark):
if __name__ == "__main__": if __name__ == "__main__":
dtypes_to_test = [np.float16, np.float32, np.float64, np.int32] dtypes_to_test = [
if not test.is_built_with_rocm(): np.float16, np.float32, np.float64, np.int32, np.complex64, np.complex128
# ROCm does not support BLAS operations for complex types ]
dtypes_to_test += [np.complex64, np.complex128]
for dtype_ in dtypes_to_test: for dtype_ in dtypes_to_test:
for adjoint_a_ in False, True: for adjoint_a_ in False, True:
for adjoint_b_ in False, True: for adjoint_b_ in False, True:

View File

@ -183,10 +183,10 @@ def _GetEigTest(dtype_, shape_, compute_v_):
if __name__ == "__main__": if __name__ == "__main__":
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64] dtypes_to_test = [
if not test.is_built_with_rocm(): dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
# ROCm does not support BLAS operations for complex types dtypes_lib.complex128
dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128] ]
for compute_v in True, False: for compute_v in True, False:
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
for size in 1, 2, 5, 10: for size in 1, 2, 5, 10:

View File

@ -49,6 +49,7 @@ def identicaltest(tc, init1, init2, shape=None):
init2: An Initializer that generates a tensor of a given shape init2: An Initializer that generates a tensor of a given shape
shape: Shape of the tensor to initialize or `None` to use a vector of length shape: Shape of the tensor to initialize or `None` to use a vector of length
100. 100.
Returns: Returns:
True or False as determined by test. True or False as determined by test.
""" """
@ -75,6 +76,7 @@ def duplicated_initializer(tc, init, graph_seed, shape=None):
graph_seed: A graph-level seed to use. graph_seed: A graph-level seed to use.
shape: Shape of the tensor to initialize or `None` to use a vector of length shape: Shape of the tensor to initialize or `None` to use a vector of length
100. 100.
Returns: Returns:
True or False as determined by test. True or False as determined by test.
""" """
@ -94,6 +96,7 @@ def _init_sampler(tc, init, num):
tc: An instance of TensorFlowTestCase. tc: An instance of TensorFlowTestCase.
init: An Initializer that generates a tensor of a given shape init: An Initializer that generates a tensor of a given shape
num: Size of 1D tensor to create. num: Size of 1D tensor to create.
Returns: Returns:
Function to generate a random tensor. Function to generate a random tensor.
""" """
@ -187,8 +190,8 @@ class ConstantInitializersTest(test.TestCase):
expected = list(value) expected = list(value)
self._testNDimConstantInitializer("list", value, shape, expected) self._testNDimConstantInitializer("list", value, shape, expected)
self._testNDimConstantInitializer("ndarray", self._testNDimConstantInitializer("ndarray", np.asarray(value), shape,
np.asarray(value), shape, expected) expected)
self._testNDimConstantInitializer("2D-ndarray", self._testNDimConstantInitializer("2D-ndarray",
np.asarray(value).reshape(tuple(shape)), np.asarray(value).reshape(tuple(shape)),
shape, expected) shape, expected)
@ -214,11 +217,11 @@ class ConstantInitializersTest(test.TestCase):
expected = list(value) expected = list(value)
self._testNDimConstantInitializerLessValues("list", value, shape, expected) self._testNDimConstantInitializerLessValues("list", value, shape, expected)
self._testNDimConstantInitializerLessValues("ndarray", self._testNDimConstantInitializerLessValues("ndarray", np.asarray(value),
np.asarray(value), shape, shape, expected)
expected)
self._testNDimConstantInitializerLessValues( self._testNDimConstantInitializerLessValues(
"2D-ndarray", np.asarray(value).reshape(tuple([2, 3])), shape, expected) "2D-ndarray",
np.asarray(value).reshape(tuple([2, 3])), shape, expected)
def _testNDimConstantInitializerMoreValues(self, value, shape): def _testNDimConstantInitializerMoreValues(self, value, shape):
ops.reset_default_graph() ops.reset_default_graph()
@ -242,8 +245,8 @@ class ConstantInitializersTest(test.TestCase):
def testInvalidValueTypeForConstantInitializerCausesTypeError(self): def testInvalidValueTypeForConstantInitializerCausesTypeError(self):
c = constant_op.constant([1.0, 2.0, 3.0]) c = constant_op.constant([1.0, 2.0, 3.0])
with self.assertRaisesRegexp( with self.assertRaisesRegexp(TypeError,
TypeError, r"Invalid type for initial value: .*Tensor.*"): r"Invalid type for initial value: .*Tensor.*"):
init_ops.constant_initializer(c, dtype=dtypes.float32) init_ops.constant_initializer(c, dtype=dtypes.float32)
v = variables.Variable([3.0, 2.0, 1.0]) v = variables.Variable([3.0, 2.0, 1.0])
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -393,11 +396,11 @@ class VarianceScalingInitializationTest(test.TestCase):
expect_mean = 0. expect_mean = 0.
expect_var = 1. / shape[0] expect_var = 1. / shape[0]
init = init_ops.variance_scaling_initializer( init = init_ops.variance_scaling_initializer(
distribution='truncated_normal') distribution="truncated_normal")
with self.session(use_gpu=True), \ with self.session(use_gpu=True), \
test.mock.patch.object( test.mock.patch.object(
random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ random_ops, "truncated_normal", wraps=random_ops.truncated_normal) \
as mock_truncated_normal: as mock_truncated_normal:
x = init(shape).eval() x = init(shape).eval()
self.assertTrue(mock_truncated_normal.called) self.assertTrue(mock_truncated_normal.called)
@ -410,11 +413,11 @@ class VarianceScalingInitializationTest(test.TestCase):
shape = [100, 100] shape = [100, 100]
expect_mean = 0. expect_mean = 0.
expect_var = 1. / shape[0] expect_var = 1. / shape[0]
init = init_ops.variance_scaling_initializer(distribution='normal') init = init_ops.variance_scaling_initializer(distribution="normal")
with self.session(use_gpu=True), \ with self.session(use_gpu=True), \
test.mock.patch.object( test.mock.patch.object(
random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ random_ops, "truncated_normal", wraps=random_ops.truncated_normal) \
as mock_truncated_normal: as mock_truncated_normal:
x = init(shape).eval() x = init(shape).eval()
self.assertTrue(mock_truncated_normal.called) self.assertTrue(mock_truncated_normal.called)
@ -428,11 +431,11 @@ class VarianceScalingInitializationTest(test.TestCase):
expect_mean = 0. expect_mean = 0.
expect_var = 1. / shape[0] expect_var = 1. / shape[0]
init = init_ops.variance_scaling_initializer( init = init_ops.variance_scaling_initializer(
distribution='untruncated_normal') distribution="untruncated_normal")
with self.session(use_gpu=True), \ with self.session(use_gpu=True), \
test.mock.patch.object( test.mock.patch.object(
random_ops, 'random_normal', wraps=random_ops.random_normal) \ random_ops, "random_normal", wraps=random_ops.random_normal) \
as mock_random_normal: as mock_random_normal:
x = init(shape).eval() x = init(shape).eval()
self.assertTrue(mock_random_normal.called) self.assertTrue(mock_random_normal.called)
@ -445,7 +448,7 @@ class VarianceScalingInitializationTest(test.TestCase):
shape = [100, 100] shape = [100, 100]
expect_mean = 0. expect_mean = 0.
expect_var = 1. / shape[0] expect_var = 1. / shape[0]
init = init_ops.variance_scaling_initializer(distribution='uniform') init = init_ops.variance_scaling_initializer(distribution="uniform")
with self.session(use_gpu=True): with self.session(use_gpu=True):
x = init(shape).eval() x = init(shape).eval()
@ -525,17 +528,13 @@ class RangeTest(test.TestCase):
math_ops.range(zero_float64, zero_int32, 1).dtype, dtypes.float64) math_ops.range(zero_float64, zero_int32, 1).dtype, dtypes.float64)
self.assertEqual( self.assertEqual(
math_ops.range( math_ops.range(0, 0, 1, dtype=dtypes.int32).dtype, dtypes.int32)
0, 0, 1, dtype=dtypes.int32).dtype, dtypes.int32)
self.assertEqual( self.assertEqual(
math_ops.range( math_ops.range(0, 0, 1, dtype=dtypes.int64).dtype, dtypes.int64)
0, 0, 1, dtype=dtypes.int64).dtype, dtypes.int64)
self.assertEqual( self.assertEqual(
math_ops.range( math_ops.range(0, 0, 1, dtype=dtypes.float32).dtype, dtypes.float32)
0, 0, 1, dtype=dtypes.float32).dtype, dtypes.float32)
self.assertEqual( self.assertEqual(
math_ops.range( math_ops.range(0, 0, 1, dtype=dtypes.float64).dtype, dtypes.float64)
0, 0, 1, dtype=dtypes.float64).dtype, dtypes.float64)
def testMixedDType(self): def testMixedDType(self):
# Test case for GitHub issue 35710 # Test case for GitHub issue 35710
@ -578,8 +577,8 @@ class LinSpaceTest(test.TestCase):
self.assertArrayNear( self.assertArrayNear(
self._LinSpace(-1., -5., 3), np.array([-1., -3., -5.]), 1e-5) self._LinSpace(-1., -5., 3), np.array([-1., -3., -5.]), 1e-5)
self.assertArrayNear( self.assertArrayNear(
self._LinSpace(-1., -5., 4), self._LinSpace(-1., -5., 4), np.array([-1., -7. / 3., -11. / 3.,
np.array([-1., -7. / 3., -11. / 3., -5.]), 1e-5) -5.]), 1e-5)
def testNegativeToPositive(self): def testNegativeToPositive(self):
for self.force_gpu in self._gpu_modes(): for self.force_gpu in self._gpu_modes():
@ -859,7 +858,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
def testInvalidDataType(self): def testInvalidDataType(self):
self.assertRaises( self.assertRaises(
ValueError, init_ops.convolutional_delta_orthogonal, ValueError,
init_ops.convolutional_delta_orthogonal,
dtype=dtypes.string) dtype=dtypes.string)
def testInvalidShape(self): def testInvalidShape(self):
@ -872,8 +872,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
shape = (3, 3, 10, 10) shape = (3, 3, 10, 10)
for dtype in [dtypes.float32, dtypes.float64]: for dtype in [dtypes.float32, dtypes.float64]:
init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype) init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype)
init2 = init_ops.convolutional_delta_orthogonal(gain=3.14, init2 = init_ops.convolutional_delta_orthogonal(
seed=1, dtype=dtype) gain=3.14, seed=1, dtype=dtype)
with self.session(graph=ops.Graph(), use_gpu=True): with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval() t1 = init1(shape).eval()
t2 = init2(shape).eval() t2 = init2(shape).eval()
@ -896,18 +896,14 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
else: else:
shape = [4, 16, 16, 16, 64] shape = [4, 16, 16, 16, 64]
convolution = convolutional.conv3d convolution = convolutional.conv3d
if test.is_built_with_rocm():
# This subtest triggers a known bug in ROCm runtime code
# The bug has been fixed and will be available in ROCm 2.7
# Re-enable this test once ROCm 2.7 is released
continue
inputs = random_ops.random_normal(shape, dtype=dtype) inputs = random_ops.random_normal(shape, dtype=dtype)
inputs_2norm = linalg_ops.norm(inputs) inputs_2norm = linalg_ops.norm(inputs)
outputs = convolution( outputs = convolution(
inputs, padding="same", filters=128, inputs,
kernel_size=kernel_size, use_bias=False, padding="same",
filters=128,
kernel_size=kernel_size,
use_bias=False,
kernel_initializer=init_ops.convolutional_delta_orthogonal( kernel_initializer=init_ops.convolutional_delta_orthogonal(
gain=gain)) gain=gain))
outputs_shape = shape[0:-1] + [128] outputs_shape = shape[0:-1] + [128]
@ -931,9 +927,10 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
tol = 1e-5 tol = 1e-5
with self.session(use_gpu=True): with self.session(use_gpu=True):
for i in range(count): for i in range(count):
x = variable_scope.get_variable("{}".format(i), shape=shape, x = variable_scope.get_variable(
initializer= "{}".format(i),
init_ops.convolutional_delta_orthogonal) shape=shape,
initializer=init_ops.convolutional_delta_orthogonal)
x.initializer.run() x.initializer.run()
y = self.evaluate(x)[1, 1, :, :] y = self.evaluate(x)[1, 1, :, :]
determinant = np.linalg.det(y) determinant = np.linalg.det(y)
@ -971,8 +968,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
def testInvalidDataType(self): def testInvalidDataType(self):
self.assertRaises( self.assertRaises(
ValueError, init_ops.convolutional_orthogonal_1d, ValueError, init_ops.convolutional_orthogonal_1d, dtype=dtypes.string)
dtype=dtypes.string)
def testInvalidShape(self): def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_1d() init1 = init_ops.convolutional_orthogonal_1d()
@ -984,8 +980,8 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
shape = (3, 10, 10) shape = (3, 10, 10)
for dtype in [dtypes.float32, dtypes.float64]: for dtype in [dtypes.float32, dtypes.float64]:
init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype) init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_1d(gain=3.14, init2 = init_ops.convolutional_orthogonal_1d(
seed=1, dtype=dtype) gain=3.14, seed=1, dtype=dtype)
with self.session(graph=ops.Graph(), use_gpu=True): with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval() t1 = init1(shape).eval()
t2 = init2(shape).eval() t2 = init2(shape).eval()
@ -1000,9 +996,10 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
tol = 1e-5 tol = 1e-5
with self.session(use_gpu=True): with self.session(use_gpu=True):
for i in range(count): for i in range(count):
x = variable_scope.get_variable("{}".format(i), shape=shape, x = variable_scope.get_variable(
initializer= "{}".format(i),
init_ops.convolutional_orthogonal_1d) shape=shape,
initializer=init_ops.convolutional_orthogonal_1d)
x.initializer.run() x.initializer.run()
y = np.sum(x.eval(), axis=0) y = np.sum(x.eval(), axis=0)
determinant = np.linalg.det(y) determinant = np.linalg.det(y)
@ -1018,6 +1015,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testShapesValues(self): def testShapesValues(self):
def circular_pad(input_, width, kernel_size): def circular_pad(input_, width, kernel_size):
"""Pad input_ for computing (circular) convolution. """Pad input_ for computing (circular) convolution.
@ -1025,6 +1023,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
input_: the input tensor input_: the input tensor
width: the width of the tensor. width: the width of the tensor.
kernel_size: the kernel size of the filter. kernel_size: the kernel size of the filter.
Returns: Returns:
a tensor whose width is (width + kernel_size - 1). a tensor whose width is (width + kernel_size - 1).
""" """
@ -1053,8 +1052,11 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
inputs_2norm = linalg_ops.norm(inputs) inputs_2norm = linalg_ops.norm(inputs)
input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
outputs = convolution( outputs = convolution(
input_with_circular_pad, padding="valid", filters=cout, input_with_circular_pad,
kernel_size=kernel_size[0], use_bias=False, padding="valid",
filters=cout,
kernel_size=kernel_size[0],
use_bias=False,
kernel_initializer=init_ops.convolutional_orthogonal_1d(gain=gain)) kernel_initializer=init_ops.convolutional_orthogonal_1d(gain=gain))
outputs_2norm = linalg_ops.norm(outputs) outputs_2norm = linalg_ops.norm(outputs)
ratio = outputs_2norm / inputs_2norm ratio = outputs_2norm / inputs_2norm
@ -1091,8 +1093,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
def testInvalidDataType(self): def testInvalidDataType(self):
self.assertRaises( self.assertRaises(
ValueError, init_ops.convolutional_orthogonal_2d, ValueError, init_ops.convolutional_orthogonal_2d, dtype=dtypes.string)
dtype=dtypes.string)
def testInvalidShape(self): def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_2d() init1 = init_ops.convolutional_orthogonal_2d()
@ -1104,8 +1105,8 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
shape = (3, 3, 10, 10) shape = (3, 3, 10, 10)
for dtype in [dtypes.float32, dtypes.float64]: for dtype in [dtypes.float32, dtypes.float64]:
init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype) init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_2d(gain=3.14, init2 = init_ops.convolutional_orthogonal_2d(
seed=1, dtype=dtype) gain=3.14, seed=1, dtype=dtype)
with self.session(graph=ops.Graph(), use_gpu=True): with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval() t1 = init1(shape).eval()
t2 = init2(shape).eval() t2 = init2(shape).eval()
@ -1113,6 +1114,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testShapesValues(self): def testShapesValues(self):
def circular_pad(input_, width, kernel_size): def circular_pad(input_, width, kernel_size):
"""Pad input_ for computing (circular) convolution. """Pad input_ for computing (circular) convolution.
@ -1120,6 +1122,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
input_: the input tensor input_: the input tensor
width: the width of the tensor. width: the width of the tensor.
kernel_size: the kernel size of the filter. kernel_size: the kernel size of the filter.
Returns: Returns:
a tensor whose width is (width + kernel_size - 1). a tensor whose width is (width + kernel_size - 1).
""" """
@ -1153,8 +1156,11 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
inputs_2norm = linalg_ops.norm(inputs) inputs_2norm = linalg_ops.norm(inputs)
input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
outputs = convolution( outputs = convolution(
input_with_circular_pad, padding="valid", filters=cout, input_with_circular_pad,
kernel_size=kernel_size, use_bias=False, padding="valid",
filters=cout,
kernel_size=kernel_size,
use_bias=False,
kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain)) kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain))
outputs_2norm = linalg_ops.norm(outputs) outputs_2norm = linalg_ops.norm(outputs)
ratio = outputs_2norm / inputs_2norm ratio = outputs_2norm / inputs_2norm
@ -1191,8 +1197,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
def testInvalidDataType(self): def testInvalidDataType(self):
self.assertRaises( self.assertRaises(
ValueError, init_ops.convolutional_orthogonal_3d, ValueError, init_ops.convolutional_orthogonal_3d, dtype=dtypes.string)
dtype=dtypes.string)
def testInvalidShape(self): def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_3d() init1 = init_ops.convolutional_orthogonal_3d()
@ -1204,8 +1209,8 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
shape = (3, 3, 3, 10, 10) shape = (3, 3, 3, 10, 10)
for dtype in [dtypes.float32, dtypes.float64]: for dtype in [dtypes.float32, dtypes.float64]:
init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype) init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_3d(gain=3.14, init2 = init_ops.convolutional_orthogonal_3d(
seed=1, dtype=dtype) gain=3.14, seed=1, dtype=dtype)
with self.session(graph=ops.Graph(), use_gpu=True): with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval() t1 = init1(shape).eval()
t2 = init2(shape).eval() t2 = init2(shape).eval()
@ -1220,9 +1225,10 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
tol = 1e-5 tol = 1e-5
with self.session(use_gpu=True): with self.session(use_gpu=True):
for i in range(count): for i in range(count):
x = variable_scope.get_variable("{}".format(i), shape=shape, x = variable_scope.get_variable(
initializer= "{}".format(i),
init_ops.convolutional_orthogonal_3d) shape=shape,
initializer=init_ops.convolutional_orthogonal_3d)
x.initializer.run() x.initializer.run()
y = np.sum(x.eval(), axis=(0, 1, 2)) y = np.sum(x.eval(), axis=(0, 1, 2))
determinant = np.linalg.det(y) determinant = np.linalg.det(y)
@ -1238,6 +1244,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testShapesValues(self): def testShapesValues(self):
def circular_pad(input_, width, kernel_size): def circular_pad(input_, width, kernel_size):
"""Padding input_ for computing circular convolution. """Padding input_ for computing circular convolution.
@ -1255,14 +1262,12 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
tmp_up = array_ops.slice(input_, [0, width - beginning, 0, 0, 0], tmp_up = array_ops.slice(input_, [0, width - beginning, 0, 0, 0],
[-1, beginning, -1, -1, -1]) [-1, beginning, -1, -1, -1])
tmp_down = array_ops.slice(input_, [0, 0, 0, 0, 0], tmp_down = array_ops.slice(input_, [0, 0, 0, 0, 0], [-1, end, -1, -1, -1])
[-1, end, -1, -1, -1])
tmp = array_ops.concat([tmp_up, input_, tmp_down], 1) tmp = array_ops.concat([tmp_up, input_, tmp_down], 1)
tmp_left = array_ops.slice(tmp, [0, 0, width - beginning, 0, 0], tmp_left = array_ops.slice(tmp, [0, 0, width - beginning, 0, 0],
[-1, -1, beginning, -1, -1]) [-1, -1, beginning, -1, -1])
tmp_right = array_ops.slice(tmp, [0, 0, 0, 0, 0], tmp_right = array_ops.slice(tmp, [0, 0, 0, 0, 0], [-1, -1, end, -1, -1])
[-1, -1, end, -1, -1])
tmp = array_ops.concat([tmp_left, tmp, tmp_right], 2) tmp = array_ops.concat([tmp_left, tmp, tmp_right], 2)
tmp_front = array_ops.slice(tmp, [0, 0, 0, width - beginning, 0], tmp_front = array_ops.slice(tmp, [0, 0, 0, width - beginning, 0],
@ -1284,8 +1289,11 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
inputs_2norm = linalg_ops.norm(inputs) inputs_2norm = linalg_ops.norm(inputs)
input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0]) input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
outputs = convolution( outputs = convolution(
input_with_circular_pad, padding="valid", filters=cout, input_with_circular_pad,
kernel_size=kernel_size[0], use_bias=False, padding="valid",
filters=cout,
kernel_size=kernel_size[0],
use_bias=False,
kernel_initializer=init_ops.convolutional_orthogonal_3d(gain=gain)) kernel_initializer=init_ops.convolutional_orthogonal_3d(gain=gain))
outputs_2norm = linalg_ops.norm(outputs) outputs_2norm = linalg_ops.norm(outputs)
ratio = outputs_2norm / inputs_2norm ratio = outputs_2norm / inputs_2norm

View File

@ -141,8 +141,6 @@ class LinearOperatorAdjointTest(
full_matrix2, adjoint=True, adjoint_arg=True).to_dense())) full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
def test_matmul_adjoint_complex_operator(self): def test_matmul_adjoint_complex_operator(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1)
@ -201,7 +199,8 @@ class LinearOperatorAdjointTest(
def test_solve_adjoint_complex_operator(self): def test_solve_adjoint_complex_operator(self):
if test.is_built_with_rocm(): if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types") self.skipTest("ROCm does not support BLAS solve operations"
" for complex types")
matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix( matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix(
[4, 4], dtype=dtypes.complex128, force_well_conditioned=True) + [4, 4], dtype=dtypes.complex128, force_well_conditioned=True) +
1j * linear_operator_test_util.random_tril_matrix( 1j * linear_operator_test_util.random_tril_matrix(

View File

@ -114,18 +114,18 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
# real, the matrix will not be real. # real, the matrix will not be real.
return [dtypes.complex64, dtypes.complex128] return [dtypes.complex64, dtypes.complex128]
def operator_and_matrix( def operator_and_matrix(self,
self, shape_info, dtype, use_placeholder, shape_info,
ensure_self_adjoint_and_pd=False): dtype,
use_placeholder,
ensure_self_adjoint_and_pd=False):
shape = shape_info.shape shape = shape_info.shape
# For this test class, we are creating real spectrums. # For this test class, we are creating real spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero. # We also want the spectrum to have eigenvalues bounded away from zero.
# #
# spectrum is bounded away from zero. # spectrum is bounded away from zero.
spectrum = linear_operator_test_util.random_sign_uniform( spectrum = linear_operator_test_util.random_sign_uniform(
shape=self._shape_to_spectrum_shape(shape), shape=self._shape_to_spectrum_shape(shape), minval=1., maxval=2.)
minval=1.,
maxval=2.)
if ensure_self_adjoint_and_pd: if ensure_self_adjoint_and_pd:
spectrum = math_ops.abs(spectrum) spectrum = math_ops.abs(spectrum)
# If dtype is complex, cast spectrum to complex. The imaginary part will be # If dtype is complex, cast spectrum to complex. The imaginary part will be
@ -176,9 +176,11 @@ class LinearOperatorCirculantTestHermitianSpectrum(
zero imaginary part. zero imaginary part.
""" """
def operator_and_matrix( def operator_and_matrix(self,
self, shape_info, dtype, use_placeholder, shape_info,
ensure_self_adjoint_and_pd=False): dtype,
use_placeholder,
ensure_self_adjoint_and_pd=False):
shape = shape_info.shape shape = shape_info.shape
# For this test class, we are creating Hermitian spectrums. # For this test class, we are creating Hermitian spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero. # We also want the spectrum to have eigenvalues bounded away from zero.
@ -259,9 +261,11 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def skip_these_tests(): def skip_these_tests():
return ["cholesky", "eigvalsh"] return ["cholesky", "eigvalsh"]
def operator_and_matrix( def operator_and_matrix(self,
self, shape_info, dtype, use_placeholder, shape_info,
ensure_self_adjoint_and_pd=False): dtype,
use_placeholder,
ensure_self_adjoint_and_pd=False):
del ensure_self_adjoint_and_pd del ensure_self_adjoint_and_pd
shape = shape_info.shape shape = shape_info.shape
# Will be well conditioned enough to get accurate solves. # Will be well conditioned enough to get accurate solves.
@ -357,11 +361,6 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
self.evaluate(operator.assert_non_singular()) self.evaluate(operator.assert_non_singular())
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self): def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
if test.is_built_with_rocm():
# ROCm does not yet support BLAS operations with complex types.
self.skipTest("ROCm does not support BLAS operations for complex types")
spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64) spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum) operator = linalg.LinearOperatorCirculant(spectrum)
with self.cached_session(): with self.cached_session():
@ -486,9 +485,11 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
def skip_these_tests(): def skip_these_tests():
return ["cond"] return ["cond"]
def operator_and_matrix( def operator_and_matrix(self,
self, shape_info, dtype, use_placeholder, shape_info,
ensure_self_adjoint_and_pd=False): dtype,
use_placeholder,
ensure_self_adjoint_and_pd=False):
shape = shape_info.shape shape = shape_info.shape
# For this test class, we are creating Hermitian spectrums. # For this test class, we are creating Hermitian spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero. # We also want the spectrum to have eigenvalues bounded away from zero.
@ -547,9 +548,11 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
def skip_these_tests(): def skip_these_tests():
return ["cholesky", "eigvalsh"] return ["cholesky", "eigvalsh"]
def operator_and_matrix( def operator_and_matrix(self,
self, shape_info, dtype, use_placeholder, shape_info,
ensure_self_adjoint_and_pd=False): dtype,
use_placeholder,
ensure_self_adjoint_and_pd=False):
del ensure_self_adjoint_and_pd del ensure_self_adjoint_and_pd
shape = shape_info.shape shape = shape_info.shape
# Will be well conditioned enough to get accurate solves. # Will be well conditioned enough to get accurate solves.
@ -665,11 +668,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
yield sess yield sess
def test_real_spectrum_gives_self_adjoint_operator(self): def test_real_spectrum_gives_self_adjoint_operator(self):
if test.is_built_with_rocm():
# ROCm does not yet support BLAS operations with complext types
self.skipTest("ROCm does not support BLAS operations for complex types")
with self.cached_session(): with self.cached_session():
# This is a real and hermitian spectrum. # This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal( spectrum = linear_operator_test_util.random_normal(
@ -686,11 +684,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
self.assertAllClose(matrix, matrix_h) self.assertAllClose(matrix, matrix_h)
def test_defining_operator_using_real_convolution_kernel(self): def test_defining_operator_using_real_convolution_kernel(self):
if test.is_built_with_rocm():
# ROCm does not yet support BLAS operations with complext types
self.skipTest("ROCm does not support BLAS operations for complex types")
with self.cached_session(): with self.cached_session():
convolution_kernel = linear_operator_test_util.random_normal( convolution_kernel = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32) shape=(2, 2, 3, 5), dtype=dtypes.float32)
@ -709,11 +702,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5) np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5)
def test_defining_spd_operator_by_taking_real_part(self): def test_defining_spd_operator_by_taking_real_part(self):
if test.is_built_with_rocm():
# ROCm does not yet support BLAS operations with complext types
self.skipTest("ROCm does not support BLAS operations for complex types")
with self.cached_session(): # Necessary for fft_kernel_label_map with self.cached_session(): # Necessary for fft_kernel_label_map
# S is real and positive. # S is real and positive.
s = linear_operator_test_util.random_uniform( s = linear_operator_test_util.random_uniform(

View File

@ -130,14 +130,12 @@ class LuOpTest(test.TestCase):
for output_idx_type in (dtypes.int32, dtypes.int64): for output_idx_type in (dtypes.int32, dtypes.int64):
self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type) self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type)
if not test.is_built_with_rocm(): for dtype in (np.complex64, np.complex128):
# ROCm does not support BLAS operations for complex types for output_idx_type in (dtypes.int32, dtypes.int64):
for dtype in (np.complex64, np.complex128): complex_data = np.tril(1j * data, -1).astype(dtype)
for output_idx_type in (dtypes.int32, dtypes.int64): complex_data += np.triu(-1j * data, 1).astype(dtype)
complex_data = np.tril(1j * data, -1).astype(dtype) complex_data += data
complex_data += np.triu(-1j * data, 1).astype(dtype) self._verifyLu(complex_data, output_idx_type=output_idx_type)
complex_data += data
self._verifyLu(complex_data, output_idx_type=output_idx_type)
def testPivoting(self): def testPivoting(self):
# This matrix triggers partial pivoting because the first diagonal entry # This matrix triggers partial pivoting because the first diagonal entry
@ -152,17 +150,15 @@ class LuOpTest(test.TestCase):
# Make sure p_val is not the identity permutation. # Make sure p_val is not the identity permutation.
self.assertNotAllClose(np.arange(3), p_val) self.assertNotAllClose(np.arange(3), p_val)
if not test.is_built_with_rocm(): for dtype in (np.complex64, np.complex128):
# ROCm does not support BLAS operations for complex types complex_data = np.tril(1j * data, -1).astype(dtype)
for dtype in (np.complex64, np.complex128): complex_data += np.triu(-1j * data, 1).astype(dtype)
complex_data = np.tril(1j * data, -1).astype(dtype) complex_data += data
complex_data += np.triu(-1j * data, 1).astype(dtype) self._verifyLu(complex_data)
complex_data += data _, p = linalg_ops.lu(data)
self._verifyLu(complex_data) p_val = self.evaluate([p])
_, p = linalg_ops.lu(data) # Make sure p_val is not the identity permutation.
p_val = self.evaluate([p]) self.assertNotAllClose(np.arange(3), p_val)
# Make sure p_val is not the identity permutation.
self.assertNotAllClose(np.arange(3), p_val)
def testInvalidMatrix(self): def testInvalidMatrix(self):
# LU factorization gives an error when the input is singular. # LU factorization gives an error when the input is singular.
@ -195,13 +191,11 @@ class LuOpTest(test.TestCase):
matrices = np.random.rand(batch_size, 5, 5) matrices = np.random.rand(batch_size, 5, 5)
self._verifyLu(matrices) self._verifyLu(matrices)
if not test.is_built_with_rocm(): # Generate random complex valued matrices.
# ROCm does not support BLAS operations for complex types np.random.seed(52)
# Generate random complex valued matrices. matrices = np.random.rand(batch_size, 5,
np.random.seed(52) 5) + 1j * np.random.rand(batch_size, 5, 5)
matrices = np.random.rand(batch_size, 5, self._verifyLu(matrices)
5) + 1j * np.random.rand(batch_size, 5, 5)
self._verifyLu(matrices)
def testLargeMatrix(self): def testLargeMatrix(self):
# Generate random matrices. # Generate random matrices.
@ -210,12 +204,10 @@ class LuOpTest(test.TestCase):
data = np.random.rand(n, n) data = np.random.rand(n, n)
self._verifyLu(data) self._verifyLu(data)
if not test.is_built_with_rocm(): # Generate random complex valued matrices.
# ROCm does not support BLAS operations for complex types np.random.seed(129)
# Generate random complex valued matrices. data = np.random.rand(n, n) + 1j * np.random.rand(n, n)
np.random.seed(129) self._verifyLu(data)
data = np.random.rand(n, n) + 1j * np.random.rand(n, n)
self._verifyLu(data)
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testEmpty(self): def testEmpty(self):

View File

@ -226,10 +226,10 @@ class MatMulInfixOperatorTest(test_lib.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
sizes = [1, 3, 5] sizes = [1, 3, 5]
trans_options = [[False, False], [True, False], [False, True]] trans_options = [[False, False], [True, False], [False, True]]
dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64] dtypes_to_test = [
if not test_lib.is_built_with_rocm(): np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64,
# ROCm does not support BLAS operations for complex types np.complex128
dtypes_to_test += [np.complex64, np.complex128] ]
# TF2 does not support placeholders under eager so we skip it # TF2 does not support placeholders under eager so we skip it
for use_static_shape in set([True, tf2.enabled()]): for use_static_shape in set([True, tf2.enabled()]):
for dtype in dtypes_to_test: for dtype in dtypes_to_test:

View File

@ -91,8 +91,6 @@ class ExponentialOpTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testNonsymmetricComplex(self): def testNonsymmetricComplex(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
matrix1 = np.array([[1., 2.], [3., 4.]]) matrix1 = np.array([[1., 2.], [3., 4.]])
matrix2 = np.array([[1., 3.], [3., 5.]]) matrix2 = np.array([[1., 3.], [3., 5.]])
matrix1 = matrix1.astype(np.complex64) matrix1 = matrix1.astype(np.complex64)
@ -114,8 +112,6 @@ class ExponentialOpTest(test.TestCase):
self._verifyExponentialReal(self._makeBatch(matrix1, matrix2)) self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
def testSymmetricPositiveDefiniteComplex(self): def testSymmetricPositiveDefiniteComplex(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
matrix1 = np.array([[2., 1.], [1., 2.]]) matrix1 = np.array([[2., 1.], [1., 2.]])
matrix2 = np.array([[3., -1.], [-1., 3.]]) matrix2 = np.array([[3., -1.], [-1., 3.]])
matrix1 = matrix1.astype(np.complex64) matrix1 = matrix1.astype(np.complex64)
@ -185,8 +181,8 @@ class MatrixExponentialBenchmark(test.Benchmark):
shape = shape[-2:] shape = shape[-2:]
assert shape[0] == shape[1] assert shape[0] == shape[1]
n = shape[0] n = shape[0]
matrix = np.ones(shape).astype(np.float32) / ( matrix = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag(
2.0 * n) + np.diag(np.ones(n).astype(np.float32)) np.ones(n).astype(np.float32))
return variables.Variable(np.tile(matrix, batch_shape + (1, 1))) return variables.Variable(np.tile(matrix, batch_shape + (1, 1)))
def benchmarkMatrixExponentialOp(self): def benchmarkMatrixExponentialOp(self):
@ -201,8 +197,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
sess, sess,
control_flow_ops.group(expm), control_flow_ops.group(expm),
min_iters=25, min_iters=25,
name="matrix_exponential_cpu_{shape}".format( name="matrix_exponential_cpu_{shape}".format(shape=shape))
shape=shape))
if test.is_gpu_available(True): if test.is_gpu_available(True):
with ops.Graph().as_default(), \ with ops.Graph().as_default(), \
@ -215,8 +210,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
sess, sess,
control_flow_ops.group(expm), control_flow_ops.group(expm),
min_iters=25, min_iters=25,
name="matrix_exponential_gpu_{shape}".format( name="matrix_exponential_gpu_{shape}".format(shape=shape))
shape=shape))
def _TestRandomSmall(dtype, batch_dims, size): def _TestRandomSmall(dtype, batch_dims, size):
@ -224,9 +218,7 @@ def _TestRandomSmall(dtype, batch_dims, size):
def Test(self): def Test(self):
np.random.seed(42) np.random.seed(42)
shape = batch_dims + (size, size) shape = batch_dims + (size, size)
matrix = np.random.uniform( matrix = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype)
low=-1.0, high=1.0,
size=shape).astype(dtype)
self._verifyExponentialReal(matrix) self._verifyExponentialReal(matrix)
return Test return Test
@ -237,10 +229,9 @@ def _TestL1Norms(dtype, shape, scale):
def Test(self): def Test(self):
np.random.seed(42) np.random.seed(42)
matrix = np.random.uniform( matrix = np.random.uniform(
low=-1.0, high=1.0, low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
size=np.prod(shape)).reshape(shape).astype(dtype)
print(dtype, shape, scale, matrix) print(dtype, shape, scale, matrix)
l1_norm = np.max(np.sum(np.abs(matrix), axis=matrix.ndim-2)) l1_norm = np.max(np.sum(np.abs(matrix), axis=matrix.ndim - 2))
matrix /= l1_norm matrix /= l1_norm
self._verifyExponentialReal(scale * matrix) self._verifyExponentialReal(scale * matrix)
@ -258,12 +249,12 @@ if __name__ == "__main__":
for shape_ in [(3, 3), (2, 3, 3)]: for shape_ in [(3, 3), (2, 3, 3)]:
for dtype_ in [np.float32, np.complex64]: for dtype_ in [np.float32, np.complex64]:
for scale_ in [0.1, 1.5, 5.0, 20.0]: for scale_ in [0.1, 1.5, 5.0, 20.0]:
name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_*10)) name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_ * 10))
setattr(ExponentialOpTest, "testL1Norms_" + name, setattr(ExponentialOpTest, "testL1Norms_" + name,
_TestL1Norms(dtype_, shape_, scale_)) _TestL1Norms(dtype_, shape_, scale_))
for dtype_ in [np.float64, np.complex128]: for dtype_ in [np.float64, np.complex128]:
for scale_ in [0.01, 0.2, 0.5, 1.5, 6.0, 25.0]: for scale_ in [0.01, 0.2, 0.5, 1.5, 6.0, 25.0]:
name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_*100)) name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_ * 100))
setattr(ExponentialOpTest, "testL1Norms_" + name, setattr(ExponentialOpTest, "testL1Norms_" + name,
_TestL1Norms(dtype_, shape_, scale_)) _TestL1Norms(dtype_, shape_, scale_))
test.main() test.main()

View File

@ -74,17 +74,14 @@ class InverseOpTest(test.TestCase):
self._verifyInverseReal(matrix2) self._verifyInverseReal(matrix2)
# A multidimensional batch of 2x2 matrices # A multidimensional batch of 2x2 matrices
self._verifyInverseReal(self._makeBatch(matrix1, matrix2)) self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
if not test.is_built_with_rocm(): matrix1 = matrix1.astype(np.complex64)
# ROCm does not support BLAS operations for complex types matrix1 += 1j * matrix1
# Complex matrix2 = matrix2.astype(np.complex64)
matrix1 = matrix1.astype(np.complex64) matrix2 += 1j * matrix2
matrix1 += 1j * matrix1 self._verifyInverseComplex(matrix1)
matrix2 = matrix2.astype(np.complex64) self._verifyInverseComplex(matrix2)
matrix2 += 1j * matrix2 # Complex batch
self._verifyInverseComplex(matrix1) self._verifyInverseComplex(self._makeBatch(matrix1, matrix2))
self._verifyInverseComplex(matrix2)
# Complex batch
self._verifyInverseComplex(self._makeBatch(matrix1, matrix2))
def testSymmetricPositiveDefinite(self): def testSymmetricPositiveDefinite(self):
# 2x2 matrices # 2x2 matrices
@ -94,17 +91,14 @@ class InverseOpTest(test.TestCase):
self._verifyInverseReal(matrix2) self._verifyInverseReal(matrix2)
# A multidimensional batch of 2x2 matrices # A multidimensional batch of 2x2 matrices
self._verifyInverseReal(self._makeBatch(matrix1, matrix2)) self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
if not test.is_built_with_rocm(): matrix1 = matrix1.astype(np.complex64)
# ROCm does not support BLAS operations for complex types matrix1 += 1j * matrix1
# Complex matrix2 = matrix2.astype(np.complex64)
matrix1 = matrix1.astype(np.complex64) matrix2 += 1j * matrix2
matrix1 += 1j * matrix1 self._verifyInverseComplex(matrix1)
matrix2 = matrix2.astype(np.complex64) self._verifyInverseComplex(matrix2)
matrix2 += 1j * matrix2 # Complex batch
self._verifyInverseComplex(matrix1) self._verifyInverseComplex(self._makeBatch(matrix1, matrix2))
self._verifyInverseComplex(matrix2)
# Complex batch
self._verifyInverseComplex(self._makeBatch(matrix1, matrix2))
@test_util.deprecated_graph_mode_only @test_util.deprecated_graph_mode_only
def testNonSquareMatrix(self): def testNonSquareMatrix(self):

View File

@ -59,8 +59,6 @@ class LogarithmOpTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testNonsymmetric(self): def testNonsymmetric(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
# 2x2 matrices # 2x2 matrices
matrix1 = np.array([[1., 2.], [3., 4.]]) matrix1 = np.array([[1., 2.], [3., 4.]])
matrix2 = np.array([[1., 3.], [3., 5.]]) matrix2 = np.array([[1., 3.], [3., 5.]])
@ -75,8 +73,6 @@ class LogarithmOpTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testSymmetricPositiveDefinite(self): def testSymmetricPositiveDefinite(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
# 2x2 matrices # 2x2 matrices
matrix1 = np.array([[2., 1.], [1., 2.]]) matrix1 = np.array([[2., 1.], [1., 2.]])
matrix2 = np.array([[3., -1.], [-1., 3.]]) matrix2 = np.array([[3., -1.], [-1., 3.]])
@ -111,8 +107,6 @@ class LogarithmOpTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testRandomSmallAndLargeComplex64(self): def testRandomSmallAndLargeComplex64(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
np.random.seed(42) np.random.seed(42)
for batch_dims in [(), (1,), (3,), (2, 2)]: for batch_dims in [(), (1,), (3,), (2, 2)]:
for size in 8, 31, 32: for size in 8, 31, 32:
@ -124,8 +118,6 @@ class LogarithmOpTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testRandomSmallAndLargeComplex128(self): def testRandomSmallAndLargeComplex128(self):
if test.is_built_with_rocm():
self.skipTest("ROCm does not support BLAS operations for complex types")
np.random.seed(42) np.random.seed(42)
for batch_dims in [(), (1,), (3,), (2, 2)]: for batch_dims in [(), (1,), (3,), (2, 2)]:
for size in 8, 31, 32: for size in 8, 31, 32:
@ -169,8 +161,8 @@ class MatrixLogarithmBenchmark(test.Benchmark):
shape = shape[-2:] shape = shape[-2:]
assert shape[0] == shape[1] assert shape[0] == shape[1]
n = shape[0] n = shape[0]
matrix = np.ones(shape).astype(np.complex64) / ( matrix = np.ones(shape).astype(np.complex64) / (2.0 * n) + np.diag(
2.0 * n) + np.diag(np.ones(n).astype(np.complex64)) np.ones(n).astype(np.complex64))
return variables.Variable(np.tile(matrix, batch_shape + (1, 1))) return variables.Variable(np.tile(matrix, batch_shape + (1, 1)))
def benchmarkMatrixLogarithmOp(self): def benchmarkMatrixLogarithmOp(self):
@ -185,8 +177,7 @@ class MatrixLogarithmBenchmark(test.Benchmark):
sess, sess,
control_flow_ops.group(logm), control_flow_ops.group(logm),
min_iters=25, min_iters=25,
name="matrix_logarithm_cpu_{shape}".format( name="matrix_logarithm_cpu_{shape}".format(shape=shape))
shape=shape))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -59,16 +59,13 @@ class SquareRootOpTest(test.TestCase):
self._verifySquareRootReal(matrix1) self._verifySquareRootReal(matrix1)
self._verifySquareRootReal(matrix2) self._verifySquareRootReal(matrix2)
self._verifySquareRootReal(self._makeBatch(matrix1, matrix2)) self._verifySquareRootReal(self._makeBatch(matrix1, matrix2))
if not test.is_built_with_rocm(): matrix1 = matrix1.astype(np.complex64)
# ROCm does not support BLAS operations for complex types matrix2 = matrix2.astype(np.complex64)
# Complex matrix1 += 1j * matrix1
matrix1 = matrix1.astype(np.complex64) matrix2 += 1j * matrix2
matrix2 = matrix2.astype(np.complex64) self._verifySquareRootComplex(matrix1)
matrix1 += 1j * matrix1 self._verifySquareRootComplex(matrix2)
matrix2 += 1j * matrix2 self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2))
self._verifySquareRootComplex(matrix1)
self._verifySquareRootComplex(matrix2)
self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2))
def testSymmetricPositiveDefinite(self): def testSymmetricPositiveDefinite(self):
matrix1 = np.array([[2., 1.], [1., 2.]]) matrix1 = np.array([[2., 1.], [1., 2.]])

View File

@ -240,10 +240,10 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_, compute_v_):
if __name__ == "__main__": if __name__ == "__main__":
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64] dtypes_to_test = [
if not test.is_built_with_rocm(): dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
# ROCm does not support BLAS operations for complex types dtypes_lib.complex128
dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128] ]
for compute_v in True, False: for compute_v in True, False:
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
for size in 1, 2, 5, 10: for size in 1, 2, 5, 10:

View File

@ -125,7 +125,6 @@ cuda_py_tests(
srcs = ["spectral_ops_test.py"], srcs = ["spectral_ops_test.py"],
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_rocm",
"nomac", "nomac",
], ],
deps = [ deps = [

View File

@ -370,10 +370,7 @@ class SVDBenchmark(test.Benchmark):
if __name__ == "__main__": if __name__ == "__main__":
dtypes_to_test = [np.float32, np.float64] dtypes_to_test = [np.float32, np.float64, np.complex64, np.complex128]
if not test.is_built_with_rocm():
# ROCm does not support BLAS operations for complex types
dtypes_to_test += [np.complex64, np.complex128]
for compute_uv in False, True: for compute_uv in False, True:
for full_matrices in False, True: for full_matrices in False, True:
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
@ -392,7 +389,7 @@ if __name__ == "__main__":
for compute_uv in False, True: for compute_uv in False, True:
for full_matrices in False, True: for full_matrices in False, True:
dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] * dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] *
(not compute_uv) * (not test.is_built_with_rocm())) (not compute_uv))
for dtype in dtypes: for dtype in dtypes:
mat_shapes = [(10, 11), (11, 10), (11, 11), (2, 2, 2, 3)] mat_shapes = [(10, 11), (11, 10), (11, 11), (2, 2, 2, 3)]
if not full_matrices or not compute_uv: if not full_matrices or not compute_uv:

View File

@ -221,10 +221,9 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
if __name__ == "__main__": if __name__ == "__main__":
dtypes_to_test = [np.float16, np.float32, np.float64] dtypes_to_test = [
if not test_lib.is_built_with_rocm(): np.float16, np.float32, np.float64, np.complex64, np.complex128
# ROCm does not support BLAS operations for complex types ]
dtypes_to_test += [np.complex64, np.complex128]
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
for rank_a in 1, 2, 4, 5: for rank_a in 1, 2, 4, 5:
for rank_b in 1, 2, 4, 5: for rank_b in 1, 2, 4, 5:

View File

@ -562,13 +562,7 @@ class EinsumTest(test.TestCase):
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,)) self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
def test_dtypes(self): def test_dtypes(self):
dtypes = [] dtypes = [np.float64, np.float32, np.complex64, np.complex128]
if test.is_built_with_rocm():
# This test triggers the BLAS op calls on the GPU
# ROCm does not support BLAS operations for complex types
dtypes = [np.float64, np.float32]
else:
dtypes = [np.float64, np.float32, np.complex64, np.complex128]
for dtype in dtypes: for dtype in dtypes:
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype) self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype) self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype)

View File

@ -114,10 +114,10 @@ namespace wrap {
__macro(rocblas_zdotc) */ \ __macro(rocblas_zdotc) */ \
__macro(rocblas_sscal) \ __macro(rocblas_sscal) \
__macro(rocblas_dscal) \ __macro(rocblas_dscal) \
/*__macro(rocblas_cscal) \ __macro(rocblas_cscal) \
__macro(rocblas_csscal) \ __macro(rocblas_csscal) \
__macro(rocblas_zscal) \ __macro(rocblas_zscal) \
__macro(rocblas_zdscal) */ \ __macro(rocblas_zdscal) \
__macro(rocblas_saxpy) \ __macro(rocblas_saxpy) \
__macro(rocblas_daxpy) \ __macro(rocblas_daxpy) \
/*__macro(rocblas_caxpy) \ /*__macro(rocblas_caxpy) \
@ -158,9 +158,9 @@ namespace wrap {
__macro(rocblas_drotmg) */ \ __macro(rocblas_drotmg) */ \
__macro(rocblas_sgemv) \ __macro(rocblas_sgemv) \
__macro(rocblas_dgemv) \ __macro(rocblas_dgemv) \
/*__macro(rocblas_cgemv) \ __macro(rocblas_cgemv) \
__macro(rocblas_zgemv) \ __macro(rocblas_zgemv) \
__macro(rocblas_sgbmv) \ /* __macro(rocblas_sgbmv) \
__macro(rocblas_dgbmv) \ __macro(rocblas_dgbmv) \
__macro(rocblas_cgbmv) \ __macro(rocblas_cgbmv) \
__macro(rocblas_zgbmv) \ __macro(rocblas_zgbmv) \
@ -231,9 +231,9 @@ namespace wrap {
__macro(rocblas_sgemm) \ __macro(rocblas_sgemm) \
__macro(rocblas_dgemm) \ __macro(rocblas_dgemm) \
__macro(rocblas_hgemm) \ __macro(rocblas_hgemm) \
/*__macro(rocblas_cgemm) \ __macro(rocblas_cgemm) \
__macro(rocblas_zgemm) \ __macro(rocblas_zgemm) \
__macro(rocblas_ssyrk) \ /* __macro(rocblas_ssyrk) \
__macro(rocblas_dsyrk) \ __macro(rocblas_dsyrk) \
__macro(rocblas_csyrk) \ __macro(rocblas_csyrk) \
__macro(rocblas_zsyrk) \ __macro(rocblas_zsyrk) \
@ -285,12 +285,37 @@ STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched)
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched) STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched)
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched) // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched) STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched)
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_strided_batched)
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_strided_batched)
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched) // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched) // STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP) ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP)
} // namespace wrap } // namespace wrap
template <class T>
const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
const DeviceMemory<T> &a) {
return reinterpret_cast<
const typename RocBlasTypeConversionHelper<T>::mapped_type *>(
GpuMemory(a));
}
template <class T>
const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
const T &a) {
return reinterpret_cast<
const typename RocBlasTypeConversionHelper<T>::mapped_type *>(&a);
}
template <class T>
typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
DeviceMemory<T> *a) {
return reinterpret_cast<
typename RocBlasTypeConversionHelper<T>::mapped_type *>(
GpuMemoryMutable(a));
}
static void blas_log(const char *c) {}
static string ToString(rocblas_status status) { static string ToString(rocblas_status status) {
switch (status) { switch (status) {
case rocblas_status_success: case rocblas_status_success:
@ -451,6 +476,7 @@ bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
const DeviceMemory<float> &x, int incx, const DeviceMemory<float> &x, int incx,
DeviceMemory<float> *y, int incy) { DeviceMemory<float> *y, int incy) {
blas_log("DoBlasAxpy");
return DoBlasInternal(wrap::rocblas_saxpy, stream, return DoBlasInternal(wrap::rocblas_saxpy, stream,
true /* = pointer_mode_host */, elem_count, &alpha, true /* = pointer_mode_host */, elem_count, &alpha,
GpuMemory(x), incx, GpuMemoryMutable(y), incy); GpuMemory(x), incx, GpuMemoryMutable(y), incy);
@ -459,6 +485,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
const DeviceMemory<double> &x, int incx, const DeviceMemory<double> &x, int incx,
DeviceMemory<double> *y, int incy) { DeviceMemory<double> *y, int incy) {
blas_log("DoBlasAxpy");
return DoBlasInternal(wrap::rocblas_daxpy, stream, return DoBlasInternal(wrap::rocblas_daxpy, stream,
true /* = pointer_mode_host */, elem_count, &alpha, true /* = pointer_mode_host */, elem_count, &alpha,
GpuMemory(x), incx, GpuMemoryMutable(y), incy); GpuMemory(x), incx, GpuMemoryMutable(y), incy);
@ -518,6 +545,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
const DeviceMemory<float> &x, int incx, const DeviceMemory<float> &x, int incx,
const DeviceMemory<float> &y, int incy, const DeviceMemory<float> &y, int incy,
DeviceMemory<float> *result) { DeviceMemory<float> *result) {
blas_log("DoBlasDot");
return DoBlasInternal( return DoBlasInternal(
wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count, wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
@ -527,6 +555,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
const DeviceMemory<double> &x, int incx, const DeviceMemory<double> &x, int incx,
const DeviceMemory<double> &y, int incy, const DeviceMemory<double> &y, int incy,
DeviceMemory<double> *result) { DeviceMemory<double> *result) {
blas_log("DoBlasDot");
return DoBlasInternal( return DoBlasInternal(
wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count, wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count,
GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result)); GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
@ -707,6 +736,7 @@ bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
DeviceMemory<float> *x, int incx) { DeviceMemory<float> *x, int incx) {
blas_log("DoBlasScal<float>");
return DoBlasInternal(wrap::rocblas_sscal, stream, return DoBlasInternal(wrap::rocblas_sscal, stream,
true /* = pointer_mode_host */, elem_count, &alpha, true /* = pointer_mode_host */, elem_count, &alpha,
GpuMemoryMutable(x), incx); GpuMemoryMutable(x), incx);
@ -721,32 +751,32 @@ bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
DeviceMemory<std::complex<float>> *x, int incx) { DeviceMemory<std::complex<float>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " return DoBlasInternal(wrap::rocblas_csscal, stream,
<< "for the \"complex<float>\" datatype"; true /* = pointer_mode_host */, elem_count, &alpha,
return false; complex_cast(x), incx);
} }
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
DeviceMemory<std::complex<double>> *x, int incx) { DeviceMemory<std::complex<double>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " return DoBlasInternal(wrap::rocblas_zdscal, stream,
<< "for the \"complex<double>\" datatype"; true /* = pointer_mode_host */, elem_count, &alpha,
return false; complex_cast(x), incx);
} }
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
std::complex<float> alpha, std::complex<float> alpha,
DeviceMemory<std::complex<float>> *x, int incx) { DeviceMemory<std::complex<float>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " return DoBlasInternal(wrap::rocblas_cscal, stream,
<< "for the \"complex<float>\" datatype"; true /* = pointer_mode_host */, elem_count,
return false; complex_cast(alpha), complex_cast(x), incx);
} }
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
std::complex<double> alpha, std::complex<double> alpha,
DeviceMemory<std::complex<double>> *x, int incx) { DeviceMemory<std::complex<double>> *x, int incx) {
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation " return DoBlasInternal(wrap::rocblas_zscal, stream,
<< "for the \"complex<double>\" datatype"; true /* = pointer_mode_host */, elem_count,
return false; complex_cast(alpha), complex_cast(x), incx);
} }
bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
@ -893,6 +923,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, float alpha, const DeviceMemory<float> &a, uint64 n, float alpha, const DeviceMemory<float> &a,
int lda, const DeviceMemory<float> &x, int incx, int lda, const DeviceMemory<float> &x, int incx,
float beta, DeviceMemory<float> *y, int incy) { float beta, DeviceMemory<float> *y, int incy) {
blas_log("DoBlasGemv");
return DoBlasInternal( return DoBlasInternal(
wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */, wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
@ -903,6 +934,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
uint64 n, double alpha, const DeviceMemory<double> &a, uint64 n, double alpha, const DeviceMemory<double> &a,
int lda, const DeviceMemory<double> &x, int incx, int lda, const DeviceMemory<double> &x, int incx,
double beta, DeviceMemory<double> *y, int incy) { double beta, DeviceMemory<double> *y, int incy) {
blas_log("DoBlasGemv");
return DoBlasInternal( return DoBlasInternal(
wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */, wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x), ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
@ -915,9 +947,11 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
const DeviceMemory<std::complex<float>> &x, int incx, const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta, std::complex<float> beta,
DeviceMemory<std::complex<float>> *y, int incy) { DeviceMemory<std::complex<float>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the GEMV operation " blas_log("DoBlasGemv");
<< "for the \"complex<float>\" datatype"; return DoBlasInternal(
return false; wrap::rocblas_cgemv, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
} }
bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
@ -926,9 +960,11 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
const DeviceMemory<std::complex<double>> &x, int incx, const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta, std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) { DeviceMemory<std::complex<double>> *y, int incy) {
LOG(ERROR) << "rocBLAS does not currently support the GEMV operation " blas_log("DoBlasGemv\n");
<< "for the \"complex<double>\" datatype"; return DoBlasInternal(
return false; wrap::rocblas_zgemv, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
} }
bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
@ -1481,6 +1517,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
float alpha, const DeviceMemory<Eigen::half> &a, float alpha, const DeviceMemory<Eigen::half> &a,
int lda, const DeviceMemory<Eigen::half> &b, int ldb, int lda, const DeviceMemory<Eigen::half> &b, int ldb,
float beta, DeviceMemory<Eigen::half> *c, int ldc) { float beta, DeviceMemory<Eigen::half> *c, int ldc) {
blas_log("DoBlasGemm");
VLOG(1) << absl::StreamFormat( VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u " "doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
@ -1526,6 +1563,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
float alpha, const DeviceMemory<float> &a, int lda, float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta, const DeviceMemory<float> &b, int ldb, float beta,
DeviceMemory<float> *c, int ldc) { DeviceMemory<float> *c, int ldc) {
blas_log("DoBlasGemm");
VLOG(1) << absl::StreamFormat( VLOG(1) << absl::StreamFormat(
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u " "doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
@ -1565,6 +1603,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
double alpha, const DeviceMemory<double> &a, int lda, double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb, double beta, const DeviceMemory<double> &b, int ldb, double beta,
DeviceMemory<double> *c, int ldc) { DeviceMemory<double> *c, int ldc) {
blas_log("DoBlasGemm");
return DoBlasInternal( return DoBlasInternal(
wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */, wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
@ -1578,9 +1617,12 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
const DeviceMemory<std::complex<float>> &b, int ldb, const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta, std::complex<float> beta,
DeviceMemory<std::complex<float>> *c, int ldc) { DeviceMemory<std::complex<float>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the GEMM operation " blas_log("DoBlasGemm");
<< "for the \"complex<float>\" datatype"; return DoBlasInternal(
return false; wrap::rocblas_cgemm, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
complex_cast(beta), complex_cast(c), ldc);
} }
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
@ -1590,9 +1632,12 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
const DeviceMemory<std::complex<double>> &b, int ldb, const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta, std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) { DeviceMemory<std::complex<double>> *c, int ldc) {
LOG(ERROR) << "rocBLAS does not currently support the GEMM operation " blas_log("DoBlasGemm");
<< "for the \"complex<double>\" datatype"; return DoBlasInternal(
return false; wrap::rocblas_zgemm, stream, true /* = pointer_mode_host */,
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
complex_cast(beta), complex_cast(c), ldc);
} }
bool ROCMBlas::DoBlasGemvWithProfiling( bool ROCMBlas::DoBlasGemvWithProfiling(
@ -1813,6 +1858,56 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
return false; return false;
} }
// This copies from source memory: raw_ptrs[i] to target memory:
// device_memory_ptr at the interval of matrix_byte_size, or vice versa.
// The below algorithm tries to minimize the number of memcpy by consolidating
// neighboring memcpy into a single request
template <typename MAPPED_T>
port::Status ReorganizeMemory(Stream *stream,
DeviceMemory<MAPPED_T> *device_memory,
const std::vector<MAPPED_T *> &raw_ptrs,
int batch_count, uint64_t batch_stride,
bool gather) {
assert(batch_count > 0);
char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
char *src_ptr = reinterpret_cast<char *>(raw_ptrs[0]);
char *dst_ptr = device_memory_ptr;
size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
uint64_t cur_stride_size = matrix_byte_size;
for (int i = 1; i < batch_count; ++i) {
if (reinterpret_cast<char *>(raw_ptrs[i]) == src_ptr + cur_stride_size) {
cur_stride_size += matrix_byte_size;
} else {
DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
bool a_status =
gather
? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
: stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
if (!a_status) {
return port::Status(
port::error::INTERNAL,
"failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
}
src_ptr = reinterpret_cast<char *>(raw_ptrs[i]);
dst_ptr = device_memory_ptr + i * matrix_byte_size;
cur_stride_size = matrix_byte_size;
}
}
DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
bool a_status =
gather ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
: stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
if (!a_status)
return port::Status(
port::error::INTERNAL,
"failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
return port::Status::OK();
}
template <typename T> template <typename T>
port::Status ROCMBlas::AllocateStridedBuffer( port::Status ROCMBlas::AllocateStridedBuffer(
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *> const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
@ -1822,7 +1917,8 @@ port::Status ROCMBlas::AllocateStridedBuffer(
std::unique_ptr<TemporaryDeviceMemory< std::unique_ptr<TemporaryDeviceMemory<
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
*device_memory) { *device_memory,
bool copy_data, bool &reallocated) {
assert(device_memory != nullptr); assert(device_memory != nullptr);
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type; using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
@ -1843,6 +1939,7 @@ port::Status ROCMBlas::AllocateStridedBuffer(
if (!needs_allocate_strided) { if (!needs_allocate_strided) {
*device_memory = DeviceMemory<MAPPED_T>( *device_memory = DeviceMemory<MAPPED_T>(
DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size)); DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size));
reallocated = false;
return port::Status::OK(); return port::Status::OK();
} }
@ -1859,19 +1956,11 @@ port::Status ROCMBlas::AllocateStridedBuffer(
DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory()); DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory());
} }
for (int i = 0; i < batch_count; ++i) { reallocated = true;
char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
DeviceMemoryBase src_mem = DeviceMemoryBase(raw_ptrs[i], matrix_byte_size); if (copy_data)
DeviceMemoryBase target_mem = DeviceMemoryBase( return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count,
device_memory_ptr + i * matrix_byte_size, matrix_byte_size); batch_stride, true);
bool a_status =
stream->ThenMemcpy(&target_mem, src_mem, matrix_byte_size).ok();
if (!a_status) {
return port::Status(
port::error::INTERNAL,
"failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
}
}
return port::Status::OK(); return port::Status::OK();
} }
@ -1925,27 +2014,28 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal(
DeviceMemory<MAPPED_T> a; DeviceMemory<MAPPED_T> a;
// Make sure the temporary memory are in-scope before the function returns // Make sure the temporary memory are in-scope before the function returns
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp; std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp;
port::Status a_allocation_status = bool reallocated_a, reallocated_b, reallocated_c;
AllocateStridedBuffer<T>(a_raw_ptrs, batch_count, batch_stride_a, port::Status a_allocation_status = AllocateStridedBuffer<T>(
scratch_allocator, stream, &a_temp, &a); a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream,
&a_temp, &a, true, reallocated_a);
if (a_allocation_status != port::Status::OK()) { if (a_allocation_status != port::Status::OK()) {
return a_allocation_status; return a_allocation_status;
} }
DeviceMemory<MAPPED_T> b; DeviceMemory<MAPPED_T> b;
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp; std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp;
port::Status b_allocation_status = port::Status b_allocation_status = AllocateStridedBuffer<T>(
AllocateStridedBuffer<T>(b_raw_ptrs, batch_count, batch_stride_b, b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream,
scratch_allocator, stream, &b_temp, &b); &b_temp, &b, true, reallocated_b);
if (b_allocation_status != port::Status::OK()) { if (b_allocation_status != port::Status::OK()) {
return b_allocation_status; return b_allocation_status;
} }
DeviceMemory<MAPPED_T> c; DeviceMemory<MAPPED_T> c;
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp; std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp;
port::Status c_allocation_status = port::Status c_allocation_status = AllocateStridedBuffer<T>(
AllocateStridedBuffer<T>(c_raw_ptrs, batch_count, batch_stride_c, c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream,
scratch_allocator, stream, &c_temp, &c); &c_temp, &c, true, reallocated_c); // can disable copy if beta=0
if (c_allocation_status != port::Status::OK()) { if (c_allocation_status != port::Status::OK()) {
return c_allocation_status; return c_allocation_status;
} }
@ -1953,19 +2043,20 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal(
MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha); MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta); MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
bool ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */, bool ok;
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */,
m, n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
batch_stride_a, GpuMemory(b), ldb, batch_stride_b, n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc, batch_stride_a, GpuMemory(b), ldb, batch_stride_b,
batch_stride_c, batch_count); GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc,
batch_stride_c, batch_count);
if (ok) { if (!ok)
return port::Status::OK();
} else {
return port::Status(port::error::INTERNAL, return port::Status(port::error::INTERNAL,
"failed BLAS call, see log for details"); "failed BLAS call, see log for details");
} if (reallocated_c)
return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c,
false);
return port::Status::OK();
} }
bool ROCMBlas::DoBlasGemmBatched( bool ROCMBlas::DoBlasGemmBatched(
@ -1975,6 +2066,7 @@ bool ROCMBlas::DoBlasGemmBatched(
const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) { int batch_count, ScratchAllocator *scratch_allocator) {
blas_log("DoBlasGemmBatched");
const Eigen::half alpha_half(alpha); const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta); const Eigen::half beta_half(beta);
@ -1996,6 +2088,7 @@ bool ROCMBlas::DoBlasGemmBatched(
const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta, const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc, const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) { int batch_count, ScratchAllocator *scratch_allocator) {
blas_log("DoBlasGemmBatched");
port::Status status = DoBlasGemmBatchedInternal( port::Status status = DoBlasGemmBatchedInternal(
wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
@ -2013,6 +2106,7 @@ bool ROCMBlas::DoBlasGemmBatched(
const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb, const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array, double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) { int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
blas_log("DoBlasGemmBatched");
port::Status status = DoBlasGemmBatchedInternal( port::Status status = DoBlasGemmBatchedInternal(
wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
@ -2032,9 +2126,15 @@ bool ROCMBlas::DoBlasGemmBatched(
int ldb, std::complex<float> beta, int ldb, std::complex<float> beta,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array, const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) { int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation " blas_log("DoBlasGemmBatched");
<< "for the \"complex<float>\" datatype"; port::Status status = DoBlasGemmBatchedInternal(
return false; wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
scratch_allocator);
if (!status.ok()) {
LOG(ERROR) << status;
}
return status.ok();
} }
bool ROCMBlas::DoBlasGemmBatched( bool ROCMBlas::DoBlasGemmBatched(
@ -2046,9 +2146,15 @@ bool ROCMBlas::DoBlasGemmBatched(
int ldb, std::complex<double> beta, int ldb, std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array, const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
int ldc, int batch_count, ScratchAllocator *scratch_allocator) { int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation " blas_log("DoBlasGemmBatched");
<< "for the \"complex<double>\" datatype"; port::Status status = DoBlasGemmBatchedInternal(
return false; wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k,
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
scratch_allocator);
if (!status.ok()) {
LOG(ERROR) << status;
}
return status.ok();
} }
bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
@ -2296,6 +2402,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::Diagonal diag, uint64 m, uint64 n, float alpha, blas::Diagonal diag, uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &a, int lda,
DeviceMemory<float> *b, int ldb) { DeviceMemory<float> *b, int ldb) {
blas_log("DoBlasTrsm");
return DoBlasInternal( return DoBlasInternal(
wrap::rocblas_strsm, stream, true /* = pointer_mode_host */, wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
@ -2308,6 +2415,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
blas::Diagonal diag, uint64 m, uint64 n, double alpha, blas::Diagonal diag, uint64 m, uint64 n, double alpha,
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &a, int lda,
DeviceMemory<double> *b, int ldb) { DeviceMemory<double> *b, int ldb) {
blas_log("DoBlasTrsm");
return DoBlasInternal( return DoBlasInternal(
wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */, wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
@ -2336,12 +2444,14 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
<< "for the \"complex<double>\" datatype"; << "for the \"complex<double>\" datatype";
return false; return false;
} }
bool ROCMBlas::DoBlasGemmStridedBatched( bool ROCMBlas::DoBlasGemmStridedBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
int64 stride_c, int batch_count) { int64 stride_c, int batch_count) {
blas_log("DoBlasGemmStridedBatched");
const Eigen::half alpha_half(alpha); const Eigen::half alpha_half(alpha);
const Eigen::half beta_half(beta); const Eigen::half beta_half(beta);
@ -2363,6 +2473,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
int batch_count) { int batch_count) {
blas_log("DoBlasGemmStridedBatched");
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream, return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
false, /* pointer_mode_host */ false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
@ -2376,6 +2487,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
int batch_count) { int batch_count) {
blas_log("DoBlasGemmStridedBatched");
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream, return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
false, /* pointer_mode_host */ false, /* pointer_mode_host */
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,

View File

@ -45,6 +45,16 @@ struct RocBlasTypeConversionHelper<Eigen::half> {
using mapped_type = rocblas_half; using mapped_type = rocblas_half;
}; };
template <>
struct RocBlasTypeConversionHelper<std::complex<float>> {
using mapped_type = rocblas_float_complex;
};
template <>
struct RocBlasTypeConversionHelper<std::complex<double>> {
using mapped_type = rocblas_double_complex;
};
// Opaque and unique identifier for the rocBLAS plugin. // Opaque and unique identifier for the rocBLAS plugin.
extern const PluginId kRocBlasPlugin; extern const PluginId kRocBlasPlugin;
@ -121,7 +131,8 @@ class ROCMBlas : public blas::BlasSupport {
std::unique_ptr<TemporaryDeviceMemory< std::unique_ptr<TemporaryDeviceMemory<
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory, typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type> DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
*device_memory); *device_memory,
bool copy_data, bool &reallocated);
// A helper function to implement DoBlasGemmBatched interfaces for generic // A helper function to implement DoBlasGemmBatched interfaces for generic
// types. // types.