Merge pull request #35666 from ROCmSoftwarePlatform:google_upstream_rocblas_complex
PiperOrigin-RevId: 304218949 Change-Id: Ic8b3408a71502444b8caf3846d4662ae21f1c325
This commit is contained in:
commit
c58518c893
|
@ -3411,7 +3411,6 @@ tf_py_test(
|
|||
data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
"no_rocm", # flaky test
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
|
|
|
@ -262,10 +262,9 @@ class BatchMatMulBenchmark(test.Benchmark):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtypes_to_test = [np.float16, np.float32, np.float64, np.int32]
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
dtypes_to_test += [np.complex64, np.complex128]
|
||||
dtypes_to_test = [
|
||||
np.float16, np.float32, np.float64, np.int32, np.complex64, np.complex128
|
||||
]
|
||||
for dtype_ in dtypes_to_test:
|
||||
for adjoint_a_ in False, True:
|
||||
for adjoint_b_ in False, True:
|
||||
|
|
|
@ -183,10 +183,10 @@ def _GetEigTest(dtype_, shape_, compute_v_):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64]
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128]
|
||||
dtypes_to_test = [
|
||||
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
|
||||
dtypes_lib.complex128
|
||||
]
|
||||
for compute_v in True, False:
|
||||
for dtype in dtypes_to_test:
|
||||
for size in 1, 2, 5, 10:
|
||||
|
|
|
@ -49,6 +49,7 @@ def identicaltest(tc, init1, init2, shape=None):
|
|||
init2: An Initializer that generates a tensor of a given shape
|
||||
shape: Shape of the tensor to initialize or `None` to use a vector of length
|
||||
100.
|
||||
|
||||
Returns:
|
||||
True or False as determined by test.
|
||||
"""
|
||||
|
@ -75,6 +76,7 @@ def duplicated_initializer(tc, init, graph_seed, shape=None):
|
|||
graph_seed: A graph-level seed to use.
|
||||
shape: Shape of the tensor to initialize or `None` to use a vector of length
|
||||
100.
|
||||
|
||||
Returns:
|
||||
True or False as determined by test.
|
||||
"""
|
||||
|
@ -94,6 +96,7 @@ def _init_sampler(tc, init, num):
|
|||
tc: An instance of TensorFlowTestCase.
|
||||
init: An Initializer that generates a tensor of a given shape
|
||||
num: Size of 1D tensor to create.
|
||||
|
||||
Returns:
|
||||
Function to generate a random tensor.
|
||||
"""
|
||||
|
@ -187,8 +190,8 @@ class ConstantInitializersTest(test.TestCase):
|
|||
expected = list(value)
|
||||
|
||||
self._testNDimConstantInitializer("list", value, shape, expected)
|
||||
self._testNDimConstantInitializer("ndarray",
|
||||
np.asarray(value), shape, expected)
|
||||
self._testNDimConstantInitializer("ndarray", np.asarray(value), shape,
|
||||
expected)
|
||||
self._testNDimConstantInitializer("2D-ndarray",
|
||||
np.asarray(value).reshape(tuple(shape)),
|
||||
shape, expected)
|
||||
|
@ -214,11 +217,11 @@ class ConstantInitializersTest(test.TestCase):
|
|||
expected = list(value)
|
||||
|
||||
self._testNDimConstantInitializerLessValues("list", value, shape, expected)
|
||||
self._testNDimConstantInitializerLessValues("ndarray",
|
||||
np.asarray(value), shape,
|
||||
expected)
|
||||
self._testNDimConstantInitializerLessValues("ndarray", np.asarray(value),
|
||||
shape, expected)
|
||||
self._testNDimConstantInitializerLessValues(
|
||||
"2D-ndarray", np.asarray(value).reshape(tuple([2, 3])), shape, expected)
|
||||
"2D-ndarray",
|
||||
np.asarray(value).reshape(tuple([2, 3])), shape, expected)
|
||||
|
||||
def _testNDimConstantInitializerMoreValues(self, value, shape):
|
||||
ops.reset_default_graph()
|
||||
|
@ -242,8 +245,8 @@ class ConstantInitializersTest(test.TestCase):
|
|||
|
||||
def testInvalidValueTypeForConstantInitializerCausesTypeError(self):
|
||||
c = constant_op.constant([1.0, 2.0, 3.0])
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, r"Invalid type for initial value: .*Tensor.*"):
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
r"Invalid type for initial value: .*Tensor.*"):
|
||||
init_ops.constant_initializer(c, dtype=dtypes.float32)
|
||||
v = variables.Variable([3.0, 2.0, 1.0])
|
||||
with self.assertRaisesRegexp(
|
||||
|
@ -393,11 +396,11 @@ class VarianceScalingInitializationTest(test.TestCase):
|
|||
expect_mean = 0.
|
||||
expect_var = 1. / shape[0]
|
||||
init = init_ops.variance_scaling_initializer(
|
||||
distribution='truncated_normal')
|
||||
distribution="truncated_normal")
|
||||
|
||||
with self.session(use_gpu=True), \
|
||||
test.mock.patch.object(
|
||||
random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \
|
||||
random_ops, "truncated_normal", wraps=random_ops.truncated_normal) \
|
||||
as mock_truncated_normal:
|
||||
x = init(shape).eval()
|
||||
self.assertTrue(mock_truncated_normal.called)
|
||||
|
@ -410,11 +413,11 @@ class VarianceScalingInitializationTest(test.TestCase):
|
|||
shape = [100, 100]
|
||||
expect_mean = 0.
|
||||
expect_var = 1. / shape[0]
|
||||
init = init_ops.variance_scaling_initializer(distribution='normal')
|
||||
init = init_ops.variance_scaling_initializer(distribution="normal")
|
||||
|
||||
with self.session(use_gpu=True), \
|
||||
test.mock.patch.object(
|
||||
random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \
|
||||
random_ops, "truncated_normal", wraps=random_ops.truncated_normal) \
|
||||
as mock_truncated_normal:
|
||||
x = init(shape).eval()
|
||||
self.assertTrue(mock_truncated_normal.called)
|
||||
|
@ -428,11 +431,11 @@ class VarianceScalingInitializationTest(test.TestCase):
|
|||
expect_mean = 0.
|
||||
expect_var = 1. / shape[0]
|
||||
init = init_ops.variance_scaling_initializer(
|
||||
distribution='untruncated_normal')
|
||||
distribution="untruncated_normal")
|
||||
|
||||
with self.session(use_gpu=True), \
|
||||
test.mock.patch.object(
|
||||
random_ops, 'random_normal', wraps=random_ops.random_normal) \
|
||||
random_ops, "random_normal", wraps=random_ops.random_normal) \
|
||||
as mock_random_normal:
|
||||
x = init(shape).eval()
|
||||
self.assertTrue(mock_random_normal.called)
|
||||
|
@ -445,7 +448,7 @@ class VarianceScalingInitializationTest(test.TestCase):
|
|||
shape = [100, 100]
|
||||
expect_mean = 0.
|
||||
expect_var = 1. / shape[0]
|
||||
init = init_ops.variance_scaling_initializer(distribution='uniform')
|
||||
init = init_ops.variance_scaling_initializer(distribution="uniform")
|
||||
|
||||
with self.session(use_gpu=True):
|
||||
x = init(shape).eval()
|
||||
|
@ -525,17 +528,13 @@ class RangeTest(test.TestCase):
|
|||
math_ops.range(zero_float64, zero_int32, 1).dtype, dtypes.float64)
|
||||
|
||||
self.assertEqual(
|
||||
math_ops.range(
|
||||
0, 0, 1, dtype=dtypes.int32).dtype, dtypes.int32)
|
||||
math_ops.range(0, 0, 1, dtype=dtypes.int32).dtype, dtypes.int32)
|
||||
self.assertEqual(
|
||||
math_ops.range(
|
||||
0, 0, 1, dtype=dtypes.int64).dtype, dtypes.int64)
|
||||
math_ops.range(0, 0, 1, dtype=dtypes.int64).dtype, dtypes.int64)
|
||||
self.assertEqual(
|
||||
math_ops.range(
|
||||
0, 0, 1, dtype=dtypes.float32).dtype, dtypes.float32)
|
||||
math_ops.range(0, 0, 1, dtype=dtypes.float32).dtype, dtypes.float32)
|
||||
self.assertEqual(
|
||||
math_ops.range(
|
||||
0, 0, 1, dtype=dtypes.float64).dtype, dtypes.float64)
|
||||
math_ops.range(0, 0, 1, dtype=dtypes.float64).dtype, dtypes.float64)
|
||||
|
||||
def testMixedDType(self):
|
||||
# Test case for GitHub issue 35710
|
||||
|
@ -578,8 +577,8 @@ class LinSpaceTest(test.TestCase):
|
|||
self.assertArrayNear(
|
||||
self._LinSpace(-1., -5., 3), np.array([-1., -3., -5.]), 1e-5)
|
||||
self.assertArrayNear(
|
||||
self._LinSpace(-1., -5., 4),
|
||||
np.array([-1., -7. / 3., -11. / 3., -5.]), 1e-5)
|
||||
self._LinSpace(-1., -5., 4), np.array([-1., -7. / 3., -11. / 3.,
|
||||
-5.]), 1e-5)
|
||||
|
||||
def testNegativeToPositive(self):
|
||||
for self.force_gpu in self._gpu_modes():
|
||||
|
@ -859,7 +858,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
|
|||
|
||||
def testInvalidDataType(self):
|
||||
self.assertRaises(
|
||||
ValueError, init_ops.convolutional_delta_orthogonal,
|
||||
ValueError,
|
||||
init_ops.convolutional_delta_orthogonal,
|
||||
dtype=dtypes.string)
|
||||
|
||||
def testInvalidShape(self):
|
||||
|
@ -872,8 +872,8 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
|
|||
shape = (3, 3, 10, 10)
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype)
|
||||
init2 = init_ops.convolutional_delta_orthogonal(gain=3.14,
|
||||
seed=1, dtype=dtype)
|
||||
init2 = init_ops.convolutional_delta_orthogonal(
|
||||
gain=3.14, seed=1, dtype=dtype)
|
||||
with self.session(graph=ops.Graph(), use_gpu=True):
|
||||
t1 = init1(shape).eval()
|
||||
t2 = init2(shape).eval()
|
||||
|
@ -896,18 +896,14 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
|
|||
else:
|
||||
shape = [4, 16, 16, 16, 64]
|
||||
convolution = convolutional.conv3d
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# This subtest triggers a known bug in ROCm runtime code
|
||||
# The bug has been fixed and will be available in ROCm 2.7
|
||||
# Re-enable this test once ROCm 2.7 is released
|
||||
continue
|
||||
|
||||
inputs = random_ops.random_normal(shape, dtype=dtype)
|
||||
inputs_2norm = linalg_ops.norm(inputs)
|
||||
outputs = convolution(
|
||||
inputs, padding="same", filters=128,
|
||||
kernel_size=kernel_size, use_bias=False,
|
||||
inputs,
|
||||
padding="same",
|
||||
filters=128,
|
||||
kernel_size=kernel_size,
|
||||
use_bias=False,
|
||||
kernel_initializer=init_ops.convolutional_delta_orthogonal(
|
||||
gain=gain))
|
||||
outputs_shape = shape[0:-1] + [128]
|
||||
|
@ -931,9 +927,10 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
|
|||
tol = 1e-5
|
||||
with self.session(use_gpu=True):
|
||||
for i in range(count):
|
||||
x = variable_scope.get_variable("{}".format(i), shape=shape,
|
||||
initializer=
|
||||
init_ops.convolutional_delta_orthogonal)
|
||||
x = variable_scope.get_variable(
|
||||
"{}".format(i),
|
||||
shape=shape,
|
||||
initializer=init_ops.convolutional_delta_orthogonal)
|
||||
x.initializer.run()
|
||||
y = self.evaluate(x)[1, 1, :, :]
|
||||
determinant = np.linalg.det(y)
|
||||
|
@ -971,8 +968,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
|
|||
|
||||
def testInvalidDataType(self):
|
||||
self.assertRaises(
|
||||
ValueError, init_ops.convolutional_orthogonal_1d,
|
||||
dtype=dtypes.string)
|
||||
ValueError, init_ops.convolutional_orthogonal_1d, dtype=dtypes.string)
|
||||
|
||||
def testInvalidShape(self):
|
||||
init1 = init_ops.convolutional_orthogonal_1d()
|
||||
|
@ -984,8 +980,8 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
|
|||
shape = (3, 10, 10)
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype)
|
||||
init2 = init_ops.convolutional_orthogonal_1d(gain=3.14,
|
||||
seed=1, dtype=dtype)
|
||||
init2 = init_ops.convolutional_orthogonal_1d(
|
||||
gain=3.14, seed=1, dtype=dtype)
|
||||
with self.session(graph=ops.Graph(), use_gpu=True):
|
||||
t1 = init1(shape).eval()
|
||||
t2 = init2(shape).eval()
|
||||
|
@ -1000,9 +996,10 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
|
|||
tol = 1e-5
|
||||
with self.session(use_gpu=True):
|
||||
for i in range(count):
|
||||
x = variable_scope.get_variable("{}".format(i), shape=shape,
|
||||
initializer=
|
||||
init_ops.convolutional_orthogonal_1d)
|
||||
x = variable_scope.get_variable(
|
||||
"{}".format(i),
|
||||
shape=shape,
|
||||
initializer=init_ops.convolutional_orthogonal_1d)
|
||||
x.initializer.run()
|
||||
y = np.sum(x.eval(), axis=0)
|
||||
determinant = np.linalg.det(y)
|
||||
|
@ -1018,6 +1015,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapesValues(self):
|
||||
|
||||
def circular_pad(input_, width, kernel_size):
|
||||
"""Pad input_ for computing (circular) convolution.
|
||||
|
||||
|
@ -1025,6 +1023,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
|
|||
input_: the input tensor
|
||||
width: the width of the tensor.
|
||||
kernel_size: the kernel size of the filter.
|
||||
|
||||
Returns:
|
||||
a tensor whose width is (width + kernel_size - 1).
|
||||
"""
|
||||
|
@ -1053,8 +1052,11 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
|
|||
inputs_2norm = linalg_ops.norm(inputs)
|
||||
input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
|
||||
outputs = convolution(
|
||||
input_with_circular_pad, padding="valid", filters=cout,
|
||||
kernel_size=kernel_size[0], use_bias=False,
|
||||
input_with_circular_pad,
|
||||
padding="valid",
|
||||
filters=cout,
|
||||
kernel_size=kernel_size[0],
|
||||
use_bias=False,
|
||||
kernel_initializer=init_ops.convolutional_orthogonal_1d(gain=gain))
|
||||
outputs_2norm = linalg_ops.norm(outputs)
|
||||
ratio = outputs_2norm / inputs_2norm
|
||||
|
@ -1091,8 +1093,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
|
|||
|
||||
def testInvalidDataType(self):
|
||||
self.assertRaises(
|
||||
ValueError, init_ops.convolutional_orthogonal_2d,
|
||||
dtype=dtypes.string)
|
||||
ValueError, init_ops.convolutional_orthogonal_2d, dtype=dtypes.string)
|
||||
|
||||
def testInvalidShape(self):
|
||||
init1 = init_ops.convolutional_orthogonal_2d()
|
||||
|
@ -1104,8 +1105,8 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
|
|||
shape = (3, 3, 10, 10)
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype)
|
||||
init2 = init_ops.convolutional_orthogonal_2d(gain=3.14,
|
||||
seed=1, dtype=dtype)
|
||||
init2 = init_ops.convolutional_orthogonal_2d(
|
||||
gain=3.14, seed=1, dtype=dtype)
|
||||
with self.session(graph=ops.Graph(), use_gpu=True):
|
||||
t1 = init1(shape).eval()
|
||||
t2 = init2(shape).eval()
|
||||
|
@ -1113,6 +1114,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapesValues(self):
|
||||
|
||||
def circular_pad(input_, width, kernel_size):
|
||||
"""Pad input_ for computing (circular) convolution.
|
||||
|
||||
|
@ -1120,6 +1122,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
|
|||
input_: the input tensor
|
||||
width: the width of the tensor.
|
||||
kernel_size: the kernel size of the filter.
|
||||
|
||||
Returns:
|
||||
a tensor whose width is (width + kernel_size - 1).
|
||||
"""
|
||||
|
@ -1153,8 +1156,11 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
|
|||
inputs_2norm = linalg_ops.norm(inputs)
|
||||
input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
|
||||
outputs = convolution(
|
||||
input_with_circular_pad, padding="valid", filters=cout,
|
||||
kernel_size=kernel_size, use_bias=False,
|
||||
input_with_circular_pad,
|
||||
padding="valid",
|
||||
filters=cout,
|
||||
kernel_size=kernel_size,
|
||||
use_bias=False,
|
||||
kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain))
|
||||
outputs_2norm = linalg_ops.norm(outputs)
|
||||
ratio = outputs_2norm / inputs_2norm
|
||||
|
@ -1191,8 +1197,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
|
|||
|
||||
def testInvalidDataType(self):
|
||||
self.assertRaises(
|
||||
ValueError, init_ops.convolutional_orthogonal_3d,
|
||||
dtype=dtypes.string)
|
||||
ValueError, init_ops.convolutional_orthogonal_3d, dtype=dtypes.string)
|
||||
|
||||
def testInvalidShape(self):
|
||||
init1 = init_ops.convolutional_orthogonal_3d()
|
||||
|
@ -1204,8 +1209,8 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
|
|||
shape = (3, 3, 3, 10, 10)
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype)
|
||||
init2 = init_ops.convolutional_orthogonal_3d(gain=3.14,
|
||||
seed=1, dtype=dtype)
|
||||
init2 = init_ops.convolutional_orthogonal_3d(
|
||||
gain=3.14, seed=1, dtype=dtype)
|
||||
with self.session(graph=ops.Graph(), use_gpu=True):
|
||||
t1 = init1(shape).eval()
|
||||
t2 = init2(shape).eval()
|
||||
|
@ -1220,9 +1225,10 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
|
|||
tol = 1e-5
|
||||
with self.session(use_gpu=True):
|
||||
for i in range(count):
|
||||
x = variable_scope.get_variable("{}".format(i), shape=shape,
|
||||
initializer=
|
||||
init_ops.convolutional_orthogonal_3d)
|
||||
x = variable_scope.get_variable(
|
||||
"{}".format(i),
|
||||
shape=shape,
|
||||
initializer=init_ops.convolutional_orthogonal_3d)
|
||||
x.initializer.run()
|
||||
y = np.sum(x.eval(), axis=(0, 1, 2))
|
||||
determinant = np.linalg.det(y)
|
||||
|
@ -1238,6 +1244,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapesValues(self):
|
||||
|
||||
def circular_pad(input_, width, kernel_size):
|
||||
"""Padding input_ for computing circular convolution.
|
||||
|
||||
|
@ -1255,14 +1262,12 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
|
|||
|
||||
tmp_up = array_ops.slice(input_, [0, width - beginning, 0, 0, 0],
|
||||
[-1, beginning, -1, -1, -1])
|
||||
tmp_down = array_ops.slice(input_, [0, 0, 0, 0, 0],
|
||||
[-1, end, -1, -1, -1])
|
||||
tmp_down = array_ops.slice(input_, [0, 0, 0, 0, 0], [-1, end, -1, -1, -1])
|
||||
tmp = array_ops.concat([tmp_up, input_, tmp_down], 1)
|
||||
|
||||
tmp_left = array_ops.slice(tmp, [0, 0, width - beginning, 0, 0],
|
||||
[-1, -1, beginning, -1, -1])
|
||||
tmp_right = array_ops.slice(tmp, [0, 0, 0, 0, 0],
|
||||
[-1, -1, end, -1, -1])
|
||||
tmp_right = array_ops.slice(tmp, [0, 0, 0, 0, 0], [-1, -1, end, -1, -1])
|
||||
tmp = array_ops.concat([tmp_left, tmp, tmp_right], 2)
|
||||
|
||||
tmp_front = array_ops.slice(tmp, [0, 0, 0, width - beginning, 0],
|
||||
|
@ -1284,8 +1289,11 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
|
|||
inputs_2norm = linalg_ops.norm(inputs)
|
||||
input_with_circular_pad = circular_pad(inputs, shape[1], kernel_size[0])
|
||||
outputs = convolution(
|
||||
input_with_circular_pad, padding="valid", filters=cout,
|
||||
kernel_size=kernel_size[0], use_bias=False,
|
||||
input_with_circular_pad,
|
||||
padding="valid",
|
||||
filters=cout,
|
||||
kernel_size=kernel_size[0],
|
||||
use_bias=False,
|
||||
kernel_initializer=init_ops.convolutional_orthogonal_3d(gain=gain))
|
||||
outputs_2norm = linalg_ops.norm(outputs)
|
||||
ratio = outputs_2norm / inputs_2norm
|
||||
|
|
|
@ -141,8 +141,6 @@ class LinearOperatorAdjointTest(
|
|||
full_matrix2, adjoint=True, adjoint_arg=True).to_dense()))
|
||||
|
||||
def test_matmul_adjoint_complex_operator(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||
matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4)
|
||||
full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1)
|
||||
|
@ -201,7 +199,8 @@ class LinearOperatorAdjointTest(
|
|||
|
||||
def test_solve_adjoint_complex_operator(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
self.skipTest("ROCm does not support BLAS solve operations"
|
||||
" for complex types")
|
||||
matrix1 = self.evaluate(linear_operator_test_util.random_tril_matrix(
|
||||
[4, 4], dtype=dtypes.complex128, force_well_conditioned=True) +
|
||||
1j * linear_operator_test_util.random_tril_matrix(
|
||||
|
|
|
@ -114,8 +114,10 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
|
|||
# real, the matrix will not be real.
|
||||
return [dtypes.complex64, dtypes.complex128]
|
||||
|
||||
def operator_and_matrix(
|
||||
self, shape_info, dtype, use_placeholder,
|
||||
def operator_and_matrix(self,
|
||||
shape_info,
|
||||
dtype,
|
||||
use_placeholder,
|
||||
ensure_self_adjoint_and_pd=False):
|
||||
shape = shape_info.shape
|
||||
# For this test class, we are creating real spectrums.
|
||||
|
@ -123,9 +125,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
|
|||
#
|
||||
# spectrum is bounded away from zero.
|
||||
spectrum = linear_operator_test_util.random_sign_uniform(
|
||||
shape=self._shape_to_spectrum_shape(shape),
|
||||
minval=1.,
|
||||
maxval=2.)
|
||||
shape=self._shape_to_spectrum_shape(shape), minval=1., maxval=2.)
|
||||
if ensure_self_adjoint_and_pd:
|
||||
spectrum = math_ops.abs(spectrum)
|
||||
# If dtype is complex, cast spectrum to complex. The imaginary part will be
|
||||
|
@ -176,8 +176,10 @@ class LinearOperatorCirculantTestHermitianSpectrum(
|
|||
zero imaginary part.
|
||||
"""
|
||||
|
||||
def operator_and_matrix(
|
||||
self, shape_info, dtype, use_placeholder,
|
||||
def operator_and_matrix(self,
|
||||
shape_info,
|
||||
dtype,
|
||||
use_placeholder,
|
||||
ensure_self_adjoint_and_pd=False):
|
||||
shape = shape_info.shape
|
||||
# For this test class, we are creating Hermitian spectrums.
|
||||
|
@ -259,8 +261,10 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
|
|||
def skip_these_tests():
|
||||
return ["cholesky", "eigvalsh"]
|
||||
|
||||
def operator_and_matrix(
|
||||
self, shape_info, dtype, use_placeholder,
|
||||
def operator_and_matrix(self,
|
||||
shape_info,
|
||||
dtype,
|
||||
use_placeholder,
|
||||
ensure_self_adjoint_and_pd=False):
|
||||
del ensure_self_adjoint_and_pd
|
||||
shape = shape_info.shape
|
||||
|
@ -357,11 +361,6 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
|
|||
self.evaluate(operator.assert_non_singular())
|
||||
|
||||
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# ROCm does not yet support BLAS operations with complex types.
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
|
||||
spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64)
|
||||
operator = linalg.LinearOperatorCirculant(spectrum)
|
||||
with self.cached_session():
|
||||
|
@ -486,8 +485,10 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
|
|||
def skip_these_tests():
|
||||
return ["cond"]
|
||||
|
||||
def operator_and_matrix(
|
||||
self, shape_info, dtype, use_placeholder,
|
||||
def operator_and_matrix(self,
|
||||
shape_info,
|
||||
dtype,
|
||||
use_placeholder,
|
||||
ensure_self_adjoint_and_pd=False):
|
||||
shape = shape_info.shape
|
||||
# For this test class, we are creating Hermitian spectrums.
|
||||
|
@ -547,8 +548,10 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
|
|||
def skip_these_tests():
|
||||
return ["cholesky", "eigvalsh"]
|
||||
|
||||
def operator_and_matrix(
|
||||
self, shape_info, dtype, use_placeholder,
|
||||
def operator_and_matrix(self,
|
||||
shape_info,
|
||||
dtype,
|
||||
use_placeholder,
|
||||
ensure_self_adjoint_and_pd=False):
|
||||
del ensure_self_adjoint_and_pd
|
||||
shape = shape_info.shape
|
||||
|
@ -665,11 +668,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
|
|||
yield sess
|
||||
|
||||
def test_real_spectrum_gives_self_adjoint_operator(self):
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# ROCm does not yet support BLAS operations with complext types
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
|
||||
with self.cached_session():
|
||||
# This is a real and hermitian spectrum.
|
||||
spectrum = linear_operator_test_util.random_normal(
|
||||
|
@ -686,11 +684,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
|
|||
self.assertAllClose(matrix, matrix_h)
|
||||
|
||||
def test_defining_operator_using_real_convolution_kernel(self):
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# ROCm does not yet support BLAS operations with complext types
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
|
||||
with self.cached_session():
|
||||
convolution_kernel = linear_operator_test_util.random_normal(
|
||||
shape=(2, 2, 3, 5), dtype=dtypes.float32)
|
||||
|
@ -709,11 +702,6 @@ class LinearOperatorCirculant3DTest(test.TestCase):
|
|||
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5)
|
||||
|
||||
def test_defining_spd_operator_by_taking_real_part(self):
|
||||
|
||||
if test.is_built_with_rocm():
|
||||
# ROCm does not yet support BLAS operations with complext types
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
|
||||
with self.cached_session(): # Necessary for fft_kernel_label_map
|
||||
# S is real and positive.
|
||||
s = linear_operator_test_util.random_uniform(
|
||||
|
|
|
@ -130,8 +130,6 @@ class LuOpTest(test.TestCase):
|
|||
for output_idx_type in (dtypes.int32, dtypes.int64):
|
||||
self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type)
|
||||
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
for dtype in (np.complex64, np.complex128):
|
||||
for output_idx_type in (dtypes.int32, dtypes.int64):
|
||||
complex_data = np.tril(1j * data, -1).astype(dtype)
|
||||
|
@ -152,8 +150,6 @@ class LuOpTest(test.TestCase):
|
|||
# Make sure p_val is not the identity permutation.
|
||||
self.assertNotAllClose(np.arange(3), p_val)
|
||||
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
for dtype in (np.complex64, np.complex128):
|
||||
complex_data = np.tril(1j * data, -1).astype(dtype)
|
||||
complex_data += np.triu(-1j * data, 1).astype(dtype)
|
||||
|
@ -195,8 +191,6 @@ class LuOpTest(test.TestCase):
|
|||
matrices = np.random.rand(batch_size, 5, 5)
|
||||
self._verifyLu(matrices)
|
||||
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
# Generate random complex valued matrices.
|
||||
np.random.seed(52)
|
||||
matrices = np.random.rand(batch_size, 5,
|
||||
|
@ -210,8 +204,6 @@ class LuOpTest(test.TestCase):
|
|||
data = np.random.rand(n, n)
|
||||
self._verifyLu(data)
|
||||
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
# Generate random complex valued matrices.
|
||||
np.random.seed(129)
|
||||
data = np.random.rand(n, n) + 1j * np.random.rand(n, n)
|
||||
|
|
|
@ -226,10 +226,10 @@ class MatMulInfixOperatorTest(test_lib.TestCase):
|
|||
if __name__ == "__main__":
|
||||
sizes = [1, 3, 5]
|
||||
trans_options = [[False, False], [True, False], [False, True]]
|
||||
dtypes_to_test = [np.int32, np.int64, np.float16, np.float32, np.float64]
|
||||
if not test_lib.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
dtypes_to_test += [np.complex64, np.complex128]
|
||||
dtypes_to_test = [
|
||||
np.int32, np.int64, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128
|
||||
]
|
||||
# TF2 does not support placeholders under eager so we skip it
|
||||
for use_static_shape in set([True, tf2.enabled()]):
|
||||
for dtype in dtypes_to_test:
|
||||
|
|
|
@ -91,8 +91,6 @@ class ExponentialOpTest(test.TestCase):
|
|||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonsymmetricComplex(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
matrix1 = np.array([[1., 2.], [3., 4.]])
|
||||
matrix2 = np.array([[1., 3.], [3., 5.]])
|
||||
matrix1 = matrix1.astype(np.complex64)
|
||||
|
@ -114,8 +112,6 @@ class ExponentialOpTest(test.TestCase):
|
|||
self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
|
||||
|
||||
def testSymmetricPositiveDefiniteComplex(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
matrix1 = np.array([[2., 1.], [1., 2.]])
|
||||
matrix2 = np.array([[3., -1.], [-1., 3.]])
|
||||
matrix1 = matrix1.astype(np.complex64)
|
||||
|
@ -185,8 +181,8 @@ class MatrixExponentialBenchmark(test.Benchmark):
|
|||
shape = shape[-2:]
|
||||
assert shape[0] == shape[1]
|
||||
n = shape[0]
|
||||
matrix = np.ones(shape).astype(np.float32) / (
|
||||
2.0 * n) + np.diag(np.ones(n).astype(np.float32))
|
||||
matrix = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag(
|
||||
np.ones(n).astype(np.float32))
|
||||
return variables.Variable(np.tile(matrix, batch_shape + (1, 1)))
|
||||
|
||||
def benchmarkMatrixExponentialOp(self):
|
||||
|
@ -201,8 +197,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
|
|||
sess,
|
||||
control_flow_ops.group(expm),
|
||||
min_iters=25,
|
||||
name="matrix_exponential_cpu_{shape}".format(
|
||||
shape=shape))
|
||||
name="matrix_exponential_cpu_{shape}".format(shape=shape))
|
||||
|
||||
if test.is_gpu_available(True):
|
||||
with ops.Graph().as_default(), \
|
||||
|
@ -215,8 +210,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
|
|||
sess,
|
||||
control_flow_ops.group(expm),
|
||||
min_iters=25,
|
||||
name="matrix_exponential_gpu_{shape}".format(
|
||||
shape=shape))
|
||||
name="matrix_exponential_gpu_{shape}".format(shape=shape))
|
||||
|
||||
|
||||
def _TestRandomSmall(dtype, batch_dims, size):
|
||||
|
@ -224,9 +218,7 @@ def _TestRandomSmall(dtype, batch_dims, size):
|
|||
def Test(self):
|
||||
np.random.seed(42)
|
||||
shape = batch_dims + (size, size)
|
||||
matrix = np.random.uniform(
|
||||
low=-1.0, high=1.0,
|
||||
size=shape).astype(dtype)
|
||||
matrix = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype)
|
||||
self._verifyExponentialReal(matrix)
|
||||
|
||||
return Test
|
||||
|
@ -237,8 +229,7 @@ def _TestL1Norms(dtype, shape, scale):
|
|||
def Test(self):
|
||||
np.random.seed(42)
|
||||
matrix = np.random.uniform(
|
||||
low=-1.0, high=1.0,
|
||||
size=np.prod(shape)).reshape(shape).astype(dtype)
|
||||
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
|
||||
print(dtype, shape, scale, matrix)
|
||||
l1_norm = np.max(np.sum(np.abs(matrix), axis=matrix.ndim - 2))
|
||||
matrix /= l1_norm
|
||||
|
|
|
@ -74,9 +74,6 @@ class InverseOpTest(test.TestCase):
|
|||
self._verifyInverseReal(matrix2)
|
||||
# A multidimensional batch of 2x2 matrices
|
||||
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
# Complex
|
||||
matrix1 = matrix1.astype(np.complex64)
|
||||
matrix1 += 1j * matrix1
|
||||
matrix2 = matrix2.astype(np.complex64)
|
||||
|
@ -94,9 +91,6 @@ class InverseOpTest(test.TestCase):
|
|||
self._verifyInverseReal(matrix2)
|
||||
# A multidimensional batch of 2x2 matrices
|
||||
self._verifyInverseReal(self._makeBatch(matrix1, matrix2))
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
# Complex
|
||||
matrix1 = matrix1.astype(np.complex64)
|
||||
matrix1 += 1j * matrix1
|
||||
matrix2 = matrix2.astype(np.complex64)
|
||||
|
|
|
@ -59,8 +59,6 @@ class LogarithmOpTest(test.TestCase):
|
|||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testNonsymmetric(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
# 2x2 matrices
|
||||
matrix1 = np.array([[1., 2.], [3., 4.]])
|
||||
matrix2 = np.array([[1., 3.], [3., 5.]])
|
||||
|
@ -75,8 +73,6 @@ class LogarithmOpTest(test.TestCase):
|
|||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testSymmetricPositiveDefinite(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
# 2x2 matrices
|
||||
matrix1 = np.array([[2., 1.], [1., 2.]])
|
||||
matrix2 = np.array([[3., -1.], [-1., 3.]])
|
||||
|
@ -111,8 +107,6 @@ class LogarithmOpTest(test.TestCase):
|
|||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRandomSmallAndLargeComplex64(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
np.random.seed(42)
|
||||
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
||||
for size in 8, 31, 32:
|
||||
|
@ -124,8 +118,6 @@ class LogarithmOpTest(test.TestCase):
|
|||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRandomSmallAndLargeComplex128(self):
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest("ROCm does not support BLAS operations for complex types")
|
||||
np.random.seed(42)
|
||||
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
||||
for size in 8, 31, 32:
|
||||
|
@ -169,8 +161,8 @@ class MatrixLogarithmBenchmark(test.Benchmark):
|
|||
shape = shape[-2:]
|
||||
assert shape[0] == shape[1]
|
||||
n = shape[0]
|
||||
matrix = np.ones(shape).astype(np.complex64) / (
|
||||
2.0 * n) + np.diag(np.ones(n).astype(np.complex64))
|
||||
matrix = np.ones(shape).astype(np.complex64) / (2.0 * n) + np.diag(
|
||||
np.ones(n).astype(np.complex64))
|
||||
return variables.Variable(np.tile(matrix, batch_shape + (1, 1)))
|
||||
|
||||
def benchmarkMatrixLogarithmOp(self):
|
||||
|
@ -185,8 +177,7 @@ class MatrixLogarithmBenchmark(test.Benchmark):
|
|||
sess,
|
||||
control_flow_ops.group(logm),
|
||||
min_iters=25,
|
||||
name="matrix_logarithm_cpu_{shape}".format(
|
||||
shape=shape))
|
||||
name="matrix_logarithm_cpu_{shape}".format(shape=shape))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -59,9 +59,6 @@ class SquareRootOpTest(test.TestCase):
|
|||
self._verifySquareRootReal(matrix1)
|
||||
self._verifySquareRootReal(matrix2)
|
||||
self._verifySquareRootReal(self._makeBatch(matrix1, matrix2))
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
# Complex
|
||||
matrix1 = matrix1.astype(np.complex64)
|
||||
matrix2 = matrix2.astype(np.complex64)
|
||||
matrix1 += 1j * matrix1
|
||||
|
|
|
@ -240,10 +240,10 @@ def _GetSelfAdjointEigGradTest(dtype_, shape_, compute_v_):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64]
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128]
|
||||
dtypes_to_test = [
|
||||
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64,
|
||||
dtypes_lib.complex128
|
||||
]
|
||||
for compute_v in True, False:
|
||||
for dtype in dtypes_to_test:
|
||||
for size in 1, 2, 5, 10:
|
||||
|
|
|
@ -125,7 +125,6 @@ cuda_py_tests(
|
|||
srcs = ["spectral_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
|
|
|
@ -370,10 +370,7 @@ class SVDBenchmark(test.Benchmark):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtypes_to_test = [np.float32, np.float64]
|
||||
if not test.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
dtypes_to_test += [np.complex64, np.complex128]
|
||||
dtypes_to_test = [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for compute_uv in False, True:
|
||||
for full_matrices in False, True:
|
||||
for dtype in dtypes_to_test:
|
||||
|
@ -392,7 +389,7 @@ if __name__ == "__main__":
|
|||
for compute_uv in False, True:
|
||||
for full_matrices in False, True:
|
||||
dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] *
|
||||
(not compute_uv) * (not test.is_built_with_rocm()))
|
||||
(not compute_uv))
|
||||
for dtype in dtypes:
|
||||
mat_shapes = [(10, 11), (11, 10), (11, 11), (2, 2, 2, 3)]
|
||||
if not full_matrices or not compute_uv:
|
||||
|
|
|
@ -221,10 +221,9 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtypes_to_test = [np.float16, np.float32, np.float64]
|
||||
if not test_lib.is_built_with_rocm():
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
dtypes_to_test += [np.complex64, np.complex128]
|
||||
dtypes_to_test = [
|
||||
np.float16, np.float32, np.float64, np.complex64, np.complex128
|
||||
]
|
||||
for dtype in dtypes_to_test:
|
||||
for rank_a in 1, 2, 4, 5:
|
||||
for rank_b in 1, 2, 4, 5:
|
||||
|
|
|
@ -562,12 +562,6 @@ class EinsumTest(test.TestCase):
|
|||
self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
|
||||
|
||||
def test_dtypes(self):
|
||||
dtypes = []
|
||||
if test.is_built_with_rocm():
|
||||
# This test triggers the BLAS op calls on the GPU
|
||||
# ROCm does not support BLAS operations for complex types
|
||||
dtypes = [np.float64, np.float32]
|
||||
else:
|
||||
dtypes = [np.float64, np.float32, np.complex64, np.complex128]
|
||||
for dtype in dtypes:
|
||||
self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
|
||||
|
|
|
@ -114,10 +114,10 @@ namespace wrap {
|
|||
__macro(rocblas_zdotc) */ \
|
||||
__macro(rocblas_sscal) \
|
||||
__macro(rocblas_dscal) \
|
||||
/*__macro(rocblas_cscal) \
|
||||
__macro(rocblas_cscal) \
|
||||
__macro(rocblas_csscal) \
|
||||
__macro(rocblas_zscal) \
|
||||
__macro(rocblas_zdscal) */ \
|
||||
__macro(rocblas_zdscal) \
|
||||
__macro(rocblas_saxpy) \
|
||||
__macro(rocblas_daxpy) \
|
||||
/*__macro(rocblas_caxpy) \
|
||||
|
@ -158,9 +158,9 @@ namespace wrap {
|
|||
__macro(rocblas_drotmg) */ \
|
||||
__macro(rocblas_sgemv) \
|
||||
__macro(rocblas_dgemv) \
|
||||
/*__macro(rocblas_cgemv) \
|
||||
__macro(rocblas_cgemv) \
|
||||
__macro(rocblas_zgemv) \
|
||||
__macro(rocblas_sgbmv) \
|
||||
/* __macro(rocblas_sgbmv) \
|
||||
__macro(rocblas_dgbmv) \
|
||||
__macro(rocblas_cgbmv) \
|
||||
__macro(rocblas_zgbmv) \
|
||||
|
@ -231,9 +231,9 @@ namespace wrap {
|
|||
__macro(rocblas_sgemm) \
|
||||
__macro(rocblas_dgemm) \
|
||||
__macro(rocblas_hgemm) \
|
||||
/*__macro(rocblas_cgemm) \
|
||||
__macro(rocblas_cgemm) \
|
||||
__macro(rocblas_zgemm) \
|
||||
__macro(rocblas_ssyrk) \
|
||||
/* __macro(rocblas_ssyrk) \
|
||||
__macro(rocblas_dsyrk) \
|
||||
__macro(rocblas_csyrk) \
|
||||
__macro(rocblas_zsyrk) \
|
||||
|
@ -285,12 +285,37 @@ STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_hgemm_strided_batched)
|
|||
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_sgemm_strided_batched)
|
||||
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_batched)
|
||||
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_dgemm_strided_batched)
|
||||
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_strided_batched)
|
||||
STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_strided_batched)
|
||||
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_cgemm_batched)
|
||||
// STREAM_EXECUTOR_ROCBLAS_WRAP(rocblas_zgemm_batched)
|
||||
ROCBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_ROCBLAS_V2_WRAP)
|
||||
|
||||
} // namespace wrap
|
||||
|
||||
template <class T>
|
||||
const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
|
||||
const DeviceMemory<T> &a) {
|
||||
return reinterpret_cast<
|
||||
const typename RocBlasTypeConversionHelper<T>::mapped_type *>(
|
||||
GpuMemory(a));
|
||||
}
|
||||
template <class T>
|
||||
const typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
|
||||
const T &a) {
|
||||
return reinterpret_cast<
|
||||
const typename RocBlasTypeConversionHelper<T>::mapped_type *>(&a);
|
||||
}
|
||||
template <class T>
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type *complex_cast(
|
||||
DeviceMemory<T> *a) {
|
||||
return reinterpret_cast<
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type *>(
|
||||
GpuMemoryMutable(a));
|
||||
}
|
||||
|
||||
static void blas_log(const char *c) {}
|
||||
|
||||
static string ToString(rocblas_status status) {
|
||||
switch (status) {
|
||||
case rocblas_status_success:
|
||||
|
@ -451,6 +476,7 @@ bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count,
|
|||
bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
|
||||
const DeviceMemory<float> &x, int incx,
|
||||
DeviceMemory<float> *y, int incy) {
|
||||
blas_log("DoBlasAxpy");
|
||||
return DoBlasInternal(wrap::rocblas_saxpy, stream,
|
||||
true /* = pointer_mode_host */, elem_count, &alpha,
|
||||
GpuMemory(x), incx, GpuMemoryMutable(y), incy);
|
||||
|
@ -459,6 +485,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
|
|||
bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
|
||||
const DeviceMemory<double> &x, int incx,
|
||||
DeviceMemory<double> *y, int incy) {
|
||||
blas_log("DoBlasAxpy");
|
||||
return DoBlasInternal(wrap::rocblas_daxpy, stream,
|
||||
true /* = pointer_mode_host */, elem_count, &alpha,
|
||||
GpuMemory(x), incx, GpuMemoryMutable(y), incy);
|
||||
|
@ -518,6 +545,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
|
|||
const DeviceMemory<float> &x, int incx,
|
||||
const DeviceMemory<float> &y, int incy,
|
||||
DeviceMemory<float> *result) {
|
||||
blas_log("DoBlasDot");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count,
|
||||
GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
|
||||
|
@ -527,6 +555,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count,
|
|||
const DeviceMemory<double> &x, int incx,
|
||||
const DeviceMemory<double> &y, int incy,
|
||||
DeviceMemory<double> *result) {
|
||||
blas_log("DoBlasDot");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count,
|
||||
GpuMemory(x), incx, GpuMemory(y), incy, GpuMemoryMutable(result));
|
||||
|
@ -707,6 +736,7 @@ bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
|
|||
|
||||
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
|
||||
DeviceMemory<float> *x, int incx) {
|
||||
blas_log("DoBlasScal<float>");
|
||||
return DoBlasInternal(wrap::rocblas_sscal, stream,
|
||||
true /* = pointer_mode_host */, elem_count, &alpha,
|
||||
GpuMemoryMutable(x), incx);
|
||||
|
@ -721,32 +751,32 @@ bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
|
|||
|
||||
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
|
||||
DeviceMemory<std::complex<float>> *x, int incx) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
|
||||
<< "for the \"complex<float>\" datatype";
|
||||
return false;
|
||||
return DoBlasInternal(wrap::rocblas_csscal, stream,
|
||||
true /* = pointer_mode_host */, elem_count, &alpha,
|
||||
complex_cast(x), incx);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
|
||||
DeviceMemory<std::complex<double>> *x, int incx) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
|
||||
<< "for the \"complex<double>\" datatype";
|
||||
return false;
|
||||
return DoBlasInternal(wrap::rocblas_zdscal, stream,
|
||||
true /* = pointer_mode_host */, elem_count, &alpha,
|
||||
complex_cast(x), incx);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
|
||||
std::complex<float> alpha,
|
||||
DeviceMemory<std::complex<float>> *x, int incx) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
|
||||
<< "for the \"complex<float>\" datatype";
|
||||
return false;
|
||||
return DoBlasInternal(wrap::rocblas_cscal, stream,
|
||||
true /* = pointer_mode_host */, elem_count,
|
||||
complex_cast(alpha), complex_cast(x), incx);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count,
|
||||
std::complex<double> alpha,
|
||||
DeviceMemory<std::complex<double>> *x, int incx) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the SCAL operation "
|
||||
<< "for the \"complex<double>\" datatype";
|
||||
return false;
|
||||
return DoBlasInternal(wrap::rocblas_zscal, stream,
|
||||
true /* = pointer_mode_host */, elem_count,
|
||||
complex_cast(alpha), complex_cast(x), incx);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count,
|
||||
|
@ -893,6 +923,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
|
|||
uint64 n, float alpha, const DeviceMemory<float> &a,
|
||||
int lda, const DeviceMemory<float> &x, int incx,
|
||||
float beta, DeviceMemory<float> *y, int incy) {
|
||||
blas_log("DoBlasGemv");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
|
||||
|
@ -903,6 +934,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
|
|||
uint64 n, double alpha, const DeviceMemory<double> &a,
|
||||
int lda, const DeviceMemory<double> &x, int incx,
|
||||
double beta, DeviceMemory<double> *y, int incy) {
|
||||
blas_log("DoBlasGemv");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(trans), m, n, &alpha, GpuMemory(a), lda, GpuMemory(x),
|
||||
|
@ -915,9 +947,11 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
|
|||
const DeviceMemory<std::complex<float>> &x, int incx,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *y, int incy) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
|
||||
<< "for the \"complex<float>\" datatype";
|
||||
return false;
|
||||
blas_log("DoBlasGemv");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_cgemv, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
|
||||
complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
|
||||
|
@ -926,9 +960,11 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
|
|||
const DeviceMemory<std::complex<double>> &x, int incx,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *y, int incy) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the GEMV operation "
|
||||
<< "for the \"complex<double>\" datatype";
|
||||
return false;
|
||||
blas_log("DoBlasGemv\n");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_zgemv, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(trans), m, n, complex_cast(alpha), complex_cast(a), lda,
|
||||
complex_cast(x), incx, complex_cast(beta), complex_cast(y), incy);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
|
||||
|
@ -1481,6 +1517,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
float alpha, const DeviceMemory<Eigen::half> &a,
|
||||
int lda, const DeviceMemory<Eigen::half> &b, int ldb,
|
||||
float beta, DeviceMemory<Eigen::half> *c, int ldc) {
|
||||
blas_log("DoBlasGemm");
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||
|
@ -1526,6 +1563,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
float alpha, const DeviceMemory<float> &a, int lda,
|
||||
const DeviceMemory<float> &b, int ldb, float beta,
|
||||
DeviceMemory<float> *c, int ldc) {
|
||||
blas_log("DoBlasGemm");
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS SGEMM: at=%d bt=%d m=%u n=%u "
|
||||
"k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
|
||||
|
@ -1565,6 +1603,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
double alpha, const DeviceMemory<double> &a, int lda,
|
||||
const DeviceMemory<double> &b, int ldb, double beta,
|
||||
DeviceMemory<double> *c, int ldc) {
|
||||
blas_log("DoBlasGemm");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha,
|
||||
|
@ -1578,9 +1617,12 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
||||
std::complex<float> beta,
|
||||
DeviceMemory<std::complex<float>> *c, int ldc) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
|
||||
<< "for the \"complex<float>\" datatype";
|
||||
return false;
|
||||
blas_log("DoBlasGemm");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_cgemm, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
||||
complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
|
||||
complex_cast(beta), complex_cast(c), ldc);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||
|
@ -1590,9 +1632,12 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
||||
std::complex<double> beta,
|
||||
DeviceMemory<std::complex<double>> *c, int ldc) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the GEMM operation "
|
||||
<< "for the \"complex<double>\" datatype";
|
||||
return false;
|
||||
blas_log("DoBlasGemm");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_zgemm, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k,
|
||||
complex_cast(alpha), complex_cast(a), lda, complex_cast(b), ldb,
|
||||
complex_cast(beta), complex_cast(c), ldc);
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemvWithProfiling(
|
||||
|
@ -1813,6 +1858,56 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
|||
return false;
|
||||
}
|
||||
|
||||
// This copies from source memory: raw_ptrs[i] to target memory:
|
||||
// device_memory_ptr at the interval of matrix_byte_size, or vice versa.
|
||||
// The below algorithm tries to minimize the number of memcpy by consolidating
|
||||
// neighboring memcpy into a single request
|
||||
template <typename MAPPED_T>
|
||||
port::Status ReorganizeMemory(Stream *stream,
|
||||
DeviceMemory<MAPPED_T> *device_memory,
|
||||
const std::vector<MAPPED_T *> &raw_ptrs,
|
||||
int batch_count, uint64_t batch_stride,
|
||||
bool gather) {
|
||||
assert(batch_count > 0);
|
||||
char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
|
||||
char *src_ptr = reinterpret_cast<char *>(raw_ptrs[0]);
|
||||
char *dst_ptr = device_memory_ptr;
|
||||
size_t matrix_byte_size = batch_stride * sizeof(MAPPED_T);
|
||||
uint64_t cur_stride_size = matrix_byte_size;
|
||||
|
||||
for (int i = 1; i < batch_count; ++i) {
|
||||
if (reinterpret_cast<char *>(raw_ptrs[i]) == src_ptr + cur_stride_size) {
|
||||
cur_stride_size += matrix_byte_size;
|
||||
} else {
|
||||
DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
|
||||
DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
|
||||
bool a_status =
|
||||
gather
|
||||
? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
|
||||
: stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
|
||||
if (!a_status) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
"failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
|
||||
}
|
||||
src_ptr = reinterpret_cast<char *>(raw_ptrs[i]);
|
||||
dst_ptr = device_memory_ptr + i * matrix_byte_size;
|
||||
cur_stride_size = matrix_byte_size;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceMemoryBase src_mem = DeviceMemoryBase(src_ptr, cur_stride_size);
|
||||
DeviceMemoryBase target_mem = DeviceMemoryBase(dst_ptr, cur_stride_size);
|
||||
bool a_status =
|
||||
gather ? stream->ThenMemcpy(&target_mem, src_mem, cur_stride_size).ok()
|
||||
: stream->ThenMemcpy(&src_mem, target_mem, cur_stride_size).ok();
|
||||
if (!a_status)
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
"failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
port::Status ROCMBlas::AllocateStridedBuffer(
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
|
||||
|
@ -1822,7 +1917,8 @@ port::Status ROCMBlas::AllocateStridedBuffer(
|
|||
std::unique_ptr<TemporaryDeviceMemory<
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
||||
*device_memory) {
|
||||
*device_memory,
|
||||
bool copy_data, bool &reallocated) {
|
||||
assert(device_memory != nullptr);
|
||||
|
||||
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
|
||||
|
@ -1843,6 +1939,7 @@ port::Status ROCMBlas::AllocateStridedBuffer(
|
|||
if (!needs_allocate_strided) {
|
||||
*device_memory = DeviceMemory<MAPPED_T>(
|
||||
DeviceMemoryBase(raw_ptrs[0], matrix_batch_byte_size));
|
||||
reallocated = false;
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1859,19 +1956,11 @@ port::Status ROCMBlas::AllocateStridedBuffer(
|
|||
DeviceMemory<MAPPED_T>(*(*temp_memory)->mutable_device_memory());
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch_count; ++i) {
|
||||
char *device_memory_ptr = static_cast<char *>(device_memory->opaque());
|
||||
DeviceMemoryBase src_mem = DeviceMemoryBase(raw_ptrs[i], matrix_byte_size);
|
||||
DeviceMemoryBase target_mem = DeviceMemoryBase(
|
||||
device_memory_ptr + i * matrix_byte_size, matrix_byte_size);
|
||||
bool a_status =
|
||||
stream->ThenMemcpy(&target_mem, src_mem, matrix_byte_size).ok();
|
||||
if (!a_status) {
|
||||
return port::Status(
|
||||
port::error::INTERNAL,
|
||||
"failed to copy device memory in ROCMBlas::DoBlasGemmBatched");
|
||||
}
|
||||
}
|
||||
reallocated = true;
|
||||
|
||||
if (copy_data)
|
||||
return ReorganizeMemory(stream, device_memory, raw_ptrs, batch_count,
|
||||
batch_stride, true);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1925,27 +2014,28 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal(
|
|||
DeviceMemory<MAPPED_T> a;
|
||||
// Make sure the temporary memory are in-scope before the function returns
|
||||
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> a_temp;
|
||||
port::Status a_allocation_status =
|
||||
AllocateStridedBuffer<T>(a_raw_ptrs, batch_count, batch_stride_a,
|
||||
scratch_allocator, stream, &a_temp, &a);
|
||||
bool reallocated_a, reallocated_b, reallocated_c;
|
||||
port::Status a_allocation_status = AllocateStridedBuffer<T>(
|
||||
a_raw_ptrs, batch_count, batch_stride_a, scratch_allocator, stream,
|
||||
&a_temp, &a, true, reallocated_a);
|
||||
if (a_allocation_status != port::Status::OK()) {
|
||||
return a_allocation_status;
|
||||
}
|
||||
|
||||
DeviceMemory<MAPPED_T> b;
|
||||
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> b_temp;
|
||||
port::Status b_allocation_status =
|
||||
AllocateStridedBuffer<T>(b_raw_ptrs, batch_count, batch_stride_b,
|
||||
scratch_allocator, stream, &b_temp, &b);
|
||||
port::Status b_allocation_status = AllocateStridedBuffer<T>(
|
||||
b_raw_ptrs, batch_count, batch_stride_b, scratch_allocator, stream,
|
||||
&b_temp, &b, true, reallocated_b);
|
||||
if (b_allocation_status != port::Status::OK()) {
|
||||
return b_allocation_status;
|
||||
}
|
||||
|
||||
DeviceMemory<MAPPED_T> c;
|
||||
std::unique_ptr<TemporaryDeviceMemory<MAPPED_T>> c_temp;
|
||||
port::Status c_allocation_status =
|
||||
AllocateStridedBuffer<T>(c_raw_ptrs, batch_count, batch_stride_c,
|
||||
scratch_allocator, stream, &c_temp, &c);
|
||||
port::Status c_allocation_status = AllocateStridedBuffer<T>(
|
||||
c_raw_ptrs, batch_count, batch_stride_c, scratch_allocator, stream,
|
||||
&c_temp, &c, true, reallocated_c); // can disable copy if beta=0
|
||||
if (c_allocation_status != port::Status::OK()) {
|
||||
return c_allocation_status;
|
||||
}
|
||||
|
@ -1953,19 +2043,20 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal(
|
|||
MAPPED_T *alpha_ptr = reinterpret_cast<MAPPED_T *>(&alpha);
|
||||
MAPPED_T *beta_ptr = reinterpret_cast<MAPPED_T *>(&beta);
|
||||
|
||||
bool ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb),
|
||||
m, n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
|
||||
bool ok;
|
||||
ok = DoBlasInternal(rocblas_func, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||
n, k, GpuComplex(alpha_ptr), GpuMemory(a), lda,
|
||||
batch_stride_a, GpuMemory(b), ldb, batch_stride_b,
|
||||
GpuComplex(beta_ptr), GpuMemoryMutable(&c), ldc,
|
||||
batch_stride_c, batch_count);
|
||||
|
||||
if (ok) {
|
||||
return port::Status::OK();
|
||||
} else {
|
||||
if (!ok)
|
||||
return port::Status(port::error::INTERNAL,
|
||||
"failed BLAS call, see log for details");
|
||||
}
|
||||
if (reallocated_c)
|
||||
return ReorganizeMemory(stream, &c, c_raw_ptrs, batch_count, batch_stride_c,
|
||||
false);
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmBatched(
|
||||
|
@ -1975,6 +2066,7 @@ bool ROCMBlas::DoBlasGemmBatched(
|
|||
const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
|
||||
const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
|
||||
int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
const Eigen::half alpha_half(alpha);
|
||||
const Eigen::half beta_half(beta);
|
||||
|
||||
|
@ -1996,6 +2088,7 @@ bool ROCMBlas::DoBlasGemmBatched(
|
|||
const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
|
||||
const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
|
||||
int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
port::Status status = DoBlasGemmBatchedInternal(
|
||||
wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k,
|
||||
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
|
||||
|
@ -2013,6 +2106,7 @@ bool ROCMBlas::DoBlasGemmBatched(
|
|||
const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
|
||||
double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
port::Status status = DoBlasGemmBatchedInternal(
|
||||
wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k,
|
||||
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
|
||||
|
@ -2032,9 +2126,15 @@ bool ROCMBlas::DoBlasGemmBatched(
|
|||
int ldb, std::complex<float> beta,
|
||||
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
|
||||
<< "for the \"complex<float>\" datatype";
|
||||
return false;
|
||||
blas_log("DoBlasGemmBatched");
|
||||
port::Status status = DoBlasGemmBatchedInternal(
|
||||
wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
|
||||
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
|
||||
scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
return status.ok();
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmBatched(
|
||||
|
@ -2046,9 +2146,15 @@ bool ROCMBlas::DoBlasGemmBatched(
|
|||
int ldb, std::complex<double> beta,
|
||||
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
LOG(ERROR) << "rocBLAS does not currently support the GEMMBatched operation "
|
||||
<< "for the \"complex<double>\" datatype";
|
||||
return false;
|
||||
blas_log("DoBlasGemmBatched");
|
||||
port::Status status = DoBlasGemmBatchedInternal(
|
||||
wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k,
|
||||
alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
|
||||
scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
return status.ok();
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side,
|
||||
|
@ -2296,6 +2402,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
|||
blas::Diagonal diag, uint64 m, uint64 n, float alpha,
|
||||
const DeviceMemory<float> &a, int lda,
|
||||
DeviceMemory<float> *b, int ldb) {
|
||||
blas_log("DoBlasTrsm");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_strsm, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
|
||||
|
@ -2308,6 +2415,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
|||
blas::Diagonal diag, uint64 m, uint64 n, double alpha,
|
||||
const DeviceMemory<double> &a, int lda,
|
||||
DeviceMemory<double> *b, int ldb) {
|
||||
blas_log("DoBlasTrsm");
|
||||
return DoBlasInternal(
|
||||
wrap::rocblas_dtrsm, stream, true /* = pointer_mode_host */,
|
||||
ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa),
|
||||
|
@ -2336,12 +2444,14 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
|||
<< "for the \"complex<double>\" datatype";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ROCMBlas::DoBlasGemmStridedBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
|
||||
int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
|
||||
int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
|
||||
int64 stride_c, int batch_count) {
|
||||
blas_log("DoBlasGemmStridedBatched");
|
||||
const Eigen::half alpha_half(alpha);
|
||||
const Eigen::half beta_half(beta);
|
||||
|
||||
|
@ -2363,6 +2473,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||
int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
|
||||
float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
|
||||
int batch_count) {
|
||||
blas_log("DoBlasGemmStridedBatched");
|
||||
return DoBlasInternal(wrap::rocblas_sgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||
|
@ -2376,6 +2487,7 @@ bool ROCMBlas::DoBlasGemmStridedBatched(
|
|||
int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
|
||||
double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
|
||||
int batch_count) {
|
||||
blas_log("DoBlasGemmStridedBatched");
|
||||
return DoBlasInternal(wrap::rocblas_dgemm_strided_batched, stream,
|
||||
false, /* pointer_mode_host */
|
||||
ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m,
|
||||
|
|
|
@ -45,6 +45,16 @@ struct RocBlasTypeConversionHelper<Eigen::half> {
|
|||
using mapped_type = rocblas_half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RocBlasTypeConversionHelper<std::complex<float>> {
|
||||
using mapped_type = rocblas_float_complex;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct RocBlasTypeConversionHelper<std::complex<double>> {
|
||||
using mapped_type = rocblas_double_complex;
|
||||
};
|
||||
|
||||
// Opaque and unique identifier for the rocBLAS plugin.
|
||||
extern const PluginId kRocBlasPlugin;
|
||||
|
||||
|
@ -121,7 +131,8 @@ class ROCMBlas : public blas::BlasSupport {
|
|||
std::unique_ptr<TemporaryDeviceMemory<
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>> *temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
||||
*device_memory);
|
||||
*device_memory,
|
||||
bool copy_data, bool &reallocated);
|
||||
|
||||
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
||||
// types.
|
||||
|
|
Loading…
Reference in New Issue