[XLA:CPU] [XLA:GPU] Adds compiler support for C64 primitive type, including relevant elementwise unary and binary op lowering for CPU and GPU.
We use a named LLVM struct "complex64", laid out the same as std::complex<float>. This named struct is accessed via the llvm::Module, which required changes to accessors of PrimitiveTypeToIrType & friends. Ops that require atan2 (in particular, angle and log) are only supported on GPU at this point. LLVM lacks a CPU intrinsic for atan or atan2, whereas libdevice provides this for GPU. PiperOrigin-RevId: 173676849
This commit is contained in:
parent
4ae245a7db
commit
4198e27be8
@ -144,8 +144,8 @@ TEST(XlaCompilationTest, UnsupportedTypes) {
|
||||
Node* a = ops::SourceOp(
|
||||
"Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_COMPLEX64)
|
||||
.WithAttr("value", Tensor(DT_COMPLEX64, TensorShape())));
|
||||
.WithAttr("dtype", DT_COMPLEX128)
|
||||
.WithAttr("value", Tensor(DT_COMPLEX128, TensorShape())));
|
||||
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
|
||||
TF_EXPECT_OK(builder.ToGraph(graph.get()));
|
||||
|
@ -50,8 +50,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 5> kAllXlaCpuTypes = {
|
||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
||||
constexpr std::array<DataType, 6> kAllXlaCpuTypes = {
|
||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
|
||||
|
@ -55,8 +55,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 5> kAllXlaGpuTypes = {
|
||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
||||
constexpr std::array<DataType, 6> kAllXlaGpuTypes = {
|
||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
|
||||
|
@ -23,6 +23,10 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
|
||||
load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
|
||||
generate_backend_suites()
|
||||
|
||||
@ -581,11 +585,12 @@ cc_library(
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "randomized_tests",
|
||||
size = "large",
|
||||
# This test is randomized, so only run it if explicitly requested.
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
] + tf_cuda_tests_tags(),
|
||||
deps = [":randomized_tests_library"],
|
||||
)
|
||||
|
||||
|
@ -46,7 +46,9 @@ class ArgMinMaxTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
def testArgMinMax(self):
|
||||
for dtype in self.numeric_types:
|
||||
# Complex numbers do not support argmin/argmax.
|
||||
minmax_types = set(self.numeric_types) - set(self.complex_types)
|
||||
for dtype in minmax_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32),
|
||||
np.array([1, 10, 27, 3, 3, 4], dtype=dtype),
|
||||
|
@ -94,6 +94,15 @@ class BinaryOpsTest(XLATestCase):
|
||||
dtype(4),
|
||||
expected=np.array([[16], [81]], dtype=dtype))
|
||||
|
||||
atan2_supported = self.device == "XLA_GPU"
|
||||
if atan2_supported:
|
||||
self._testBinary(
|
||||
math_ops.atan2,
|
||||
np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype),
|
||||
np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype),
|
||||
expected=np.array(
|
||||
[0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._reciprocal_grad,
|
||||
np.array([4, -3, -2, 1], dtype=dtype),
|
||||
@ -259,37 +268,38 @@ class BinaryOpsTest(XLATestCase):
|
||||
dtype(7),
|
||||
expected=np.array([[-6], [-5]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([10, 20], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
dtype(5),
|
||||
np.array([1, 20], dtype=dtype),
|
||||
expected=np.array([5, 20], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
np.array([[10], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[10], [7]], dtype=dtype))
|
||||
if dtype not in self.complex_types: # min/max not supported for complex
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([10, 20], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
dtype(5),
|
||||
np.array([1, 20], dtype=dtype),
|
||||
expected=np.array([5, 20], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.maximum,
|
||||
np.array([[10], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[10], [7]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
np.array([1, 20], dtype=dtype),
|
||||
np.array([10, 2], dtype=dtype),
|
||||
expected=np.array([1, 2], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
dtype(5),
|
||||
np.array([1, 20], dtype=dtype),
|
||||
expected=np.array([1, 5], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
np.array([[10], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[7], [2]], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
np.array([1, 20], dtype=dtype),
|
||||
np.array([10, 2], dtype=dtype),
|
||||
expected=np.array([1, 2], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
dtype(5),
|
||||
np.array([1, 20], dtype=dtype),
|
||||
expected=np.array([1, 5], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.minimum,
|
||||
np.array([[10], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[7], [2]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.multiply,
|
||||
@ -307,21 +317,23 @@ class BinaryOpsTest(XLATestCase):
|
||||
dtype(7),
|
||||
expected=np.array([[70], [14]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([81, 324], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
dtype(5),
|
||||
np.array([1, 2], dtype=dtype),
|
||||
expected=np.array([16, 9], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
np.array([[1], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[36], [25]], dtype=dtype))
|
||||
# Complex support for squared_difference is incidental, see b/68205550
|
||||
if dtype not in self.complex_types:
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
np.array([1, 2], dtype=dtype),
|
||||
np.array([10, 20], dtype=dtype),
|
||||
expected=np.array([81, 324], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
dtype(5),
|
||||
np.array([1, 2], dtype=dtype),
|
||||
expected=np.array([16, 9], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.squared_difference,
|
||||
np.array([[1], [2]], dtype=dtype),
|
||||
dtype(7),
|
||||
expected=np.array([[36], [25]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
nn_ops.bias_add,
|
||||
@ -334,6 +346,139 @@ class BinaryOpsTest(XLATestCase):
|
||||
np.array([2, -1], dtype=dtype),
|
||||
expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
|
||||
|
||||
def testComplexOps(self):
|
||||
for dtype in self.complex_types:
|
||||
ctypes = {np.complex64: np.float32}
|
||||
self._testBinary(
|
||||
math_ops.complex,
|
||||
np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]),
|
||||
np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]),
|
||||
expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
|
||||
np.array(
|
||||
[[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]],
|
||||
dtype=dtype),
|
||||
np.array(
|
||||
[[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype),
|
||||
expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._real_div,
|
||||
np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype),
|
||||
np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype),
|
||||
expected=np.array(
|
||||
[
|
||||
1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2,
|
||||
float("inf")
|
||||
],
|
||||
dtype=dtype))
|
||||
|
||||
# TODO(b/65408531): support+test pow for cplx
|
||||
|
||||
lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
|
||||
rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
|
||||
self._testBinary(
|
||||
gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs)
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))
|
||||
|
||||
# TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow)
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs))
|
||||
|
||||
def testComplexMath(self):
|
||||
for dtype in self.complex_types:
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
np.array([1 + 3j, 2 + 7j], dtype=dtype),
|
||||
np.array([10 - 4j, 20 + 17j], dtype=dtype),
|
||||
expected=np.array([11 - 1j, 22 + 24j], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
dtype(5 - 7j),
|
||||
np.array([1 + 2j, 2 + 4j], dtype=dtype),
|
||||
expected=np.array([6 - 5j, 7 - 3j], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.add,
|
||||
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
|
||||
dtype(7 + 5j),
|
||||
expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.subtract,
|
||||
np.array([1 + 3j, 2 + 7j], dtype=dtype),
|
||||
np.array([10 - 4j, 20 + 17j], dtype=dtype),
|
||||
expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.subtract,
|
||||
dtype(5 - 7j),
|
||||
np.array([1 + 2j, 2 + 4j], dtype=dtype),
|
||||
expected=np.array([4 - 9j, 3 - 11j], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.subtract,
|
||||
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
|
||||
dtype(7 + 5j),
|
||||
expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.multiply,
|
||||
np.array([1 + 3j, 2 + 7j], dtype=dtype),
|
||||
np.array([10 - 4j, 20 + 17j], dtype=dtype),
|
||||
expected=np.array(
|
||||
[(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.multiply,
|
||||
dtype(5 - 7j),
|
||||
np.array([1 + 2j, 2 + 4j], dtype=dtype),
|
||||
expected=np.array(
|
||||
[(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.multiply,
|
||||
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
|
||||
dtype(7 + 5j),
|
||||
expected=np.array(
|
||||
[[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
math_ops.div,
|
||||
np.array([8 - 1j, 2 + 16j], dtype=dtype),
|
||||
np.array([2 + 4j, 4 - 8j], dtype=dtype),
|
||||
expected=np.array(
|
||||
[(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.div,
|
||||
dtype(1 + 2j),
|
||||
np.array([2 + 4j, 4 - 8j], dtype=dtype),
|
||||
expected=np.array(
|
||||
[(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype))
|
||||
self._testBinary(
|
||||
math_ops.div,
|
||||
np.array([2 + 4j, 4 - 8j], dtype=dtype),
|
||||
dtype(1 + 2j),
|
||||
expected=np.array(
|
||||
[(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype))
|
||||
|
||||
# TODO(b/68205550): math_ops.squared_difference shouldn't be supported.
|
||||
|
||||
self._testBinary(
|
||||
nn_ops.bias_add,
|
||||
np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype),
|
||||
np.array([2 + 6j, -1 - 3j], dtype=dtype),
|
||||
expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype))
|
||||
self._testBinary(
|
||||
nn_ops.bias_add,
|
||||
np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype),
|
||||
np.array([2 + 1j, -1 + 2j], dtype=dtype),
|
||||
expected=np.array(
|
||||
[[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype))
|
||||
|
||||
def _testDivision(self, dtype):
|
||||
"""Test cases for division operators."""
|
||||
self._testBinary(
|
||||
@ -352,18 +497,19 @@ class BinaryOpsTest(XLATestCase):
|
||||
dtype(2),
|
||||
expected=np.array([[5], [2]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops._floor_div,
|
||||
np.array([3, 3, -1, -9, -8], dtype=dtype),
|
||||
np.array([2, -2, 7, 2, -4], dtype=dtype),
|
||||
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
|
||||
if dtype not in self.complex_types: # floordiv unsupported for complex.
|
||||
self._testBinary(
|
||||
gen_math_ops._floor_div,
|
||||
np.array([3, 3, -1, -9, -8], dtype=dtype),
|
||||
np.array([2, -2, 7, 2, -4], dtype=dtype),
|
||||
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
|
||||
|
||||
def testIntDivision(self):
|
||||
for dtype in self.int_types:
|
||||
self._testDivision(dtype)
|
||||
|
||||
def testFloatDivision(self):
|
||||
for dtype in self.float_types:
|
||||
for dtype in self.float_types + self.complex_types:
|
||||
self._testDivision(dtype)
|
||||
|
||||
def _testRemainder(self, dtype):
|
||||
|
@ -49,11 +49,15 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
|
||||
backend_deps = []
|
||||
backend_data = []
|
||||
if backend == "cpu":
|
||||
backend_args += ["--test_device=XLA_CPU",
|
||||
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
|
||||
backend_args += [
|
||||
"--test_device=XLA_CPU",
|
||||
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
|
||||
]
|
||||
elif backend == "gpu":
|
||||
backend_args += ["--test_device=XLA_GPU",
|
||||
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
|
||||
backend_args += [
|
||||
"--test_device=XLA_GPU",
|
||||
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
|
||||
]
|
||||
backend_tags += ["requires-gpu-sm35"]
|
||||
elif backend in plugins:
|
||||
backend_args += ["--test_device=" + plugins[backend]["device"],
|
||||
|
@ -30,8 +30,6 @@ from tensorflow.python.platform import test
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_TEST_TYPES = [dtypes.float32]
|
||||
|
||||
|
||||
class GatherTest(xla_test.XLATestCase):
|
||||
|
||||
@ -46,7 +44,7 @@ class GatherTest(xla_test.XLATestCase):
|
||||
def testScalar1D(self):
|
||||
with self.test_session() as session, self.test_scope():
|
||||
data = np.array([0, 1, 2, 3, 7, 5])
|
||||
for dtype in _TEST_TYPES:
|
||||
for dtype in self.all_tf_types:
|
||||
for indices in 4, [1, 2, 2, 4, 5]:
|
||||
params_np = self._buildParams(data, dtype)
|
||||
params = array_ops.placeholder(dtype=dtype)
|
||||
@ -60,7 +58,7 @@ class GatherTest(xla_test.XLATestCase):
|
||||
with self.test_session() as session, self.test_scope():
|
||||
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
|
||||
[12, 13, 14]])
|
||||
for dtype in _TEST_TYPES:
|
||||
for dtype in self.all_tf_types:
|
||||
for axis in 0, 1, -1:
|
||||
params_np = self._buildParams(data, dtype)
|
||||
params = array_ops.placeholder(dtype=dtype)
|
||||
@ -74,7 +72,7 @@ class GatherTest(xla_test.XLATestCase):
|
||||
with self.test_session() as session, self.test_scope():
|
||||
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
|
||||
[12, 13, 14]])
|
||||
for dtype in _TEST_TYPES:
|
||||
for dtype in self.all_tf_types:
|
||||
for axis in 0, 1, -1:
|
||||
params_np = self._buildParams(data, dtype)
|
||||
params = array_ops.placeholder(dtype=dtype)
|
||||
@ -94,7 +92,7 @@ class GatherTest(xla_test.XLATestCase):
|
||||
[12, 13, 14]])
|
||||
# The indices must be in bounds for any axis.
|
||||
indices_np = np.array([0, 1, 0, 2])
|
||||
for dtype in _TEST_TYPES:
|
||||
for dtype in self.all_tf_types:
|
||||
for axis in 0, 1, -1:
|
||||
params_np = self._buildParams(data, dtype)
|
||||
params = array_ops.placeholder(dtype=dtype)
|
||||
@ -112,7 +110,7 @@ class GatherTest(xla_test.XLATestCase):
|
||||
"""Check that scalar and empty indices shapes work as well."""
|
||||
shape = (2, 1, 3, 2)
|
||||
for indices_shape in (), (0,), (2, 0), (2, 3):
|
||||
for dtype in _TEST_TYPES:
|
||||
for dtype in self.all_tf_types:
|
||||
for axis in 0, 1, 2, 3, -1, -2:
|
||||
params = self._buildParams(np.random.randn(*shape), dtype)
|
||||
indices = np.random.randint(shape[axis], size=indices_shape)
|
||||
|
@ -68,6 +68,26 @@ class NAryOpsTest(XLATestCase):
|
||||
np.array([42], dtype=np.float32)],
|
||||
expected=np.array([48], dtype=np.float32))
|
||||
|
||||
def testComplex(self):
|
||||
for dtype in self.complex_types:
|
||||
self._testNAry(
|
||||
math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)],
|
||||
expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype))
|
||||
|
||||
self._testNAry(
|
||||
math_ops.add_n, [
|
||||
np.array([1 + 2j, 2 - 3j], dtype=dtype),
|
||||
np.array([10j, 20], dtype=dtype)
|
||||
],
|
||||
expected=np.array([1 + 12j, 22 - 3j], dtype=dtype))
|
||||
self._testNAry(
|
||||
math_ops.add_n, [
|
||||
np.array([-4, 5j], dtype=dtype),
|
||||
np.array([2 + 10j, -2], dtype=dtype),
|
||||
np.array([42j, 3 + 3j], dtype=dtype)
|
||||
],
|
||||
expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype))
|
||||
|
||||
@unittest.skip("IdentityN is temporarily CompilationOnly as workaround")
|
||||
def testIdentityN(self):
|
||||
self._testNAryLists(array_ops.identity_n,
|
||||
|
@ -29,6 +29,9 @@ from tensorflow.python.platform import googletest
|
||||
class RandomOpsTest(XLATestCase):
|
||||
"""Test cases for random-number generating operators."""
|
||||
|
||||
def _random_types(self):
|
||||
return set(self.numeric_types) - set(self.complex_types)
|
||||
|
||||
def _testRngIsNotConstant(self, rng, dtype):
|
||||
# Tests that 'rng' does not always return the same value.
|
||||
with self.test_session() as sess:
|
||||
@ -51,7 +54,8 @@ class RandomOpsTest(XLATestCase):
|
||||
def rng(dtype):
|
||||
return random_ops.random_uniform(shape=[2], dtype=dtype,
|
||||
maxval=1000000)
|
||||
for dtype in self.numeric_types:
|
||||
|
||||
for dtype in self._random_types():
|
||||
self._testRngIsNotConstant(rng, dtype)
|
||||
|
||||
def testRandomNormalIsNotConstant(self):
|
||||
@ -63,7 +67,7 @@ class RandomOpsTest(XLATestCase):
|
||||
self._testRngIsNotConstant(rng, dtype)
|
||||
|
||||
def testRandomUniformIsInRange(self):
|
||||
for dtype in self.numeric_types:
|
||||
for dtype in self._random_types():
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,
|
||||
|
@ -75,7 +75,7 @@ namespace {
|
||||
// Command line flags: see main() below.
|
||||
int64 tf_xla_random_seed = 0;
|
||||
int32 tf_xla_test_repetitions = 20;
|
||||
int64 tf_xla_max_tensor_size = 100000LL;
|
||||
int64 tf_xla_max_tensor_size = 10000LL;
|
||||
string* tf_xla_test_device_ptr; // initial value set in main()
|
||||
bool tf_xla_test_use_jit = true;
|
||||
|
||||
@ -83,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) {
|
||||
return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
|
||||
}
|
||||
|
||||
constexpr std::array<DataType, 3> kAllXlaTypes = {
|
||||
{DT_INT32, DT_FLOAT, DT_BOOL}};
|
||||
constexpr std::array<DataType, 4> kAllXlaTypes = {
|
||||
{DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}};
|
||||
|
||||
// An OpTestBuilder is a graph builder class that takes as input an operator to
|
||||
// test, its inputs and attributes, and builds a graph that executes the
|
||||
@ -449,6 +449,13 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape) {
|
||||
});
|
||||
break;
|
||||
}
|
||||
case DT_COMPLEX64: {
|
||||
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
|
||||
test::FillFn<complex64>(&tensor, [this, &distribution](int i) {
|
||||
return complex64(distribution(generator()), distribution(generator()));
|
||||
});
|
||||
break;
|
||||
}
|
||||
case DT_INT32: {
|
||||
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
|
||||
test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
|
||||
@ -624,11 +631,47 @@ std::vector<int32> OpTest::AsInt32s(const std::vector<int64>& int64s) {
|
||||
|
||||
// Functions for comparing tensors.
|
||||
|
||||
template <typename T>
|
||||
double Abs(T x) {
|
||||
return std::fabs(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
double Abs<complex64>(complex64 x) {
|
||||
return std::abs(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IsClose(const T& x, const T& y, double atol, double rtol) {
|
||||
if (std::isnan(x) && std::isnan(y)) return true;
|
||||
if (x == y) return true; // Allow inf == inf.
|
||||
return fabs(x - y) < atol + rtol * fabs(x);
|
||||
return Abs(x - y) < atol + rtol * Abs(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
|
||||
double rtol) {
|
||||
if (std::isnan(x.real()) && std::isnan(y.real())) {
|
||||
if (std::isnan(x.imag()) && std::isnan(y.imag())) {
|
||||
return true;
|
||||
}
|
||||
if (x.imag() == y.imag()) return true; // Allow inf == inf.
|
||||
return Abs(x.imag() - y.imag()) < atol + rtol * Abs(x.imag());
|
||||
} else if (std::isnan(x.imag()) && std::isnan(y.imag())) {
|
||||
if (x.real() == y.real()) return true; // Allow inf == inf.
|
||||
return Abs(x.real() - y.real()) < atol + rtol * Abs(x.real());
|
||||
}
|
||||
if (x == y) return true; // Allow inf == inf.
|
||||
return Abs(x - y) < atol + rtol * Abs(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
string Str(T x) {
|
||||
return strings::StrCat(x);
|
||||
}
|
||||
template <>
|
||||
string Str<complex64>(complex64 x) {
|
||||
return strings::StrCat("(", x.real(), ", ", x.imag(), ")");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -639,9 +682,10 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
|
||||
for (int i = 0; i < Tx.size(); ++i) {
|
||||
if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
|
||||
return errors::InvalidArgument(strings::StrCat(
|
||||
i, "-th tensor element isn't close: ", Tx(i), " vs. ", Ty(i),
|
||||
". x = ", x.DebugString(), "y = ", y.DebugString(), "atol = ", atol,
|
||||
" rtol = ", rtol, " tol = ", atol + rtol * std::fabs(Tx(i))));
|
||||
i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ",
|
||||
Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(),
|
||||
"atol = ", atol, " rtol = ", rtol,
|
||||
" tol = ", atol + rtol * Abs(Tx(i))));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -683,6 +727,8 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
|
||||
return TensorsAreCloseImpl<float>(a, b, atol, rtol);
|
||||
case DT_DOUBLE:
|
||||
return TensorsAreCloseImpl<double>(a, b, atol, rtol);
|
||||
case DT_COMPLEX64:
|
||||
return TensorsAreCloseImpl<complex64>(a, b, atol, rtol);
|
||||
case DT_INT32:
|
||||
return TensorsAreEqualImpl<int32>(a, b);
|
||||
case DT_INT64:
|
||||
@ -822,7 +868,7 @@ Tensor AsIntTensor(DataType dtype, const std::vector<int64>& values) {
|
||||
|
||||
TEST_F(OpTest, Abs) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Abs").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
@ -837,7 +883,7 @@ TEST_F(OpTest, Acosh) {
|
||||
|
||||
TEST_F(OpTest, Add) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
auto dims = BroadcastableDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
|
||||
.RandomInput(type, dims.first)
|
||||
@ -848,7 +894,7 @@ TEST_F(OpTest, Add) {
|
||||
|
||||
TEST_F(OpTest, AddN) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
int n = std::uniform_int_distribution<int>(1, 5)(generator());
|
||||
|
||||
auto shape = RandomDims();
|
||||
@ -890,9 +936,10 @@ TEST_F(OpTest, Any) {
|
||||
TEST_F(OpTest, ApproximateEqual) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(type, dims)
|
||||
.RandomInput(type, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
@ -1038,6 +1085,7 @@ TEST_F(OpTest, AvgPool3DGrad) {
|
||||
|
||||
TEST_F(OpTest, BatchMatMul) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
std::vector<int64> output_dims = RandomDims(2, 5, 0, 7);
|
||||
int64 ndims = output_dims.size();
|
||||
int64 inner_dim = RandomDim();
|
||||
@ -1056,9 +1104,9 @@ TEST_F(OpTest, BatchMatMul) {
|
||||
}
|
||||
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
|
||||
.RandomInput(DT_FLOAT, x_dims)
|
||||
.RandomInput(DT_FLOAT, y_dims)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, x_dims)
|
||||
.RandomInput(type, y_dims)
|
||||
.Attr("T", type)
|
||||
.Attr("adj_x", adj_x)
|
||||
.Attr("adj_y", adj_y));
|
||||
});
|
||||
@ -1090,10 +1138,11 @@ TEST_F(OpTest, BatchToSpace) {
|
||||
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
|
||||
TensorShape({num_block_dims, 2})));
|
||||
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
|
||||
.RandomInput(DT_FLOAT, input_dims)
|
||||
.RandomInput(type, input_dims)
|
||||
.Input(crops)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("T", type)
|
||||
.Attr("block_size", block_size));
|
||||
});
|
||||
}
|
||||
@ -1127,13 +1176,14 @@ TEST_F(OpTest, BatchToSpaceND) {
|
||||
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
|
||||
TensorShape({num_block_dims, 2})));
|
||||
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("BatchToSpaceND")
|
||||
.RandomInput(DT_FLOAT, input_dims)
|
||||
.RandomInput(type, input_dims)
|
||||
.Input(test::AsTensor<int32>(
|
||||
std::vector<int32>(block_dims.begin(), block_dims.end())))
|
||||
.Input(crops)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -1142,18 +1192,20 @@ TEST_F(OpTest, BiasAdd) {
|
||||
auto x_dims = RandomDims(2, kDefaultMaxRank);
|
||||
auto y_dims = {x_dims[x_dims.size() - 1]};
|
||||
// TODO(phawkins): test both data formats.
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd")
|
||||
.RandomInput(DT_FLOAT, x_dims)
|
||||
.RandomInput(DT_FLOAT, y_dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.RandomInput(type, x_dims)
|
||||
.RandomInput(type, y_dims)
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, BiasAddGrad) {
|
||||
Repeatedly([this]() {
|
||||
// TODO(phawkins): test both data formats.
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("BiasAddGrad").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("BiasAddGrad").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -1161,10 +1213,11 @@ TEST_F(OpTest, BiasAddV1) {
|
||||
Repeatedly([this]() {
|
||||
auto x_dims = RandomDims(2, kDefaultMaxRank);
|
||||
auto y_dims = {x_dims[x_dims.size() - 1]};
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1")
|
||||
.RandomInput(DT_FLOAT, x_dims)
|
||||
.RandomInput(DT_FLOAT, y_dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.RandomInput(type, x_dims)
|
||||
.RandomInput(type, y_dims)
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -1221,8 +1274,8 @@ TEST_F(OpTest, BroadcastGradientArgs) {
|
||||
TEST_F(OpTest, Cast) {
|
||||
Repeatedly([this]() {
|
||||
DataType src_type, dst_type;
|
||||
src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
|
||||
dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
|
||||
src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
|
||||
dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
|
||||
.RandomInput(src_type)
|
||||
.Attr("SrcT", src_type)
|
||||
@ -1293,11 +1346,12 @@ TEST_F(OpTest, Conv2D) {
|
||||
|
||||
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
|
||||
features_in, features_out};
|
||||
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Conv2D")
|
||||
.RandomInput(DT_FLOAT, data_dims)
|
||||
.RandomInput(DT_FLOAT, kernel_dims)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, data_dims)
|
||||
.RandomInput(type, kernel_dims)
|
||||
.Attr("T", type)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||
.Attr("data_format", "NHWC"));
|
||||
@ -1317,12 +1371,13 @@ TEST_F(OpTest, Conv2DBackpropFilter) {
|
||||
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
||||
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
|
||||
{d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}));
|
||||
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Conv2DBackpropFilter")
|
||||
.RandomInput(DT_FLOAT, activations)
|
||||
.RandomInput(type, activations)
|
||||
.Input(kernel_shape)
|
||||
.RandomInput(DT_FLOAT, backprop)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, backprop)
|
||||
.Attr("T", type)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||
.Attr("data_format", "NHWC"));
|
||||
@ -1342,12 +1397,13 @@ TEST_F(OpTest, Conv2DBackpropInput) {
|
||||
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
||||
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
||||
features_in, features_out};
|
||||
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Conv2DBackpropInput")
|
||||
.Input(in_shape)
|
||||
.RandomInput(DT_FLOAT, kernel)
|
||||
.RandomInput(DT_FLOAT, backprop)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, kernel)
|
||||
.RandomInput(type, backprop)
|
||||
.Attr("T", type)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||
.Attr("data_format", "NHWC"));
|
||||
@ -1365,11 +1421,12 @@ TEST_F(OpTest, Conv3D) {
|
||||
|
||||
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
||||
d.kernel_dims[2], features_in, features_out};
|
||||
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Conv3D")
|
||||
.RandomInput(DT_FLOAT, data)
|
||||
.RandomInput(DT_FLOAT, kernel)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, data)
|
||||
.RandomInput(type, kernel)
|
||||
.Attr("T", type)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
||||
});
|
||||
@ -1389,12 +1446,13 @@ TEST_F(OpTest, Conv3DBackpropFilter) {
|
||||
Tensor kernel_shape = test::AsTensor<int32>(
|
||||
AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2],
|
||||
features_in, features_out}));
|
||||
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Conv3DBackpropFilterV2")
|
||||
.RandomInput(DT_FLOAT, activations)
|
||||
.RandomInput(type, activations)
|
||||
.Input(kernel_shape)
|
||||
.RandomInput(DT_FLOAT, backprop)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, backprop)
|
||||
.Attr("T", type)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
||||
});
|
||||
@ -1413,17 +1471,34 @@ TEST_F(OpTest, Conv3DBackpropInput) {
|
||||
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
||||
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
||||
d.kernel_dims[2], features_in, features_out};
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Conv3DBackpropInputV2")
|
||||
.Input(in_shape)
|
||||
.RandomInput(DT_FLOAT, kernel)
|
||||
.RandomInput(DT_FLOAT, backprop)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, kernel)
|
||||
.RandomInput(type, backprop)
|
||||
.Attr("T", type)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Cos) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Cos").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Cosh) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Cosh").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, DepthToSpace) {
|
||||
Repeatedly([this]() {
|
||||
int64 block = RandomDim(2, 5);
|
||||
@ -1431,14 +1506,16 @@ TEST_F(OpTest, DepthToSpace) {
|
||||
input_dims[1] = (input_dims[1] + (block - 1)) / block;
|
||||
input_dims[2] = (input_dims[2] + (block - 1)) / block;
|
||||
input_dims[3] *= block * block;
|
||||
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace")
|
||||
.RandomInput(DT_FLOAT, input_dims)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, input_dims)
|
||||
.Attr("T", type)
|
||||
.Attr("block_size", block));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, DepthwiseConv2DNative) {
|
||||
if (1) return;
|
||||
Repeatedly([this]() {
|
||||
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
||||
std::uniform_int_distribution<int> random_int(1, 5);
|
||||
@ -1449,17 +1526,20 @@ TEST_F(OpTest, DepthwiseConv2DNative) {
|
||||
|
||||
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
|
||||
features_in, depth_multiplier};
|
||||
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
|
||||
strides[2] = strides[1]; // Current impl only supports equal strides
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("DepthwiseConv2dNative")
|
||||
.RandomInput(DT_FLOAT, input_dims)
|
||||
.RandomInput(DT_FLOAT, kernel_dims)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("strides", strides)
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
|
||||
if (1) return;
|
||||
Repeatedly([this]() {
|
||||
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
||||
std::uniform_int_distribution<int> random_int(1, 5);
|
||||
@ -1472,33 +1552,22 @@ TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
|
||||
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
|
||||
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
|
||||
{d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}));
|
||||
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
|
||||
strides[2] = strides[1]; // Current impl only supports equal strides
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("DepthwiseConv2dNativeBackpropFilter")
|
||||
.RandomInput(DT_FLOAT, activations)
|
||||
.Input(kernel_shape)
|
||||
.RandomInput(DT_FLOAT, backprop)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("strides", strides)
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||
.Attr("data_format", "NHWC"));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Cos) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Cos").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Cosh) {
|
||||
Repeatedly([this]() {
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Cosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
|
||||
if (1) return;
|
||||
Repeatedly([this]() {
|
||||
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
||||
std::uniform_int_distribution<int> random_int(1, 5);
|
||||
@ -1511,21 +1580,24 @@ TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
|
||||
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
|
||||
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
||||
features_in, depth_multiplier};
|
||||
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
|
||||
strides[2] = strides[1]; // Current impl only supports equal strides
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("DepthwiseConv2dNativeBackpropInput")
|
||||
.Input(in_shape)
|
||||
.RandomInput(DT_FLOAT, kernel)
|
||||
.RandomInput(DT_FLOAT, backprop)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||
.Attr("strides", strides)
|
||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||
.Attr("data_format", "NHWC"));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Diag) {
|
||||
if (1) return;
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||
std::vector<int64> dims;
|
||||
// Diag causes a quadratic blowup in output size.
|
||||
int64 size;
|
||||
@ -1540,7 +1612,7 @@ TEST_F(OpTest, Diag) {
|
||||
|
||||
TEST_F(OpTest, DiagPart) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||
auto dims = RandomDims(1, 3);
|
||||
// Duplicate the random dims.
|
||||
std::vector<int64> doubled_dims(dims.size() * 2);
|
||||
@ -1554,7 +1626,7 @@ TEST_F(OpTest, DiagPart) {
|
||||
|
||||
TEST_F(OpTest, Div) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
auto dims = BroadcastableDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
|
||||
.RandomInput(type, dims.first)
|
||||
@ -1650,7 +1722,7 @@ TEST_F(OpTest, SeluGrad) {
|
||||
|
||||
TEST_F(OpTest, Equal) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
auto dims = BroadcastableDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
|
||||
.RandomInput(type, dims.first)
|
||||
@ -1661,15 +1733,17 @@ TEST_F(OpTest, Equal) {
|
||||
|
||||
TEST_F(OpTest, Exp) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Exp").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Exp").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Expm1) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Expm1").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Expm1").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -1809,15 +1883,17 @@ TEST_F(OpTest, LinSpace) {
|
||||
|
||||
TEST_F(OpTest, Log) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Log").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Log").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Log1p) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Log1p").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Log1p").RandomInput(type).Attr("T", DT_FLOAT));
|
||||
});
|
||||
}
|
||||
|
||||
@ -1914,10 +1990,11 @@ TEST_F(OpTest, MatMul) {
|
||||
std::swap(b_dims[0], b_dims[1]);
|
||||
}
|
||||
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
|
||||
.RandomInput(DT_FLOAT, a_dims)
|
||||
.RandomInput(DT_FLOAT, b_dims)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.RandomInput(type, a_dims)
|
||||
.RandomInput(type, b_dims)
|
||||
.Attr("T", type)
|
||||
.Attr("transpose_a", transpose_a)
|
||||
.Attr("transpose_b", transpose_b));
|
||||
});
|
||||
@ -1925,7 +2002,7 @@ TEST_F(OpTest, MatMul) {
|
||||
|
||||
TEST_F(OpTest, MatrixDiag) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
|
||||
.RandomInput(type, RandomDims(1))
|
||||
.Attr("T", type));
|
||||
@ -1934,7 +2011,7 @@ TEST_F(OpTest, MatrixDiag) {
|
||||
|
||||
TEST_F(OpTest, MatrixDiagPart) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
|
||||
.RandomInput(type, RandomDims(2))
|
||||
.Attr("T", type));
|
||||
@ -2025,7 +2102,7 @@ TEST_F(OpTest, MaxPool3D) {
|
||||
|
||||
TEST_F(OpTest, Mean) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
// TODO(phawkins): CPU and XLA differ output for reducing across a
|
||||
// size-0 dimension (nan vs 0). For now, require size >= 1.
|
||||
std::vector<int64> data_dims = RandomDims(0, kDefaultMaxRank, 1);
|
||||
@ -2076,7 +2153,7 @@ TEST_F(OpTest, Mod) {
|
||||
|
||||
TEST_F(OpTest, Mul) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
auto dims = BroadcastableDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
|
||||
.RandomInput(type, dims.first)
|
||||
@ -2087,7 +2164,7 @@ TEST_F(OpTest, Mul) {
|
||||
|
||||
TEST_F(OpTest, Neg) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Neg").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
@ -2095,7 +2172,7 @@ TEST_F(OpTest, Neg) {
|
||||
|
||||
TEST_F(OpTest, NotEqual) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
auto dims = BroadcastableDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
|
||||
.RandomInput(type, dims.first)
|
||||
@ -2136,7 +2213,7 @@ TEST_F(OpTest, OneHot) {
|
||||
|
||||
TEST_F(OpTest, OnesLike) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
@ -2195,16 +2272,17 @@ TEST_F(OpTest, Pow) {
|
||||
// nontermination.
|
||||
Repeatedly([this]() {
|
||||
auto dims = BroadcastableDims();
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow")
|
||||
.RandomInput(DT_FLOAT, dims.first)
|
||||
.RandomInput(DT_FLOAT, dims.second)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.RandomInput(type, dims.first)
|
||||
.RandomInput(type, dims.second)
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Prod) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
std::vector<int64> data_dims = RandomDims();
|
||||
Tensor indices = RandomReductionIndices(data_dims.size());
|
||||
bool keep_dims = Choose<bool>({false, true});
|
||||
@ -2238,7 +2316,7 @@ TEST_F(OpTest, Range) {
|
||||
|
||||
TEST_F(OpTest, Rank) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Rank").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
@ -2246,7 +2324,7 @@ TEST_F(OpTest, Rank) {
|
||||
|
||||
TEST_F(OpTest, RealDiv) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = DT_FLOAT;
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
auto dims = BroadcastableDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
|
||||
.RandomInput(type, dims.first)
|
||||
@ -2257,18 +2335,20 @@ TEST_F(OpTest, RealDiv) {
|
||||
|
||||
TEST_F(OpTest, Reciprocal) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Reciprocal").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Reciprocal").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, ReciprocalGrad) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> dims = RandomDims();
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.RandomInput(type, dims)
|
||||
.RandomInput(type, dims)
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
TEST_F(OpTest, Relu) {
|
||||
@ -2335,24 +2415,24 @@ TEST_F(OpTest, Reshape) {
|
||||
TEST_F(OpTest, Reverse) {
|
||||
Repeatedly([this]() {
|
||||
std::vector<int64> dims = RandomDims(1);
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||
int64 rank = dims.size();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
|
||||
.RandomInput(type, dims)
|
||||
.RandomInput(DT_BOOL, {rank})
|
||||
.Attr("T", DT_FLOAT));
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, ReverseV2) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||
std::vector<int64> data_dims = RandomDims();
|
||||
Tensor indices = RandomReductionIndices(data_dims.size());
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
|
||||
.RandomInput(type, data_dims)
|
||||
.Input(indices)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -2372,18 +2452,20 @@ TEST_F(OpTest, Round) {
|
||||
|
||||
TEST_F(OpTest, Rsqrt) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Rsqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Rsqrt").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, RsqrtGrad) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.RandomInput(type, dims)
|
||||
.RandomInput(type, dims)
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -2411,24 +2493,26 @@ TEST_F(OpTest, ShapeN) {
|
||||
|
||||
TEST_F(OpTest, Sigmoid) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Sigmoid").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Sigmoid").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SigmoidGrad) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.RandomInput(type, dims)
|
||||
.RandomInput(type, dims)
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Sign) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Sign").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
@ -2436,21 +2520,23 @@ TEST_F(OpTest, Sign) {
|
||||
|
||||
TEST_F(OpTest, Sin) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Sin").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Sin").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Sinh) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Sinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Sinh").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Size) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Size").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
@ -2562,10 +2648,11 @@ TEST_F(OpTest, SpaceToBatch) {
|
||||
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
|
||||
TensorShape({num_block_dims, 2})));
|
||||
|
||||
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
|
||||
.RandomInput(DT_FLOAT, input_dims)
|
||||
.RandomInput(type, input_dims)
|
||||
.Input(paddings)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("T", type)
|
||||
.Attr("block_size", block_size));
|
||||
});
|
||||
}
|
||||
@ -2603,13 +2690,14 @@ TEST_F(OpTest, SpaceToBatchND) {
|
||||
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
|
||||
TensorShape({num_block_dims, 2})));
|
||||
|
||||
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("SpaceToBatchND")
|
||||
.RandomInput(DT_FLOAT, input_dims)
|
||||
.RandomInput(type, input_dims)
|
||||
.Input(test::AsTensor<int32>(
|
||||
std::vector<int32>(block_dims.begin(), block_dims.end())))
|
||||
.Input(paddings)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -2699,18 +2787,20 @@ TEST_F(OpTest, Split) {
|
||||
|
||||
TEST_F(OpTest, Sqrt) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Sqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Sqrt").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, SqrtGrad) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.RandomInput(type, dims)
|
||||
.RandomInput(type, dims)
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -2726,7 +2816,7 @@ TEST_F(OpTest, SquaredDifference) {
|
||||
|
||||
TEST_F(OpTest, Square) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Square").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
@ -2752,7 +2842,7 @@ TEST_F(OpTest, Squeeze) {
|
||||
|
||||
TEST_F(OpTest, Sub) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
auto dims = BroadcastableDims();
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
|
||||
.RandomInput(type, dims.first)
|
||||
@ -2763,7 +2853,7 @@ TEST_F(OpTest, Sub) {
|
||||
|
||||
TEST_F(OpTest, Sum) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
std::vector<int64> data_dims = RandomDims();
|
||||
Tensor indices = RandomReductionIndices(data_dims.size());
|
||||
bool keep_dims = Choose<bool>({false, true});
|
||||
@ -2875,25 +2965,28 @@ TEST_F(OpTest, StridedSliceGrad) {
|
||||
|
||||
TEST_F(OpTest, Tan) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Tan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Tan").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, Tanh) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Tanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
||||
OpTestBuilder("Tanh").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(OpTest, TanhGrad) {
|
||||
Repeatedly([this]() {
|
||||
auto dims = RandomDims();
|
||||
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.RandomInput(DT_FLOAT, dims)
|
||||
.Attr("T", DT_FLOAT));
|
||||
.RandomInput(type, dims)
|
||||
.RandomInput(type, dims)
|
||||
.Attr("T", type));
|
||||
});
|
||||
}
|
||||
|
||||
@ -2951,7 +3044,7 @@ TEST_F(OpTest, TruncateMod) {
|
||||
|
||||
TEST_F(OpTest, ZerosLike) {
|
||||
Repeatedly([this]() {
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type));
|
||||
});
|
||||
|
@ -328,6 +328,131 @@ class UnaryOpsTest(XLATestCase):
|
||||
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
|
||||
expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype))
|
||||
|
||||
def testComplexOps(self):
|
||||
for dtype in self.complex_types:
|
||||
# TODO(b/65408531): math_ops.acosh (needs pow)
|
||||
# TODO(b/65408531): math_ops.asinh (needs pow)
|
||||
|
||||
# TODO(b/65408531): Wider support for log (needs atan2).
|
||||
atan2_supported = self.device == "XLA_GPU"
|
||||
if atan2_supported:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.atanh,
|
||||
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
|
||||
expected=np.arctanh(
|
||||
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.cosh,
|
||||
np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype),
|
||||
expected=np.cosh(np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.sinh,
|
||||
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
|
||||
expected=np.sinh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.exp,
|
||||
np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
|
||||
expected=np.exp(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.expm1,
|
||||
np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
|
||||
expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.reciprocal,
|
||||
np.array([[1, 2j, 2 + 3j]], dtype=dtype),
|
||||
expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype))
|
||||
|
||||
if atan2_supported:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.log,
|
||||
np.array([[5j, 3 - 2j]], dtype=dtype),
|
||||
expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.sin,
|
||||
np.array([[5j, 3 - 2j]], dtype=dtype),
|
||||
expected=np.sin(np.array([[5j, 3 - 2j]], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.cos,
|
||||
np.array([[5j, 3 - 2j]], dtype=dtype),
|
||||
expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype)))
|
||||
|
||||
# TODO(b/34703906): improve log1p implementation and make tolerance
|
||||
# tighter.
|
||||
if atan2_supported: # TODO(b/34703906): log support
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.log1p,
|
||||
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype),
|
||||
expected=np.log1p(
|
||||
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
|
||||
|
||||
# TODO(b/34703906): math_ops.rsqrt (needs pow)
|
||||
|
||||
# TODO(b/34703906): math_ops.sigmoid (needs tanh)
|
||||
|
||||
# TODO(b/34703906): math_ops.sqrt (needs pow)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.tan,
|
||||
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
|
||||
expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
|
||||
|
||||
# TODO(b/34703906): math_ops.tanh (as itself)
|
||||
|
||||
ctypes = {np.complex64: np.float32}
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.abs,
|
||||
np.array([[3 - 4j, -1j, np.inf]], dtype=dtype),
|
||||
expected=np.array([[5, 1, np.inf]], dtype=ctypes[dtype]))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.negative,
|
||||
np.array([[-1 + 2j, -3j]], dtype=dtype),
|
||||
expected=np.array([[1 - 2j, 3j]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.square,
|
||||
np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype),
|
||||
expected=np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype)**2)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.zeros_like,
|
||||
np.array([[4j, 3 - 2j], [2, -1j]], dtype=dtype),
|
||||
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.ones_like,
|
||||
np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype),
|
||||
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
|
||||
|
||||
if atan2_supported: # TODO(b/34703906): atan2 support
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.angle,
|
||||
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||
expected=np.angle(
|
||||
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype)))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.conj,
|
||||
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||
expected=np.array([1 - 3j, -4 - 7j, 2.7, 3j], dtype=dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.imag,
|
||||
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||
expected=np.array([3, 7, 0, -3], dtype=ctypes[dtype]))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.real,
|
||||
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||
expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype]))
|
||||
|
||||
def testIntOps(self):
|
||||
for dtype in self.int_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
@ -399,11 +524,14 @@ class UnaryOpsTest(XLATestCase):
|
||||
|
||||
def testCast(self):
|
||||
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
||||
types = [dtypes.bool, dtypes.int32, dtypes.float32]
|
||||
types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types
|
||||
for shape in shapes:
|
||||
for src_type in types:
|
||||
for dst_type in types:
|
||||
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
|
||||
if src_type in self.complex_tf_types:
|
||||
src += (np.arange(np.prod(shape)) * 2j).astype(
|
||||
src_type.as_numpy_dtype)
|
||||
src = src.reshape(shape)
|
||||
|
||||
dst = src.astype(dst_type.as_numpy_dtype)
|
||||
|
@ -43,7 +43,7 @@ class VariableOpsTest(XLATestCase):
|
||||
# Regression test for a bug where computations with one non-constant
|
||||
# output and one variable update were mishandled.
|
||||
for dtype in self.numeric_types:
|
||||
init = np.array([[1, 2], [3, 4]], dtype=dtype)
|
||||
init = np.array([[1, 2j], [3, 4]]).astype(dtype)
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(init)
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
@ -51,82 +51,91 @@ class VariableOpsTest(XLATestCase):
|
||||
x = v.assign_add(p)
|
||||
with ops.control_dependencies([x]):
|
||||
y = v.read_value()
|
||||
self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype),
|
||||
sess.run(y, {p: 1}))
|
||||
self.assertAllClose(
|
||||
np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {
|
||||
p: 1
|
||||
}))
|
||||
|
||||
def testSparseRead0DIndices(self):
|
||||
for dtype in self.numeric_types:
|
||||
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
|
||||
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10,
|
||||
11]]).astype(dtype)
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(init)
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
x = v.sparse_read(2)
|
||||
self.assertAllClose(np.array([8, 9, 10, 11], dtype=dtype), sess.run(x))
|
||||
self.assertAllClose(
|
||||
np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x))
|
||||
|
||||
def testSparseRead1DIndices(self):
|
||||
for dtype in self.numeric_types:
|
||||
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
|
||||
init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10,
|
||||
11]]).astype(dtype)
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(init)
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
x = v.sparse_read([2, 1])
|
||||
self.assertAllClose(
|
||||
np.array([[8, 9, 10, 11], [4, 5, 6, 7]], dtype=dtype), sess.run(x))
|
||||
np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
|
||||
sess.run(x))
|
||||
|
||||
def testSparseRead2DIndices(self):
|
||||
for dtype in self.numeric_types:
|
||||
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
|
||||
init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10,
|
||||
11]]).astype(dtype)
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(init)
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
x = v.sparse_read([[2, 1], [0, 2]])
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2, 3], [8, 9, 10,
|
||||
11]]],
|
||||
dtype=dtype), sess.run(x))
|
||||
np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
|
||||
[[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
|
||||
sess.run(x))
|
||||
|
||||
def testSparseRead2DIndices3DTensor(self):
|
||||
for dtype in self.numeric_types:
|
||||
init = np.array(
|
||||
[[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
|
||||
[[20, 21, 22], [23, 24, 25]], [[30, 31, 32], [33, 34, 35]]],
|
||||
dtype=dtype)
|
||||
init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
|
||||
[[20, 21, 22], [23, 24j, 25]],
|
||||
[[30, 31, 32], [33, 34, 35]]]).astype(dtype)
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(init)
|
||||
sess.run(variables.variables_initializer([v]))
|
||||
x = v.sparse_read([[2, 1], [3, 0]])
|
||||
self.assertAllClose(
|
||||
np.array(
|
||||
[[[[20, 21, 22], [23, 24, 25]], [[10, 11, 12], [13, 14, 15]]],
|
||||
[[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
|
||||
[[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]],
|
||||
dtype=dtype), sess.run(x))
|
||||
).astype(dtype), sess.run(x))
|
||||
|
||||
def testReadWrite(self):
|
||||
"""Tests initialization, reading, and writing a resource variable."""
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
with variable_scope.variable_scope("ascope", use_resource=True):
|
||||
x = variable_scope.get_variable(
|
||||
"x",
|
||||
shape=[],
|
||||
dtype=dtypes.float32,
|
||||
initializer=init_ops.constant_initializer(2))
|
||||
a = x.read_value()
|
||||
with ops.control_dependencies([a]):
|
||||
b = state_ops.assign(x, 47)
|
||||
with ops.control_dependencies([b]):
|
||||
c = x.read_value()
|
||||
with ops.control_dependencies([c]):
|
||||
d = state_ops.assign_add(x, 3)
|
||||
with ops.control_dependencies([d]):
|
||||
e = x.read_value()
|
||||
for dtype in self.numeric_types:
|
||||
with self.test_session() as session:
|
||||
print(ops.get_default_graph())
|
||||
with self.test_scope():
|
||||
with variable_scope.variable_scope("ascope", use_resource=True):
|
||||
x = variable_scope.get_variable(
|
||||
"x",
|
||||
shape=[],
|
||||
dtype=dtype,
|
||||
initializer=init_ops.constant_initializer(2))
|
||||
a = x.read_value()
|
||||
with ops.control_dependencies([a]):
|
||||
b = state_ops.assign(x, dtype(47))
|
||||
with ops.control_dependencies([b]):
|
||||
c = x.read_value()
|
||||
with ops.control_dependencies([c]):
|
||||
d = state_ops.assign_add(x, np.array(6 + 2j).astype(dtype))
|
||||
with ops.control_dependencies([d]):
|
||||
e = state_ops.assign_sub(x, dtype(3))
|
||||
with ops.control_dependencies([e]):
|
||||
f = x.read_value()
|
||||
|
||||
session.run(variables.global_variables_initializer())
|
||||
v1, v2, v3 = session.run([a, c, e])
|
||||
self.assertAllClose(2.0, v1)
|
||||
self.assertAllClose(47.0, v2)
|
||||
self.assertAllClose(50.0, v3)
|
||||
session.run(variables.global_variables_initializer())
|
||||
v1, v2, v3 = session.run([a, c, f])
|
||||
self.assertAllClose(dtype(2), v1)
|
||||
self.assertAllClose(dtype(47), v2)
|
||||
self.assertAllClose(np.array(50 + 2j).astype(dtype), v3)
|
||||
|
||||
def testTraining(self):
|
||||
"""Tests a gradient descent step for a simple model."""
|
||||
|
@ -63,12 +63,19 @@ class XLATestCase(test.TestCase):
|
||||
self.float_tf_types = [
|
||||
dtype for dtype in self.all_tf_types if dtype.is_floating
|
||||
]
|
||||
self.numeric_tf_types = self.int_tf_types + self.float_tf_types
|
||||
self.complex_tf_types = [
|
||||
dtype for dtype in self.all_tf_types if dtype.is_complex
|
||||
]
|
||||
self.numeric_tf_types = (
|
||||
self.int_tf_types + self.float_tf_types + self.complex_tf_types)
|
||||
|
||||
self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
|
||||
self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types]
|
||||
self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types]
|
||||
self.numeric_types = self.int_types + self.float_types
|
||||
self.complex_types = [
|
||||
dtype.as_numpy_dtype for dtype in self.complex_tf_types
|
||||
]
|
||||
self.numeric_types = self.int_types + self.float_types + self.complex_types
|
||||
|
||||
# Parse the manifest file, if any, into a regex identifying tests to
|
||||
# disable
|
||||
|
@ -77,7 +77,13 @@ class BatchMatMulOp : public XlaOpKernel {
|
||||
xla::ComputationBuilder* builder = ctx->builder();
|
||||
|
||||
xla::ComputationDataHandle x_handle = ctx->Input(0);
|
||||
if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) {
|
||||
x_handle = builder->Conj(x_handle);
|
||||
}
|
||||
xla::ComputationDataHandle y_handle = ctx->Input(1);
|
||||
if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) {
|
||||
y_handle = builder->Conj(y_handle);
|
||||
}
|
||||
|
||||
// Reshape input tensors into 3D tensors by flattening the batch
|
||||
// dimensions. This makes it easier to unroll the batch dimension.
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Native XLA implementations of simple unary Ops
|
||||
// Native XLA implementations of simple binary Ops
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -50,6 +51,9 @@ XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions));
|
||||
|
||||
XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions));
|
||||
XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions));
|
||||
|
||||
// Implementation of FloorDiv. Pseudo-code:
|
||||
// if ((x < 0) != (y < 0)) {
|
||||
// T abs_x = std::abs(x);
|
||||
@ -171,8 +175,12 @@ class ApproximateEqualOp : public XlaOpKernel {
|
||||
// Computes the max of the scalar input x and 0.
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), tolerance_));
|
||||
auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1)));
|
||||
auto abs_shape = b->GetShape(abs);
|
||||
OP_REQUIRES_OK(ctx, abs_shape.status());
|
||||
auto abs_type = abs_shape.ValueOrDie()->element_type();
|
||||
auto result = b->Lt(
|
||||
abs, b->ConvertElementType(b->ConstantR0<float>(tolerance_), abs_type));
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -40,6 +41,11 @@ class CastOp : public XlaOpKernel {
|
||||
output = input;
|
||||
} else if (dst_dtype_ == DT_BOOL) {
|
||||
output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_));
|
||||
} else if (xla::primitive_util::IsComplexType(src_type_) &&
|
||||
!xla::primitive_util::IsComplexType(dst_type_)) {
|
||||
// As in cast_op.h, we replicate the numpy behavior of truncating the
|
||||
// imaginary part.
|
||||
output = builder->ConvertElementType(builder->Real(input), dst_type_);
|
||||
} else {
|
||||
output = builder->ConvertElementType(input, dst_type_);
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) {
|
||||
errors::InvalidArgument("indices must be int32 or int64"));
|
||||
|
||||
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
|
||||
context, input, input_shape, indices, indices_shape, axis, DT_FLOAT,
|
||||
context, input, input_shape, indices, indices_shape, axis, input_type(0),
|
||||
index_type, builder);
|
||||
context->SetOutput(0, gather);
|
||||
}
|
||||
|
@ -23,6 +23,9 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr std::array<DataType, 4> kMatmulTypes = {
|
||||
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
|
||||
|
||||
class MatMulOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false)
|
||||
@ -73,7 +76,7 @@ class MatMulOp : public XlaOpKernel {
|
||||
bool transpose_b_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kFloatTypes), MatMulOp);
|
||||
REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kMatmulTypes), MatMulOp);
|
||||
|
||||
class SparseMatMulOp : public MatMulOp {
|
||||
public:
|
||||
|
@ -37,8 +37,9 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyGradientDescent"),
|
||||
ResourceApplyGradientDescent);
|
||||
REGISTER_XLA_OP(
|
||||
Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyGradientDescent);
|
||||
|
||||
class ResourceApplyMomentum : public XlaOpKernel {
|
||||
public:
|
||||
@ -109,7 +110,8 @@ class ResourceApplyMomentum : public XlaOpKernel {
|
||||
private:
|
||||
bool use_nesterov_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyMomentum"), ResourceApplyMomentum);
|
||||
REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyMomentum);
|
||||
|
||||
class ResourceApplyAdagrad : public XlaOpKernel {
|
||||
public:
|
||||
@ -163,7 +165,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyAdagrad"), ResourceApplyAdagrad);
|
||||
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyAdagrad);
|
||||
|
||||
class ResourceApplyAdam : public XlaOpKernel {
|
||||
public:
|
||||
@ -263,7 +266,8 @@ class ResourceApplyAdam : public XlaOpKernel {
|
||||
private:
|
||||
DataType dtype_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyAdam"), ResourceApplyAdam);
|
||||
REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyAdam);
|
||||
|
||||
class ResourceApplyRMSProp : public XlaOpKernel {
|
||||
public:
|
||||
@ -362,7 +366,8 @@ class ResourceApplyRMSProp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp);
|
||||
REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyRMSProp);
|
||||
|
||||
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
|
||||
bool has_l2_shrinkage) {
|
||||
@ -500,7 +505,8 @@ class ResourceApplyFtrl : public XlaOpKernel {
|
||||
private:
|
||||
DataType dtype_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl);
|
||||
REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyFtrl);
|
||||
|
||||
class ResourceApplyFtrlV2 : public XlaOpKernel {
|
||||
public:
|
||||
@ -515,7 +521,8 @@ class ResourceApplyFtrlV2 : public XlaOpKernel {
|
||||
private:
|
||||
DataType dtype_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2"), ResourceApplyFtrlV2);
|
||||
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
|
||||
ResourceApplyFtrlV2);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -41,6 +41,12 @@ namespace {
|
||||
}; \
|
||||
REGISTER_XLA_OP(Name(#NAME), NAME##Op);
|
||||
|
||||
XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x)));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Conj, b->Complex(b->Real(x), b->Neg(b->Imag(x))));
|
||||
|
||||
// Return x if x>0, otherwise -x.
|
||||
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
|
||||
|
||||
@ -162,6 +168,9 @@ XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
|
||||
XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x)));
|
||||
XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Real, b->Real(x));
|
||||
XLAJIT_MAKE_UNARY(Imag, b->Imag(x));
|
||||
|
||||
#undef XLAJIT_MAKE_UNARY
|
||||
|
||||
} // namespace
|
||||
|
@ -97,6 +97,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
|
||||
case xla::F64:
|
||||
literal = *xla::Literal::CreateR0<double>(value);
|
||||
break;
|
||||
case xla::C64:
|
||||
literal = *xla::Literal::CreateR0<complex64>(value);
|
||||
break;
|
||||
case xla::PRED:
|
||||
LOG(FATAL) << "pred element type is not integral";
|
||||
case xla::S16:
|
||||
@ -132,6 +135,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
|
||||
case xla::F64:
|
||||
return b->ConstantR0<double>(value);
|
||||
break;
|
||||
case xla::C64:
|
||||
return b->ConstantR0<complex64>(value);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled element type " << type;
|
||||
}
|
||||
|
@ -47,14 +47,17 @@ extern const char* const DEVICE_XLA_GPU;
|
||||
|
||||
constexpr std::array<DataType, 3> kFloatTypes = {
|
||||
{DT_HALF, DT_FLOAT, DT_DOUBLE}};
|
||||
constexpr std::array<DataType, 7> kNumericTypes = {
|
||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}};
|
||||
constexpr std::array<DataType, 8> kNumericTypes = {
|
||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
|
||||
DT_COMPLEX64}};
|
||||
|
||||
constexpr std::array<DataType, 7> kCpuAllTypes = {
|
||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
||||
constexpr std::array<DataType, 8> kCpuAllTypes = {
|
||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
|
||||
DT_COMPLEX64, DT_BOOL}};
|
||||
|
||||
constexpr std::array<DataType, 7> kGpuAllTypes = {
|
||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
||||
constexpr std::array<DataType, 8> kGpuAllTypes = {
|
||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
|
||||
DT_COMPLEX64, DT_BOOL}};
|
||||
|
||||
// Class that manages registrations of operators and devices for the XLA JIT.
|
||||
// Not thread-safe.
|
||||
|
@ -913,6 +913,17 @@ ComputationDataHandle ComputationBuilder::CustomCall(
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Complex(
|
||||
const ComputationDataHandle& real, const ComputationDataHandle& imag,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Conj(
|
||||
const ComputationDataHandle& operand) {
|
||||
return Complex(Real(operand), Neg(Imag(operand)));
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Add(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
@ -995,6 +1006,12 @@ ComputationDataHandle ComputationBuilder::Abs(
|
||||
return UnaryOp(UNOP_ABS, operand);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Atan2(
|
||||
const ComputationDataHandle& y, const ComputationDataHandle& x,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||
return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Exp(
|
||||
const ComputationDataHandle& operand) {
|
||||
return UnaryOp(UNOP_EXP, operand);
|
||||
@ -1040,6 +1057,16 @@ ComputationDataHandle ComputationBuilder::Tanh(
|
||||
return UnaryOp(UNOP_TANH, operand);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Real(
|
||||
const ComputationDataHandle& operand) {
|
||||
return UnaryOp(UNOP_REAL, operand);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Imag(
|
||||
const ComputationDataHandle& operand) {
|
||||
return UnaryOp(UNOP_IMAG, operand);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::IsFinite(
|
||||
const ComputationDataHandle& operand) {
|
||||
return UnaryOp(UNOP_IS_FINITE, operand);
|
||||
|
@ -431,6 +431,14 @@ class ComputationBuilder {
|
||||
// of the operands is a scalar, or an explicit broadcast dimension is given
|
||||
// (see g3doc for more details).
|
||||
|
||||
// Enqueues a complex compose instruction onto the computation.
|
||||
ComputationDataHandle Complex(
|
||||
const ComputationDataHandle& real, const ComputationDataHandle& imag,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a complex conjugate instruction onto the computation.
|
||||
ComputationDataHandle Conj(const ComputationDataHandle& operand);
|
||||
|
||||
// Enqueues an add instruction onto the computation.
|
||||
ComputationDataHandle Add(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||
@ -542,6 +550,11 @@ class ComputationBuilder {
|
||||
// Enqueues an abs instruction onto the computation.
|
||||
ComputationDataHandle Abs(const ComputationDataHandle& operand);
|
||||
|
||||
// Enqueues a atan2 instruction onto the computation.
|
||||
ComputationDataHandle Atan2(
|
||||
const ComputationDataHandle& y, const ComputationDataHandle& x,
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues an exp instruction onto the computation.
|
||||
ComputationDataHandle Exp(const ComputationDataHandle& operand);
|
||||
|
||||
@ -570,6 +583,12 @@ class ComputationBuilder {
|
||||
// Enqueues a tanh instruction onto the computation.
|
||||
ComputationDataHandle Tanh(const ComputationDataHandle& operand);
|
||||
|
||||
// Enqueues a real-part instruction onto the computation.
|
||||
ComputationDataHandle Real(const ComputationDataHandle& operand);
|
||||
|
||||
// Enqueues an imaginary-part instruction onto the computation.
|
||||
ComputationDataHandle Imag(const ComputationDataHandle& operand);
|
||||
|
||||
// Enqueues a float32 sqrt instruction onto the computation.
|
||||
// (float32 is specified as there is an implicit float32 0.5f constant
|
||||
// exponent).
|
||||
|
@ -204,6 +204,8 @@ Status Literal::Copy(const Literal& src_literal,
|
||||
return *Literal::CreateR0<float>(0);
|
||||
case F64:
|
||||
return *Literal::CreateR0<double>(0);
|
||||
case C64:
|
||||
return *Literal::CreateR0<complex64>(0);
|
||||
case PRED:
|
||||
return *Literal::CreateR0<bool>(false);
|
||||
case S16:
|
||||
@ -236,6 +238,8 @@ Status Literal::Copy(const Literal& src_literal,
|
||||
return *Literal::CreateR0<float>(1);
|
||||
case F64:
|
||||
return *Literal::CreateR0<double>(1);
|
||||
case C64:
|
||||
return *Literal::CreateR0<complex64>(1);
|
||||
case PRED:
|
||||
return *Literal::CreateR0<bool>(true);
|
||||
case S16:
|
||||
@ -271,6 +275,8 @@ Status Literal::Copy(const Literal& src_literal,
|
||||
case F64:
|
||||
return *Literal::CreateR0<double>(
|
||||
-std::numeric_limits<double>::infinity());
|
||||
case C64:
|
||||
LOG(FATAL) << "C64 element type has no minimum value";
|
||||
case PRED:
|
||||
return *Literal::CreateR0<bool>(false);
|
||||
case S16:
|
||||
|
@ -141,6 +141,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleConvert(HloInstruction* convert) override;
|
||||
|
||||
Status HandleReal(HloInstruction* real, HloInstruction* operand) override;
|
||||
Status HandleImag(HloInstruction* imag, HloInstruction* operand) override;
|
||||
|
||||
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
|
||||
HloInstruction* rhs, const Window& window) override;
|
||||
|
||||
@ -967,6 +970,24 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Real(Complex(r, i)) -> r
|
||||
Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real,
|
||||
HloInstruction* operand) {
|
||||
if (operand->opcode() == HloOpcode::kComplex) {
|
||||
return ReplaceInstruction(real, operand->mutable_operand(0));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Imag(Complex(r, i)) -> i
|
||||
Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag,
|
||||
HloInstruction* operand) {
|
||||
if (operand->opcode() == HloOpcode::kComplex) {
|
||||
return ReplaceInstruction(imag, operand->mutable_operand(1));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
||||
// Eliminate nop pads (padding all zero), and replace a pad with negative
|
||||
// padding with a pad with non-negative padding followed by a slice.
|
||||
|
@ -433,6 +433,56 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
|
||||
EXPECT_EQ(root, param0);
|
||||
}
|
||||
|
||||
// Test that real(complex(r,i)) is simplified to r.
|
||||
TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
|
||||
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, r2f32, "param1"));
|
||||
HloInstruction* cplx = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
|
||||
HloOpcode::kComplex, param0, param1));
|
||||
HloInstruction* real = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root, real);
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
non_bitcasting_callback());
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
root = computation->root_instruction();
|
||||
EXPECT_EQ(root, param0);
|
||||
}
|
||||
|
||||
// Test that imag(complex(r,i)) is simplified to i.
|
||||
TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
|
||||
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||
HloInstruction* param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, r2f32, "param1"));
|
||||
HloInstruction* cplx = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
|
||||
HloOpcode::kComplex, param0, param1));
|
||||
HloInstruction* imag = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root, imag);
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
non_bitcasting_callback());
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
root = computation->root_instruction();
|
||||
EXPECT_EQ(root, param1);
|
||||
}
|
||||
|
||||
// Test that get_element(make_tuple({A,B}),1) is simplified to B
|
||||
TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
@ -63,7 +63,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
|
||||
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
|
||||
const HloModuleConfig& hlo_module_config) {
|
||||
PrimitiveType type = target_array.GetShape().element_type();
|
||||
TF_RET_CHECK(F32 == type || F64 == type);
|
||||
TF_RET_CHECK(F32 == type || F64 == type || C64 == type);
|
||||
DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
|
||||
lhs_array, rhs_array, executable_run_options_value,
|
||||
ir_builder, hlo_module_config);
|
||||
@ -176,7 +176,7 @@ tensorflow::Status DotOpEmitter::Emit() {
|
||||
llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
|
||||
ir_builder_->SetInsertPoint(preheader_bb->getTerminator());
|
||||
|
||||
ir_builder_->CreateStore(llvm::ConstantFP::get(accum_type, 0.0),
|
||||
ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type),
|
||||
accum_address);
|
||||
|
||||
// Body basic block of reduction loop:
|
||||
@ -191,9 +191,29 @@ tensorflow::Status DotOpEmitter::Emit() {
|
||||
llvm::Value* rhs_element =
|
||||
rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_);
|
||||
|
||||
llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
|
||||
llvm::Value* accum = ir_builder_->CreateLoad(accum_address);
|
||||
llvm::Value* updated_accum = ir_builder_->CreateFAdd(accum, product);
|
||||
llvm::Value* updated_accum;
|
||||
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
|
||||
auto real = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {0});
|
||||
};
|
||||
auto imag = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {1});
|
||||
};
|
||||
llvm::Value* product_real = ir_builder_->CreateFSub(
|
||||
ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)),
|
||||
ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element)));
|
||||
llvm::Value* product_imag = ir_builder_->CreateFAdd(
|
||||
ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)),
|
||||
ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element)));
|
||||
updated_accum = ir_builder_->CreateInsertValue(
|
||||
accum, ir_builder_->CreateFAdd(real(accum), product_real), {0});
|
||||
updated_accum = ir_builder_->CreateInsertValue(
|
||||
updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1});
|
||||
} else {
|
||||
llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
|
||||
updated_accum = ir_builder_->CreateFAdd(accum, product);
|
||||
}
|
||||
ir_builder_->CreateStore(updated_accum, accum_address);
|
||||
|
||||
// Exit basic block of reduction loop.
|
||||
@ -230,11 +250,28 @@ tensorflow::Status DotOpEmitter::Emit() {
|
||||
|
||||
tensorflow::Status DotOpEmitter::EmitScalarDot() {
|
||||
// A scalar dot is just a scalar multiply.
|
||||
llvm::Value* result;
|
||||
llvm::Value* lhs_value =
|
||||
lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
|
||||
llvm::Value* rhs_value =
|
||||
rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
|
||||
llvm::Value* result = ir_builder_->CreateFMul(lhs_value, rhs_value);
|
||||
if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
|
||||
#define REAL(x) ir_builder_->CreateExtractValue(x, {0})
|
||||
#define IMAG(x) ir_builder_->CreateExtractValue(x, {1})
|
||||
llvm::Value* real = ir_builder_->CreateFSub(
|
||||
ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)),
|
||||
ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value)));
|
||||
llvm::Value* imag = ir_builder_->CreateFAdd(
|
||||
ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)),
|
||||
ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value)));
|
||||
#undef IMAG
|
||||
#undef REAL
|
||||
result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
|
||||
result = ir_builder_->CreateInsertValue(result, real, {0});
|
||||
result = ir_builder_->CreateInsertValue(result, imag, {1});
|
||||
} else {
|
||||
result = ir_builder_->CreateFMul(lhs_value, rhs_value);
|
||||
}
|
||||
target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
@ -46,8 +46,8 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
|
||||
}
|
||||
// Create function type for the function.
|
||||
llvm::FunctionType* function_type = llvm::FunctionType::get(
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
|
||||
/*isVarArg=*/false);
|
||||
// Create function declaration for 'tanhf'.
|
||||
llvm::Function* function =
|
||||
|
@ -41,6 +41,12 @@ bool PotentiallyImplementedAsEigenConvolution(
|
||||
ShapeUtil::HasZeroElements(kernel_shape)) {
|
||||
return false;
|
||||
}
|
||||
// TODO(b/65408531): Explore using Eigen dot for complex64 type.
|
||||
if (ShapeUtil::ElementIsComplex(input_shape) ||
|
||||
ShapeUtil::ElementIsComplex(kernel_shape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
convolution.convolution_dimension_numbers();
|
||||
// Only 1D and 2D convolutions are supported at the moment.
|
||||
|
@ -288,7 +288,7 @@ Status IrEmitter::HandleConstant(HloInstruction* constant,
|
||||
MinimumAlignmentForShape(literal.shape()));
|
||||
} else {
|
||||
llvm::Constant* initializer =
|
||||
llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
|
||||
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
|
||||
global_for_const = new llvm::GlobalVariable(
|
||||
/*Module=*/*module_,
|
||||
/*Type=*/initializer->getType(),
|
||||
@ -401,7 +401,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
|
||||
const Shape& shape = get_tuple_element->shape();
|
||||
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
|
||||
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
|
||||
GetEmittedValueFor(operand), &ir_builder_);
|
||||
GetEmittedValueFor(operand), &ir_builder_, module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -412,9 +412,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
|
||||
|
||||
if (ShapeUtil::IsTuple(select->shape())) {
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select));
|
||||
llvm_ir::EmitTupleSelect(GetIrArrayFor(select), GetIrArrayFor(pred),
|
||||
GetEmittedValueFor(on_true),
|
||||
GetEmittedValueFor(on_false), &ir_builder_);
|
||||
llvm_ir::EmitTupleSelect(
|
||||
GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true),
|
||||
GetEmittedValueFor(on_false), &ir_builder_, module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -459,7 +459,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
|
||||
tuple_element_addresses.push_back(tuple_element_address);
|
||||
}
|
||||
|
||||
llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_);
|
||||
llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_,
|
||||
module_);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape,
|
||||
GetEmittedValueFor(infeed)));
|
||||
@ -562,7 +563,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||
ShapeUtil::GetTupleElementShape(operand_shape, i);
|
||||
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
|
||||
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
|
||||
value, &ir_builder_);
|
||||
value, &ir_builder_, module_);
|
||||
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
|
||||
tuple_element_shape, tuple_element));
|
||||
}
|
||||
@ -583,7 +584,7 @@ Status IrEmitter::HandleTuple(
|
||||
for (auto operand : operands) {
|
||||
base_ptrs.push_back(GetEmittedValueFor(operand));
|
||||
}
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_);
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -644,7 +645,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window,
|
||||
// the initial value on the reduce_window.
|
||||
PrimitiveType operand_element_type = operand->shape().element_type();
|
||||
llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
|
||||
"reduce_window_accumulator_address", &ir_builder_,
|
||||
MinimumAlignmentForPrimitiveType(operand_element_type));
|
||||
ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
|
||||
@ -769,7 +770,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
||||
// Allocate space to keep the currently selected value, its index, and
|
||||
// the boolean initialized_flag, which is initially set to false.
|
||||
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
|
||||
"selected_value_address", &ir_builder_,
|
||||
MinimumAlignmentForPrimitiveType(operand_element_type));
|
||||
llvm::Value* selected_index_address =
|
||||
@ -851,8 +852,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
||||
// If the 'select' function returns false, update the selected value and the
|
||||
// index to the currently visiting operand.
|
||||
llvm::Value* cond = ir_builder_.CreateICmpNE(
|
||||
result, llvm::ConstantInt::get(
|
||||
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
|
||||
result,
|
||||
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
|
||||
"boolean_predicate");
|
||||
llvm_ir::LlvmIfData if_select_lhs =
|
||||
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
|
||||
@ -895,7 +896,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
|
||||
/*supported_types=*/{F32, F64}));
|
||||
/*supported_types=*/{F32, F64, C64}));
|
||||
|
||||
llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
|
||||
llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
|
||||
@ -923,7 +924,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
|
||||
const Window& window) {
|
||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
|
||||
/*supported_types=*/{F32}));
|
||||
/*supported_types=*/{F32, C64}));
|
||||
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
convolution->convolution_dimension_numbers();
|
||||
@ -1079,7 +1080,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
|
||||
// the output entry at the given index.
|
||||
PrimitiveType lhs_element_type = lhs->shape().element_type();
|
||||
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, &ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_),
|
||||
"convolution_sum_address", &ir_builder_,
|
||||
MinimumAlignmentForPrimitiveType(lhs_element_type));
|
||||
ir_builder_.CreateStore(
|
||||
@ -1295,14 +1296,14 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
|
||||
PrimitiveType element_type = operand->shape().element_type();
|
||||
// Used to calculate E(X).
|
||||
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
|
||||
"sum_address", &ir_builder_,
|
||||
MinimumAlignmentForPrimitiveType(element_type));
|
||||
|
||||
// Used to calculate E(X^2).
|
||||
llvm::Value* sum_square_address =
|
||||
llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
|
||||
"sum_square_address", &ir_builder_,
|
||||
MinimumAlignmentForPrimitiveType(element_type));
|
||||
|
||||
@ -1425,7 +1426,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
|
||||
.EmitLoop(IrName(batch_norm_training, "normalize")));
|
||||
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training),
|
||||
{normalized, mean, var}, &ir_builder_);
|
||||
{normalized, mean, var}, &ir_builder_, module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1488,6 +1489,14 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
|
||||
}
|
||||
|
||||
const Shape& root_shape = root_instruction->shape();
|
||||
if (ShapeUtil::ElementIsComplex(root_shape)) {
|
||||
// TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
|
||||
// Complex multiply would be more challenging. We could perhaps use a
|
||||
// strided load to get all reals in a vector, all imags in a vector, or use
|
||||
// CreateShuffleVector on a bitcast to float x [2N].
|
||||
*failure_reason = "complex values not supported";
|
||||
return nullptr;
|
||||
}
|
||||
bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
|
||||
bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
|
||||
bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
|
||||
@ -1509,7 +1518,7 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
|
||||
// This is visually similar to ElementalIrEmitter, though conceptually we're
|
||||
// doing something different here. ElementalIrEmitter emits scalar operations
|
||||
// while these emit scalar or vector operations depending on the type of the
|
||||
// operands.
|
||||
// operands. See CreateShardedVectorType for the actual types in use here.
|
||||
switch (root_instruction->opcode()) {
|
||||
default:
|
||||
*failure_reason = "did not recognize root instruction opcode";
|
||||
@ -1586,7 +1595,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
|
||||
|
||||
ShardedVectorType sharded_vector_type;
|
||||
llvm::Type* element_ir_type =
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_);
|
||||
llvm_ir::PrimitiveTypeToIrType(element_type, module_);
|
||||
|
||||
for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
|
||||
// For every power of two present in element_count, we generate one or more
|
||||
@ -1919,7 +1928,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
||||
// Initialize an accumulator with init_value.
|
||||
PrimitiveType accumulator_type = reduce->shape().element_type();
|
||||
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(accumulator_type, &ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_),
|
||||
"accumulator", &ir_builder_,
|
||||
MinimumAlignmentForPrimitiveType(accumulator_type));
|
||||
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
|
||||
@ -2248,6 +2257,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
return Status::OK();
|
||||
} else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion,
|
||||
assignment_)) {
|
||||
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
|
||||
|
||||
@ -2257,6 +2267,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
fusion, operands, GetIrArrayFor(fusion), &elemental_emitter,
|
||||
&ir_builder_);
|
||||
} else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
VLOG(3) << "HandleFusion kLoop";
|
||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||
auto operands = GetIrArraysForOperandsOf(fusion);
|
||||
FusedIrEmitter fused_emitter(operands, &elemental_emitter);
|
||||
@ -2400,8 +2411,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
|
||||
{while_result}, IrName(xla_while, "cond"));
|
||||
llvm::Value* while_predicate = ir_builder_.CreateICmpNE(
|
||||
while_condition,
|
||||
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_),
|
||||
0));
|
||||
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
|
||||
|
||||
// Branches to the body or to the while exit depending on the condition.
|
||||
llvm::BasicBlock* body_bb = llvm::BasicBlock::Create(
|
||||
@ -2542,7 +2552,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
|
||||
unsigned element_alignment = GCD(
|
||||
primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
|
||||
llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
|
||||
llvm_ir::PrimitiveTypeToIrType(primitive_type, &ir_builder_));
|
||||
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
|
||||
|
||||
if (element_count == 1) {
|
||||
auto* load_instruction = ir_builder_.CreateAlignedLoad(
|
||||
@ -2755,7 +2765,7 @@ llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
|
||||
return llvm_ir::ShapeToIrType(shape, &ir_builder_);
|
||||
return llvm_ir::ShapeToIrType(shape, module_);
|
||||
}
|
||||
|
||||
std::vector<llvm::Type*> IrEmitter::GetComputeFunctionParams() {
|
||||
@ -2925,7 +2935,7 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall(
|
||||
PrimitiveType return_type = return_shape.element_type();
|
||||
llvm::Value* return_value_buffer =
|
||||
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
llvm_ir::PrimitiveTypeToIrType(return_type, &ir_builder_), elements,
|
||||
llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
|
||||
tensorflow::strings::StrCat(name, "_return_value_address"),
|
||||
&ir_builder_, MinimumAlignmentForPrimitiveType(return_type));
|
||||
EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
|
||||
@ -3100,7 +3110,7 @@ Status IrEmitter::EmitTargetElementLoop(
|
||||
for (int64 i = 0; i < output_arrays.size(); ++i) {
|
||||
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
|
||||
}
|
||||
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_);
|
||||
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_);
|
||||
|
||||
} else {
|
||||
if (ShouldEmitParallelLoopFor(*target_op)) {
|
||||
|
@ -85,6 +85,10 @@ class DfsHloVisitor {
|
||||
virtual Status HandleCopy(HloInstruction* copy) {
|
||||
return HandleElementwiseUnary(copy);
|
||||
}
|
||||
virtual Status HandleComplex(HloInstruction* complex, HloInstruction* real,
|
||||
HloInstruction* imag) {
|
||||
return HandleElementwiseBinary(complex);
|
||||
}
|
||||
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
return HandleElementwiseBinary(multiply);
|
||||
@ -122,6 +126,10 @@ class DfsHloVisitor {
|
||||
virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(abs);
|
||||
}
|
||||
virtual Status HandleAtan2(HloInstruction* atan2, HloInstruction* y,
|
||||
HloInstruction* x) {
|
||||
return HandleElementwiseBinary(atan2);
|
||||
}
|
||||
virtual Status HandleRound(HloInstruction* round) {
|
||||
return HandleElementwiseUnary(round);
|
||||
}
|
||||
@ -152,6 +160,12 @@ class DfsHloVisitor {
|
||||
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(tanh);
|
||||
}
|
||||
virtual Status HandleReal(HloInstruction* real, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(real);
|
||||
}
|
||||
virtual Status HandleImag(HloInstruction* imag, HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(imag);
|
||||
}
|
||||
virtual Status HandleIsFinite(HloInstruction* is_finite,
|
||||
HloInstruction* operand) {
|
||||
return HandleElementwiseUnary(is_finite);
|
||||
|
@ -54,10 +54,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const {
|
||||
if (op->opcode() == HloOpcode::kCopy) {
|
||||
return operand_value;
|
||||
} else if (operand_value->getType()->isIntegerTy()) {
|
||||
return EmitIntegerUnaryOp(op, operand_value);
|
||||
} else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
|
||||
return EmitComplexUnaryOp(op, operand_value);
|
||||
} else {
|
||||
return operand_value->getType()->isIntegerTy()
|
||||
? EmitIntegerUnaryOp(op, operand_value)
|
||||
: EmitFloatUnaryOp(op, operand_value);
|
||||
return EmitFloatUnaryOp(op, operand_value);
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,20 +75,35 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
||||
}
|
||||
if (primitive_util::IsIntegralType(to_type)) {
|
||||
return ir_builder_->CreateIntCast(
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_),
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
|
||||
primitive_util::IsSignedIntegralType(to_type));
|
||||
}
|
||||
if (primitive_util::IsFloatingPointType(to_type)) {
|
||||
if (primitive_util::IsSignedIntegralType(from_type)) {
|
||||
return ir_builder_->CreateSIToFP(
|
||||
operand_value,
|
||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||
}
|
||||
if (primitive_util::IsUnsignedIntegralType(from_type) ||
|
||||
from_type == PRED) {
|
||||
return ir_builder_->CreateUIToFP(
|
||||
operand_value,
|
||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||
}
|
||||
}
|
||||
if (primitive_util::IsComplexType(to_type)) {
|
||||
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
|
||||
primitive_util::ComplexComponentType(to_type), module_);
|
||||
if (primitive_util::IsSignedIntegralType(from_type)) {
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateSIToFP(operand_value, to_ir_component_type),
|
||||
nullptr);
|
||||
}
|
||||
if (primitive_util::IsUnsignedIntegralType(from_type) ||
|
||||
from_type == PRED) {
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateUIToFP(operand_value, to_ir_component_type),
|
||||
nullptr);
|
||||
}
|
||||
}
|
||||
return Unimplemented("conversion from primitive type %s to %s",
|
||||
@ -97,8 +114,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
||||
bool is_signed =
|
||||
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
||||
if (is_signed) {
|
||||
auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
|
||||
ir_builder_);
|
||||
auto type =
|
||||
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
||||
auto zero = llvm::ConstantInt::get(type, 0);
|
||||
auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
|
||||
return ir_builder_->CreateSelect(cmp, operand_value,
|
||||
@ -110,8 +127,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
||||
case HloOpcode::kSign: {
|
||||
bool is_signed =
|
||||
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
||||
auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
|
||||
ir_builder_);
|
||||
auto type =
|
||||
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
||||
auto zero = llvm::ConstantInt::get(type, 0);
|
||||
auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
|
||||
if (is_signed) {
|
||||
@ -135,7 +152,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
||||
return ir_builder_->CreateZExt(
|
||||
ir_builder_->CreateNot(ir_builder_->CreateTrunc(
|
||||
operand_value, ir_builder_->getInt1Ty())),
|
||||
llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
|
||||
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
|
||||
} else if (primitive_util::IsIntegralType(type)) {
|
||||
return ir_builder_->CreateNot(operand_value);
|
||||
}
|
||||
@ -157,20 +174,30 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
||||
if (from_type == to_type) {
|
||||
return operand_value;
|
||||
}
|
||||
if (primitive_util::IsComplexType(to_type)) {
|
||||
PrimitiveType to_component_type =
|
||||
primitive_util::ComplexComponentType(to_type);
|
||||
if (from_type == to_component_type) {
|
||||
return ComposeComplex(op, operand_value, nullptr);
|
||||
}
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateFPCast(
|
||||
operand_value,
|
||||
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
|
||||
nullptr);
|
||||
}
|
||||
if (primitive_util::IsFloatingPointType(to_type)) {
|
||||
return ir_builder_->CreateFPCast(
|
||||
operand_value,
|
||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||
}
|
||||
if (primitive_util::IsSignedIntegralType(to_type)) {
|
||||
return ir_builder_->CreateFPToSI(
|
||||
operand_value,
|
||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||
}
|
||||
if (primitive_util::IsUnsignedIntegralType(to_type)) {
|
||||
return ir_builder_->CreateFPToUI(
|
||||
operand_value,
|
||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||
}
|
||||
return Unimplemented("unhandled conversion operation: %s => %s",
|
||||
PrimitiveType_Name(from_type).c_str(),
|
||||
@ -230,7 +257,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
||||
auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
|
||||
auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
|
||||
return ir_builder_->CreateZExt(
|
||||
result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
|
||||
result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
|
||||
}
|
||||
case HloOpcode::kNegate:
|
||||
return ir_builder_->CreateFNeg(operand_value);
|
||||
@ -240,20 +267,164 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const {
|
||||
auto real = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {0});
|
||||
};
|
||||
auto imag = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {1});
|
||||
};
|
||||
switch (op->opcode()) {
|
||||
// TODO(b/65209142): Angle/Log require atan2.
|
||||
// case HloOpcode::kAngle:
|
||||
// case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
|
||||
case HloOpcode::kConvert: {
|
||||
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
||||
TF_RET_CHECK(primitive_util::IsComplexType(from_type));
|
||||
PrimitiveType to_type = op->shape().element_type();
|
||||
TF_RET_CHECK(primitive_util::IsComplexType(to_type));
|
||||
if (from_type == to_type) {
|
||||
return operand_value;
|
||||
}
|
||||
PrimitiveType to_component_type =
|
||||
primitive_util::ComplexComponentType(to_type);
|
||||
auto to_ir_component_type =
|
||||
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type),
|
||||
ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type));
|
||||
}
|
||||
case HloOpcode::kExp: {
|
||||
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
|
||||
auto exp_a = llvm_ir::EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::exp, {real(operand_value)},
|
||||
{real(operand_value)->getType()}, ir_builder_);
|
||||
auto cos_b = llvm_ir::EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::cos, {imag(operand_value)},
|
||||
{imag(operand_value)->getType()}, ir_builder_);
|
||||
auto sin_b = llvm_ir::EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::sin, {imag(operand_value)},
|
||||
{imag(operand_value)->getType()}, ir_builder_);
|
||||
return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
|
||||
ir_builder_->CreateFMul(exp_a, sin_b));
|
||||
}
|
||||
case HloOpcode::kCos: {
|
||||
// cos(z) = .5(e^(iz) + e^(-iz))
|
||||
// cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
|
||||
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
|
||||
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
|
||||
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
|
||||
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
|
||||
// = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
|
||||
auto a = real(operand_value);
|
||||
auto b = imag(operand_value);
|
||||
auto type = a->getType();
|
||||
auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
|
||||
{type}, ir_builder_);
|
||||
auto half_exp_b =
|
||||
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
|
||||
auto half_exp_neg_b =
|
||||
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
|
||||
auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a},
|
||||
{type}, ir_builder_);
|
||||
auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
|
||||
{type}, ir_builder_);
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateFMul(
|
||||
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
|
||||
ir_builder_->CreateFMul(
|
||||
sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
|
||||
}
|
||||
case HloOpcode::kSin: {
|
||||
// sin(z) = .5i(e^(-iz) - e^(iz))
|
||||
// sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
|
||||
// = .5i(e^(b-ai) - e^(-b+ai))
|
||||
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
|
||||
// sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
|
||||
// = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
|
||||
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
|
||||
// = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
|
||||
// = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
|
||||
auto a = real(operand_value);
|
||||
auto b = imag(operand_value);
|
||||
auto type = a->getType();
|
||||
auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
|
||||
{type}, ir_builder_);
|
||||
auto half_exp_b =
|
||||
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
|
||||
auto half_exp_neg_b =
|
||||
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
|
||||
auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a},
|
||||
{type}, ir_builder_);
|
||||
auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
|
||||
{type}, ir_builder_);
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateFMul(
|
||||
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
|
||||
ir_builder_->CreateFMul(
|
||||
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
|
||||
}
|
||||
case HloOpcode::kAbs: {
|
||||
auto sum_sq = ir_builder_->CreateFAdd(
|
||||
ir_builder_->CreateFMul(real(operand_value), real(operand_value)),
|
||||
ir_builder_->CreateFMul(imag(operand_value), imag(operand_value)));
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
|
||||
{sum_sq->getType()}, ir_builder_);
|
||||
}
|
||||
case HloOpcode::kSign: { // Sign(c) = c / |c|
|
||||
auto sum_sq = ir_builder_->CreateFAdd(
|
||||
ir_builder_->CreateFMul(real(operand_value), real(operand_value)),
|
||||
ir_builder_->CreateFMul(imag(operand_value), imag(operand_value)));
|
||||
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
|
||||
llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_);
|
||||
auto type = cplx_abs->getType();
|
||||
auto zero = llvm::ConstantFP::get(type, 0.0);
|
||||
auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero);
|
||||
return ir_builder_->CreateSelect(
|
||||
oeq, ComposeComplex(op, zero, zero),
|
||||
ComposeComplex(
|
||||
op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs),
|
||||
ir_builder_->CreateFDiv(imag(operand_value), cplx_abs)));
|
||||
}
|
||||
case HloOpcode::kNegate:
|
||||
return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)),
|
||||
ir_builder_->CreateFNeg(imag(operand_value)));
|
||||
case HloOpcode::kReal:
|
||||
return real(operand_value);
|
||||
case HloOpcode::kImag:
|
||||
return imag(operand_value);
|
||||
default:
|
||||
return Unimplemented("unary complex op '%s'",
|
||||
HloOpcodeString(op->opcode()).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
|
||||
const HloInstruction* op, llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const {
|
||||
return lhs_value->getType()->isIntegerTy()
|
||||
? EmitIntegerBinaryOp(op, lhs_value, rhs_value,
|
||||
primitive_util::IsSignedIntegralType(
|
||||
op->operand(0)->shape().element_type()))
|
||||
: EmitFloatBinaryOp(op, lhs_value, rhs_value);
|
||||
PrimitiveType operand_type = op->operand(0)->shape().element_type();
|
||||
if (lhs_value->getType()->isIntegerTy()) {
|
||||
return EmitIntegerBinaryOp(
|
||||
op, lhs_value, rhs_value,
|
||||
primitive_util::IsSignedIntegralType(operand_type));
|
||||
} else if (primitive_util::IsComplexType(operand_type)) {
|
||||
return EmitComplexBinaryOp(op, lhs_value, rhs_value);
|
||||
} else {
|
||||
return EmitFloatBinaryOp(op, lhs_value, rhs_value);
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
||||
const HloInstruction* op, llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const {
|
||||
switch (op->opcode()) {
|
||||
// case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support
|
||||
case HloOpcode::kComplex:
|
||||
return ComposeComplex(op, lhs_value, rhs_value);
|
||||
case HloOpcode::kAdd:
|
||||
return ir_builder_->CreateFAdd(lhs_value, rhs_value);
|
||||
case HloOpcode::kSubtract:
|
||||
@ -305,6 +476,88 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
|
||||
const HloInstruction* op, llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const {
|
||||
auto real = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {0});
|
||||
};
|
||||
auto imag = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {1});
|
||||
};
|
||||
switch (op->opcode()) {
|
||||
case HloOpcode::kAdd:
|
||||
return ComposeComplex(
|
||||
op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)),
|
||||
ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value)));
|
||||
case HloOpcode::kSubtract:
|
||||
return ComposeComplex(
|
||||
op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)),
|
||||
ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value)));
|
||||
case HloOpcode::kMultiply:
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateFSub(
|
||||
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
|
||||
ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))),
|
||||
ir_builder_->CreateFAdd(
|
||||
ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)),
|
||||
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))));
|
||||
case HloOpcode::kDivide: {
|
||||
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
|
||||
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
|
||||
auto rhs_sum_sq = ir_builder_->CreateFAdd(
|
||||
ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)),
|
||||
ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value)));
|
||||
auto type = rhs_sum_sq->getType();
|
||||
auto zero = llvm::ConstantFP::get(type, 0.0);
|
||||
auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero);
|
||||
return ir_builder_->CreateSelect(
|
||||
oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero),
|
||||
ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateFDiv(
|
||||
ir_builder_->CreateFAdd(
|
||||
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
|
||||
ir_builder_->CreateFMul(imag(lhs_value),
|
||||
imag(rhs_value))),
|
||||
rhs_sum_sq),
|
||||
ir_builder_->CreateFDiv(
|
||||
ir_builder_->CreateFSub(
|
||||
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)),
|
||||
ir_builder_->CreateFMul(real(lhs_value),
|
||||
imag(rhs_value))),
|
||||
rhs_sum_sq)));
|
||||
}
|
||||
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
|
||||
// comparisons always return false when one of the operands is NaN, whereas
|
||||
// unordered comparisons return true.
|
||||
//
|
||||
// We use ordered comparisons for everything except kNe, where we use an
|
||||
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
||||
// matches C++'s semantics.
|
||||
case HloOpcode::kEq:
|
||||
return ir_builder_->CreateAnd(
|
||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value),
|
||||
real(rhs_value), ir_builder_),
|
||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value),
|
||||
imag(rhs_value), ir_builder_));
|
||||
case HloOpcode::kNe:
|
||||
return ir_builder_->CreateOr(
|
||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value),
|
||||
real(rhs_value), ir_builder_),
|
||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value),
|
||||
imag(rhs_value), ir_builder_));
|
||||
|
||||
// TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic
|
||||
// case HloOpcode::kPower:
|
||||
// // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2)
|
||||
default:
|
||||
return Unimplemented("binary complex op '%s'",
|
||||
HloOpcodeString(op->opcode()).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const {
|
||||
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_);
|
||||
@ -396,7 +649,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
|
||||
PrimitiveType prim_type, llvm::Value* value) const {
|
||||
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
|
||||
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, ir_builder_);
|
||||
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
|
||||
auto one = llvm::ConstantFP::get(type, 1.0);
|
||||
return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
|
||||
}
|
||||
@ -619,7 +872,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
|
||||
const {
|
||||
PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
|
||||
llvm::Type* param_ir_type =
|
||||
llvm_ir::PrimitiveTypeToIrType(param_prim_type, ir_builder_);
|
||||
llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_);
|
||||
|
||||
// Same values as PCG library
|
||||
// https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
|
||||
@ -783,7 +1036,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
|
||||
return ir_builder_->CreateZExt(
|
||||
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p),
|
||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||
ir_builder_));
|
||||
module_));
|
||||
}
|
||||
default:
|
||||
return InvalidArgument(
|
||||
@ -806,9 +1059,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kTanh:
|
||||
@ -821,6 +1076,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
return EmitUnaryOp(hlo, operand_value);
|
||||
};
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
@ -913,10 +1170,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
}
|
||||
|
||||
llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
|
||||
llvm::PHINode* output = ir_builder_->CreatePHI(
|
||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||
ir_builder_),
|
||||
hlo->operands().size());
|
||||
llvm::PHINode* output =
|
||||
ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(
|
||||
hlo->shape().element_type(), module_),
|
||||
hlo->operands().size());
|
||||
auto prior_insert_point = ir_builder_->GetInsertPoint();
|
||||
|
||||
ir_builder_->SetInsertPoint(init_block);
|
||||
@ -1075,7 +1332,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
// else -> return data from 'index'.
|
||||
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||
ir_builder_),
|
||||
module_),
|
||||
"ret_value_addr", ir_builder_);
|
||||
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
|
||||
slice_intersection, "slice_intersection", ir_builder_);
|
||||
@ -1164,7 +1421,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
// }
|
||||
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||
ir_builder_),
|
||||
module_),
|
||||
"pad_result_addr", ir_builder_);
|
||||
llvm_ir::LlvmIfData if_data =
|
||||
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
|
||||
@ -1206,7 +1463,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
ir_builder_);
|
||||
PrimitiveType primitive_type = hlo->shape().element_type();
|
||||
llvm::Type* primitive_type_llvm =
|
||||
llvm_ir::PrimitiveTypeToIrType(primitive_type, ir_builder_);
|
||||
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
|
||||
llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
primitive_type_llvm, "dot_acc", ir_builder_);
|
||||
ir_builder_->CreateStore(
|
||||
@ -1239,7 +1496,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
|
||||
llvm::Value* next_accumulator;
|
||||
if (primitive_util::IsFloatingPointType(primitive_type)) {
|
||||
if (primitive_util::IsComplexType(primitive_type)) {
|
||||
auto real = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {0});
|
||||
};
|
||||
auto imag = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {1});
|
||||
};
|
||||
llvm::Value* product_real = ir_builder_->CreateFSub(
|
||||
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
|
||||
ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value)));
|
||||
llvm::Value* product_imag = ir_builder_->CreateFAdd(
|
||||
ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)),
|
||||
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)));
|
||||
next_accumulator = ir_builder_->CreateInsertValue(
|
||||
current_accumulator,
|
||||
ir_builder_->CreateFAdd(real(current_accumulator), product_real),
|
||||
{0});
|
||||
next_accumulator = ir_builder_->CreateInsertValue(
|
||||
next_accumulator,
|
||||
ir_builder_->CreateFAdd(imag(current_accumulator), product_imag),
|
||||
{1});
|
||||
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
|
||||
next_accumulator = ir_builder_->CreateFAdd(
|
||||
current_accumulator,
|
||||
ir_builder_->CreateFMul(lhs_value, rhs_value));
|
||||
@ -1261,4 +1539,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op,
|
||||
llvm::Value* real,
|
||||
llvm::Value* imag) const {
|
||||
auto cplx_type =
|
||||
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
||||
auto complex = ir_builder_->CreateInsertValue(
|
||||
llvm::ConstantAggregateZero::get(cplx_type), real, {0});
|
||||
if (imag != nullptr) {
|
||||
complex = ir_builder_->CreateInsertValue(complex, imag, {1});
|
||||
}
|
||||
return complex;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -55,6 +55,7 @@ class ElementalIrEmitter {
|
||||
const HloToElementGeneratorMap& operand_to_generator) const;
|
||||
|
||||
llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
|
||||
llvm::Module* module() const { return module_; }
|
||||
|
||||
protected:
|
||||
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
|
||||
@ -63,6 +64,9 @@ class ElementalIrEmitter {
|
||||
virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const;
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const;
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
|
||||
llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value,
|
||||
@ -72,6 +76,10 @@ class ElementalIrEmitter {
|
||||
const HloInstruction* op, llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const;
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(
|
||||
const HloInstruction* op, llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const;
|
||||
|
||||
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const;
|
||||
|
||||
@ -109,6 +117,11 @@ class ElementalIrEmitter {
|
||||
// compiled executable outside of the HLO code itself.
|
||||
const HloModuleConfig& hlo_module_config_;
|
||||
|
||||
protected:
|
||||
// Composes a complex struct. imag may be nullptr for simple cast operations.
|
||||
llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real,
|
||||
llvm::Value* imag) const;
|
||||
|
||||
private:
|
||||
// Returns a ElementGenerator for a RNG HloInstruction.
|
||||
llvm_ir::ElementGenerator MakeRngElementGenerator(
|
||||
|
@ -135,6 +135,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
|
||||
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
|
||||
PrimitiveType output_type = op->shape().element_type();
|
||||
switch (op->opcode()) {
|
||||
case HloOpcode::kAtan2:
|
||||
return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value},
|
||||
{lhs_input_type, rhs_input_type},
|
||||
output_type);
|
||||
case HloOpcode::kRemainder: {
|
||||
return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value},
|
||||
{lhs_input_type, rhs_input_type},
|
||||
@ -226,6 +230,112 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const {
|
||||
PrimitiveType input_type = op->operand(0)->shape().element_type();
|
||||
PrimitiveType component_type =
|
||||
primitive_util::IsComplexType(input_type)
|
||||
? primitive_util::ComplexComponentType(input_type)
|
||||
: input_type;
|
||||
auto real = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {0});
|
||||
};
|
||||
auto imag = [&](llvm::Value* x) {
|
||||
return ir_builder_->CreateExtractValue(x, {1});
|
||||
};
|
||||
|
||||
switch (op->opcode()) {
|
||||
case HloOpcode::kLog: {
|
||||
// log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
|
||||
auto a = real(operand_value);
|
||||
auto b = imag(operand_value);
|
||||
llvm::Type* llvm_ty = a->getType();
|
||||
auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
|
||||
ir_builder_->CreateFMul(b, b));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto log_sum_sq,
|
||||
EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type},
|
||||
component_type));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a},
|
||||
{component_type, component_type},
|
||||
component_type));
|
||||
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
|
||||
return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq),
|
||||
angle);
|
||||
}
|
||||
// TODO(b/65408531): Implement kPower on GPU, where atan2 is available.
|
||||
// case HloOpcode::kPower:
|
||||
// // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di))
|
||||
case HloOpcode::kExp: {
|
||||
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
|
||||
auto b = imag(operand_value);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)},
|
||||
{component_type}, component_type));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type},
|
||||
component_type));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type},
|
||||
component_type));
|
||||
return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
|
||||
ir_builder_->CreateFMul(exp_a, sin_b));
|
||||
}
|
||||
case HloOpcode::kCos: {
|
||||
// cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
|
||||
auto a = real(operand_value);
|
||||
auto llvm_ty = a->getType();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)},
|
||||
{component_type}, component_type));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
|
||||
component_type));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type},
|
||||
component_type));
|
||||
auto half_exp_b =
|
||||
ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
|
||||
auto half_exp_neg_b =
|
||||
ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateFMul(
|
||||
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
|
||||
ir_builder_->CreateFMul(
|
||||
sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
|
||||
}
|
||||
|
||||
case HloOpcode::kSin: {
|
||||
// sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
|
||||
auto a = real(operand_value);
|
||||
auto llvm_ty = a->getType();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)},
|
||||
{component_type}, component_type));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
|
||||
component_type));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type},
|
||||
component_type));
|
||||
auto half_exp_b =
|
||||
ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
|
||||
auto half_exp_neg_b =
|
||||
ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
|
||||
return ComposeComplex(
|
||||
op,
|
||||
ir_builder_->CreateFMul(
|
||||
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
|
||||
ir_builder_->CreateFMul(
|
||||
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
|
||||
}
|
||||
default:
|
||||
return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value);
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
|
||||
const string& callee_name,
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
||||
@ -235,13 +345,12 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
|
||||
std::vector<llvm::Type*> ir_input_types;
|
||||
for (PrimitiveType input_type : input_types) {
|
||||
ir_input_types.push_back(
|
||||
llvm_ir::PrimitiveTypeToIrType(input_type, ir_builder_));
|
||||
llvm_ir::PrimitiveTypeToIrType(input_type, module_));
|
||||
}
|
||||
llvm::FunctionType* callee_type = llvm::FunctionType::get(
|
||||
llvm_ir::PrimitiveTypeToIrType(output_type,
|
||||
ir_builder_), // The return type.
|
||||
ir_input_types, // The parameter types.
|
||||
false); // No variadic arguments.
|
||||
llvm_ir::PrimitiveTypeToIrType(output_type, module_), // Return type.
|
||||
ir_input_types, // Parameter types.
|
||||
false); // No variadic arguments.
|
||||
|
||||
// Declares the callee if it is not declared already.
|
||||
llvm::Function* callee = llvm::cast<llvm::Function>(
|
||||
@ -315,7 +424,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
||||
|
||||
PrimitiveType operand_element_type = operand->shape().element_type();
|
||||
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
|
||||
"reduce_window_accum_ptr", ir_builder_);
|
||||
{
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
|
||||
@ -377,7 +486,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
||||
const HloInstruction* operand = hlo->operand(0);
|
||||
llvm::Value* accum_ptr =
|
||||
ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
|
||||
hlo->shape().element_type(), ir_builder()));
|
||||
hlo->shape().element_type(), module_));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
|
||||
operand_to_generator.at(hlo->operand(1))({}));
|
||||
ir_builder()->CreateStore(init_value, accum_ptr);
|
||||
|
@ -54,6 +54,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitComplexUnaryOp(
|
||||
const HloInstruction* op, llvm::Value* operand_value) const override;
|
||||
|
||||
StatusOr<llvm::Value*> EmitFloatBinaryOp(
|
||||
const HloInstruction* op, llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const override;
|
||||
|
@ -102,7 +102,7 @@ void HloToIrBindings::EmitBasePointersForHlos(
|
||||
slice_result.ConsumeValueOrDie();
|
||||
if (slice.allocation()->is_thread_local()) {
|
||||
llvm::Type* pointee_type =
|
||||
llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_);
|
||||
llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
|
||||
BindHloToIrValue(*non_io_hlo,
|
||||
ir_builder_->CreateAlloca(pointee_type), index);
|
||||
} else {
|
||||
@ -124,18 +124,18 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
|
||||
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
|
||||
return llvm_ir::EmitGetTupleElement(
|
||||
gte->shape(), gte->tuple_index(), /*alignment=*/1,
|
||||
GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_);
|
||||
GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_);
|
||||
}
|
||||
return llvm_ir::EmitGetTupleElement(
|
||||
gte->shape(), gte->tuple_index(), /*alignment=*/1,
|
||||
EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_);
|
||||
EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_);
|
||||
}
|
||||
|
||||
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
|
||||
const ShapeIndex& shape_index,
|
||||
llvm::Value* ir_value) {
|
||||
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
|
||||
ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_builder_);
|
||||
ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
|
||||
llvm::Type* dest_type = pointee_type->getPointerTo();
|
||||
|
||||
llvm::Value* typed_ir_value;
|
||||
|
@ -36,10 +36,12 @@ class HloToIrBindings {
|
||||
public:
|
||||
HloToIrBindings(const HloModule& module,
|
||||
const BufferAssignment* buffer_assignment,
|
||||
llvm::IRBuilder<>* ir_builder, bool is_nested)
|
||||
llvm::IRBuilder<>* ir_builder, llvm::Module* llvm_module,
|
||||
bool is_nested)
|
||||
: buffer_assignment_(buffer_assignment),
|
||||
is_nested_(is_nested),
|
||||
ir_builder_(ir_builder),
|
||||
module_(llvm_module),
|
||||
alias_analysis_(module, *buffer_assignment_,
|
||||
&ir_builder_->getContext()) {}
|
||||
|
||||
@ -93,6 +95,7 @@ class HloToIrBindings {
|
||||
const bool is_nested_;
|
||||
|
||||
llvm::IRBuilder<>* ir_builder_;
|
||||
llvm::Module* module_;
|
||||
|
||||
// Stores the underlying llvm::IrArray for each HloInstruction.
|
||||
// For an instruction that generates multiple outputs, the root will be a
|
||||
|
@ -53,9 +53,10 @@ namespace gpu {
|
||||
IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
|
||||
IrEmitterContext* ir_emitter_context, bool is_nested)
|
||||
: ir_emitter_context_(ir_emitter_context),
|
||||
ir_builder_(ir_emitter_context->llvm_module()->getContext()),
|
||||
module_(ir_emitter_context->llvm_module()),
|
||||
ir_builder_(module_->getContext()),
|
||||
bindings_(ir_emitter_context->hlo_module(),
|
||||
&ir_emitter_context->buffer_assignment(), &ir_builder_,
|
||||
&ir_emitter_context->buffer_assignment(), &ir_builder_, module_,
|
||||
is_nested),
|
||||
hlo_module_config_(hlo_module_config) {
|
||||
ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
|
||||
@ -71,18 +72,17 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
|
||||
};
|
||||
}
|
||||
return EmitTargetElementLoop(
|
||||
*hlo, GpuElementalIrEmitter(hlo_module_config_,
|
||||
ir_emitter_context_->llvm_module(),
|
||||
&ir_builder_, GetNestedComputer())
|
||||
*hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
|
||||
GetNestedComputer())
|
||||
.MakeElementGenerator(hlo, operand_to_generator));
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleConstant(HloInstruction* constant,
|
||||
const Literal& literal) {
|
||||
llvm::Constant* initializer =
|
||||
llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
|
||||
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
|
||||
llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
|
||||
*ir_emitter_context_->llvm_module(), initializer->getType(),
|
||||
*module_, initializer->getType(),
|
||||
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
|
||||
/*Name=*/"");
|
||||
VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl
|
||||
@ -115,7 +115,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
|
||||
get_tuple_element->shape(), get_tuple_element->tuple_index(),
|
||||
// TODO(b/26344050): tighten the alignment here
|
||||
// based on the real element type.
|
||||
/*alignment=*/1, GetBasePointer(*operand), &ir_builder_));
|
||||
/*alignment=*/1, GetBasePointer(*operand), &ir_builder_, module_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -140,7 +140,7 @@ Status IrEmitter::HandleTuple(
|
||||
for (const HloInstruction* operand : operands) {
|
||||
base_ptrs.push_back(GetBasePointer(*operand));
|
||||
}
|
||||
llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_);
|
||||
llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -329,7 +329,7 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
|
||||
if (ShapeUtil::IsTuple(select->shape())) {
|
||||
llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred),
|
||||
GetBasePointer(*on_true),
|
||||
GetBasePointer(*on_false), &ir_builder_);
|
||||
GetBasePointer(*on_false), &ir_builder_, module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -355,7 +355,26 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
|
||||
lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
|
||||
llvm::Value* rhs_value =
|
||||
rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
|
||||
llvm::Value* result = ir_builder_.CreateFMul(lhs_value, rhs_value);
|
||||
llvm::Value* result;
|
||||
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
|
||||
auto real = [&](llvm::Value* x) {
|
||||
return ir_builder_.CreateExtractValue(x, {0});
|
||||
};
|
||||
auto imag = [&](llvm::Value* x) {
|
||||
return ir_builder_.CreateExtractValue(x, {1});
|
||||
};
|
||||
llvm::Value* real_result = ir_builder_.CreateFSub(
|
||||
ir_builder_.CreateFMul(real(lhs_value), real(rhs_value)),
|
||||
ir_builder_.CreateFMul(imag(lhs_value), imag(rhs_value)));
|
||||
llvm::Value* imag_result = ir_builder_.CreateFAdd(
|
||||
ir_builder_.CreateFMul(real(lhs_value), imag(rhs_value)),
|
||||
ir_builder_.CreateFMul(imag(lhs_value), real(rhs_value)));
|
||||
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
|
||||
result = ir_builder_.CreateInsertValue(result, real_result, {0});
|
||||
result = ir_builder_.CreateInsertValue(result, imag_result, {1});
|
||||
} else {
|
||||
result = ir_builder_.CreateFMul(lhs_value, rhs_value);
|
||||
}
|
||||
target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -411,8 +430,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
|
||||
|
||||
// Initialize the accumulator in the preheader to zero.
|
||||
new llvm::StoreInst(
|
||||
llvm::ConstantFP::get(accum_type, 0.0), // The value stored.
|
||||
accum_address, // The address.
|
||||
llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0
|
||||
accum_address, // The address.
|
||||
reduction_loop->GetPreheaderBasicBlock()
|
||||
->getTerminator()); // The instruction this store is inserted before.
|
||||
|
||||
@ -427,9 +446,27 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
|
||||
lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_);
|
||||
llvm::Value* rhs_element =
|
||||
rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_);
|
||||
llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element);
|
||||
llvm::Value* accum = ir_builder_.CreateLoad(accum_address);
|
||||
llvm::Value* updated_accum = ir_builder_.CreateFAdd(accum, product);
|
||||
llvm::Value* updated_accum;
|
||||
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
|
||||
#define REAL(x) ir_builder_.CreateExtractValue(x, {0})
|
||||
#define IMAG(x) ir_builder_.CreateExtractValue(x, {1})
|
||||
llvm::Value* product_real = ir_builder_.CreateFSub(
|
||||
ir_builder_.CreateFMul(REAL(lhs_element), REAL(rhs_element)),
|
||||
ir_builder_.CreateFMul(IMAG(lhs_element), IMAG(rhs_element)));
|
||||
llvm::Value* product_imag = ir_builder_.CreateFAdd(
|
||||
ir_builder_.CreateFMul(REAL(lhs_element), IMAG(rhs_element)),
|
||||
ir_builder_.CreateFMul(IMAG(lhs_element), REAL(rhs_element)));
|
||||
updated_accum = ir_builder_.CreateInsertValue(
|
||||
accum, ir_builder_.CreateFAdd(REAL(accum), product_real), {0});
|
||||
updated_accum = ir_builder_.CreateInsertValue(
|
||||
updated_accum, ir_builder_.CreateFAdd(IMAG(accum), product_imag), {1});
|
||||
#undef IMAG
|
||||
#undef REAL
|
||||
} else {
|
||||
llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element);
|
||||
updated_accum = ir_builder_.CreateFAdd(accum, product);
|
||||
}
|
||||
ir_builder_.CreateStore(updated_accum, accum_address);
|
||||
|
||||
// After the reduction loop exits, store the accumulator into the target
|
||||
@ -494,7 +531,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
||||
// Initialize an accumulator with init_value.
|
||||
llvm::AllocaInst* accumulator_addr =
|
||||
ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
|
||||
reduce->shape().element_type(), &ir_builder_));
|
||||
reduce->shape().element_type(), module_));
|
||||
ir_builder_.CreateStore(
|
||||
ir_builder_.CreateLoad(GetBasePointer(*init_value)),
|
||||
accumulator_addr);
|
||||
@ -547,8 +584,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
||||
for (HloInstruction* operand : fusion->operands()) {
|
||||
parameter_arrays.push_back(GetIrArray(*operand));
|
||||
}
|
||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
|
||||
ir_emitter_context_->llvm_module(),
|
||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_,
|
||||
&ir_builder_, GetNestedComputer());
|
||||
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
|
||||
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
|
||||
@ -591,9 +627,8 @@ Status IrEmitter::HandleRng(HloInstruction* random,
|
||||
// Emits a single-threaded loop because the loop body generated by the element
|
||||
// generator for Rng can't be parallelized (b/32333178).
|
||||
return llvm_ir::LoopEmitter(
|
||||
GpuElementalIrEmitter(hlo_module_config_,
|
||||
ir_emitter_context_->llvm_module(),
|
||||
&ir_builder_, GetNestedComputer())
|
||||
GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
|
||||
GetNestedComputer())
|
||||
.MakeElementGenerator(random, operand_to_generator),
|
||||
GetIrArray(*random), &ir_builder_)
|
||||
.EmitLoop(IrName(random));
|
||||
@ -634,7 +669,7 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
|
||||
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(
|
||||
computation.root_instruction()->shape().element_type(), &ir_builder_),
|
||||
computation.root_instruction()->shape().element_type(), module_),
|
||||
"return_buffer", &ir_builder_);
|
||||
std::vector<llvm::Value*> parameter_buffers;
|
||||
for (llvm::Value* parameter_element : parameter_elements) {
|
||||
|
@ -162,6 +162,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
IrEmitterContext* ir_emitter_context_;
|
||||
llvm::Module* module_;
|
||||
|
||||
// The following fields track the IR emission state. According to LLVM memory
|
||||
// management rules, their memory is owned by the module.
|
||||
|
@ -52,9 +52,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
|
||||
io_hlos->push_back(param);
|
||||
const Shape& param_shape = param->shape();
|
||||
argument_types.push_back(
|
||||
llvm_ir::ShapeToIrType(param_shape, &ir_builder_)->getPointerTo());
|
||||
int64 param_size = llvm_ir::ByteSizeOf(
|
||||
param_shape, ir_emitter_context_->llvm_module()->getDataLayout());
|
||||
llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
|
||||
int64 param_size =
|
||||
llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
|
||||
argument_dereferenceable_bytes.push_back(param_size);
|
||||
}
|
||||
{
|
||||
@ -62,7 +62,7 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
|
||||
io_hlos->push_back(root);
|
||||
const Shape& root_shape = root->shape();
|
||||
argument_types.push_back(
|
||||
llvm_ir::ShapeToIrType(root_shape, &ir_builder_)->getPointerTo());
|
||||
llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
|
||||
int64 root_size = llvm_ir::ByteSizeOf(
|
||||
root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
|
||||
argument_dereferenceable_bytes.push_back(root_size);
|
||||
|
@ -757,8 +757,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
auto loop_body_emitter =
|
||||
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
|
||||
// Emit the loop body that reduces one tile.
|
||||
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
|
||||
input_shape.element_type(), &ir_builder_);
|
||||
llvm::Type* element_ir_type =
|
||||
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
|
||||
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
|
||||
element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
|
||||
{
|
||||
@ -973,7 +973,7 @@ Status IrEmitterUnnested::EmitRowReduction(
|
||||
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
|
||||
// Emit the loop body that reduces one tile.
|
||||
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
|
||||
input_shape.element_type(), &ir_builder_);
|
||||
input_shape.element_type(), ir_emitter_context_->llvm_module());
|
||||
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
|
||||
element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
|
||||
{
|
||||
@ -1360,7 +1360,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
||||
// boolean flag if the value is initialized. The initialized_flag is set
|
||||
// false.
|
||||
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type,
|
||||
ir_emitter_context_->llvm_module()),
|
||||
"selected_value_address", &ir_builder_);
|
||||
llvm::Value* selected_index_address =
|
||||
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
@ -1440,7 +1441,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
||||
llvm::Value* operand_address =
|
||||
operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
|
||||
llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_),
|
||||
llvm_ir::PrimitiveTypeToIrType(PRED,
|
||||
ir_emitter_context_->llvm_module()),
|
||||
"select_return_buffer", &ir_builder_);
|
||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
||||
*select_and_scatter->select(),
|
||||
@ -1450,8 +1452,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
||||
// If the 'select' function returns false, update the selected value and the
|
||||
// index to the currently visiting operand.
|
||||
llvm::Value* cond = ir_builder_.CreateICmpNE(
|
||||
result, llvm::ConstantInt::get(
|
||||
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
|
||||
result,
|
||||
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
|
||||
PRED, ir_emitter_context_->llvm_module()),
|
||||
0),
|
||||
"boolean_predicate");
|
||||
llvm_ir::LlvmIfData if_select_lhs =
|
||||
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
|
||||
@ -1877,7 +1881,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
|
||||
}
|
||||
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
|
||||
llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_);
|
||||
llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_,
|
||||
module_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -50,6 +50,12 @@ namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
struct is_complex_t : public std::false_type {};
|
||||
|
||||
template <>
|
||||
struct is_complex_t<complex64> : public std::true_type {};
|
||||
|
||||
template <typename OperandT>
|
||||
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
||||
const Literal& lhs_literal,
|
||||
@ -101,6 +107,37 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
|
||||
const Shape& shape, HloOpcode opcode, const Literal& lhs_literal,
|
||||
const Literal& rhs_literal) {
|
||||
std::function<bool(complex64, complex64)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
||||
return lhs_el == rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kNe:
|
||||
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
||||
return lhs_el != rhs_el;
|
||||
};
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
||||
<< HloOpcodeString(opcode);
|
||||
}
|
||||
|
||||
auto result = Literal::CreateFromShape(shape);
|
||||
TF_RETURN_IF_ERROR(result->Populate<bool>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return compare_op(lhs_literal.Get<complex64>(multi_index),
|
||||
rhs_literal.Get<complex64>(multi_index));
|
||||
}));
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
template <typename ReturnT, typename NativeT>
|
||||
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
|
||||
HloInstruction* instruction,
|
||||
@ -138,7 +175,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
Status DefaultAction(HloInstruction* hlo_instruction) override {
|
||||
return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
|
||||
HloOpcodeString(hlo_instruction->opcode()).c_str());
|
||||
};
|
||||
}
|
||||
|
||||
// TODO(b/35950897): many of the stl functions used in the handlers are not
|
||||
// overloaded for every XLA primitive types.
|
||||
@ -156,7 +193,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr>
|
||||
typename std::enable_if<std::is_signed<NativeT>::value ||
|
||||
is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
|
||||
ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
|
||||
@ -169,7 +207,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleAbs<ReturnT>(abs, operand);
|
||||
}
|
||||
|
||||
Status HandleRound(HloInstruction* round) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleRound(HloInstruction* round) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[round],
|
||||
ElementWiseUnaryOp(round, [](ReturnT elem_operand) {
|
||||
return std::round(elem_operand);
|
||||
@ -177,6 +218,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleRound(HloInstruction* round) {
|
||||
return InvalidArgument("Unsupported type for Round");
|
||||
}
|
||||
|
||||
Status HandleRound(HloInstruction* round) override {
|
||||
return HandleRound<ReturnT>(round);
|
||||
}
|
||||
|
||||
Status HandleBroadcast(HloInstruction* broadcast) override {
|
||||
parent_->evaluated_[broadcast] =
|
||||
Literal::CreateFromShape(broadcast->shape());
|
||||
@ -205,15 +257,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleCeil(HloInstruction* ceil) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
|
||||
ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) {
|
||||
return std::ceil(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleCeil(HloInstruction* ceil) {
|
||||
return InvalidArgument("Unsupported type for Ceil");
|
||||
}
|
||||
|
||||
Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override {
|
||||
return HandleCeil<ReturnT>(ceil);
|
||||
}
|
||||
|
||||
Status HandleConvert(HloInstruction* convert) override {
|
||||
const HloInstruction* operand = convert->operand(0);
|
||||
@ -237,15 +303,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return std::exp(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleFloor(HloInstruction* floor) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor],
|
||||
ElementWiseUnaryOp(floor, [](ReturnT elem_operand) {
|
||||
return std::floor(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleFloor(HloInstruction* floor) {
|
||||
return InvalidArgument("Unsupported type for Floor");
|
||||
}
|
||||
|
||||
Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override {
|
||||
return HandleFloor<ReturnT>(floor);
|
||||
}
|
||||
|
||||
Status HandleLog(HloInstruction* log, HloInstruction* operand) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
|
||||
@ -253,15 +333,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return std::log(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleNot(HloInstruction* not_, HloInstruction* operand) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleNot(HloInstruction* not_) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
|
||||
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
|
||||
return !elem_operand;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleNot(HloInstruction* not_) {
|
||||
return InvalidArgument("Unsupported type for Not");
|
||||
}
|
||||
|
||||
Status HandleNot(HloInstruction* not_, HloInstruction* operand) override {
|
||||
return HandleNot<ReturnT>(not_);
|
||||
}
|
||||
|
||||
Status HandleNegate(HloInstruction* negate,
|
||||
HloInstruction* operand) override {
|
||||
@ -270,16 +364,36 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return -elem_operand;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleSign(HloInstruction* sign, HloInstruction* operand) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleSign(HloInstruction* sign) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
|
||||
ElementWiseUnaryOp(sign, [](ReturnT elem_operand) {
|
||||
return (ReturnT(0) < elem_operand) -
|
||||
(elem_operand < ReturnT(0));
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleSign(HloInstruction* sign) {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
|
||||
ElementWiseUnaryOp(sign, [](ReturnT elem_operand) {
|
||||
auto abs_val = std::abs(elem_operand);
|
||||
return 0 == abs_val ? ReturnT(0)
|
||||
: elem_operand / abs_val;
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleSign(HloInstruction* sign, HloInstruction* operand) override {
|
||||
return HandleSign<ReturnT>(sign);
|
||||
}
|
||||
|
||||
Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
|
||||
@ -287,7 +401,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return std::tanh(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
@ -297,7 +411,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return lhs_elem * rhs_elem;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
@ -307,7 +421,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return lhs_elem - rhs_elem;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
@ -317,7 +431,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return lhs_elem + rhs_elem;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
@ -327,25 +441,53 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return lhs_elem / rhs_elem;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleMaximum(HloInstruction* maximum) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleMaximum(HloInstruction* maximum) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[maximum],
|
||||
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
|
||||
return std::fmax(lhs, rhs);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleMinimum(HloInstruction* minimum) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleMaximum(HloInstruction* maximum) {
|
||||
return InvalidArgument("Unsupported type for Maximum");
|
||||
}
|
||||
|
||||
Status HandleMaximum(HloInstruction* maximum) override {
|
||||
return HandleMaximum<ReturnT>(maximum);
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleMinimum(HloInstruction* minimum) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[minimum],
|
||||
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return std::fmin(lhs_el, rhs_el);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleMinimum(HloInstruction* minimum) {
|
||||
return InvalidArgument("Unsupported type for Minimum");
|
||||
}
|
||||
|
||||
Status HandleMinimum(HloInstruction* minimum) override {
|
||||
return HandleMinimum<ReturnT>(minimum);
|
||||
}
|
||||
|
||||
Status HandlePower(HloInstruction* power, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
@ -355,37 +497,79 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return std::pow(lhs_el, rhs_el);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleRemainder(HloInstruction* remainder) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[remainder],
|
||||
ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return std::fmod(lhs_el, rhs_el);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleAnd(HloInstruction* and_, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleRemainder(HloInstruction* remainder) {
|
||||
return InvalidArgument("Unsupported type for Remainder");
|
||||
}
|
||||
|
||||
Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
return HandleRemainder<ReturnT>(remainder);
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleAnd(HloInstruction* and_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[and_],
|
||||
ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el && rhs_el;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleOr(HloInstruction* or_, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleAnd(HloInstruction* and_) {
|
||||
return InvalidArgument("Unsupported type for And");
|
||||
}
|
||||
|
||||
Status HandleAnd(HloInstruction* and_, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
return HandleAnd<ReturnT>(and_);
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleOr(HloInstruction* or_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[or_],
|
||||
ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el || rhs_el;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleOr(HloInstruction* or_) {
|
||||
return InvalidArgument("Unsupported type for Or");
|
||||
}
|
||||
|
||||
Status HandleOr(HloInstruction* or_, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
return HandleOr<ReturnT>(or_);
|
||||
}
|
||||
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
@ -474,8 +658,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleShiftRightLogical<ReturnT>(shrl, lhs, rhs);
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||
HloInstruction* arg, HloInstruction* max) override {
|
||||
HloInstruction* arg, HloInstruction* max) {
|
||||
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
|
||||
[](ReturnT low, ReturnT high, ReturnT value) {
|
||||
return std::fmax(low, std::fmin(value, high));
|
||||
@ -483,7 +670,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
|
||||
ElementWiseTernaryOp(clamp, std::move(clamp_op)));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||
HloInstruction* arg, HloInstruction* max) {
|
||||
return InvalidArgument("Unsupported type for Clamp");
|
||||
}
|
||||
|
||||
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||
HloInstruction* arg, HloInstruction* max) override {
|
||||
return HandleClamp<ReturnT>(clamp, min, arg, max);
|
||||
}
|
||||
|
||||
Status HandleSelect(HloInstruction* select, HloInstruction* pred,
|
||||
HloInstruction* on_true,
|
||||
@ -499,7 +699,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
|
||||
ElementWiseTernaryOp(select, std::move(select_op)));
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleReverse(HloInstruction* reverse,
|
||||
HloInstruction* operand) override {
|
||||
@ -529,7 +729,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
parent_->evaluated_[reverse] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
|
||||
HloInstruction* rhs, const Window& window) override {
|
||||
@ -652,7 +852,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
parent_->evaluated_[conv] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
@ -719,7 +919,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
parent_->evaluated_[dot] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandlePad(HloInstruction* pad) override {
|
||||
CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
|
||||
@ -788,7 +988,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
parent_->evaluated_[pad] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
|
||||
HloInstruction* operand,
|
||||
@ -841,7 +1041,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
|
||||
HloInstruction* operand,
|
||||
@ -897,7 +1097,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
||||
HloInstruction* init_value,
|
||||
@ -985,7 +1185,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
parent_->evaluated_[reduce] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleReduceWindow(HloInstruction* reduce_window,
|
||||
HloInstruction* operand, const Window& window,
|
||||
@ -1072,7 +1272,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
parent_->evaluated_[reduce_window] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override {
|
||||
const Shape& shape = slice->shape();
|
||||
@ -1101,7 +1301,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
|
||||
parent_->evaluated_[slice] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename IndexT>
|
||||
@ -1244,35 +1444,33 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
HloEvaluator* parent_;
|
||||
}; // namespace xla
|
||||
}; // class HloEvaluator::TypedVisitor
|
||||
|
||||
HloEvaluator::HloEvaluator() {
|
||||
typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this);
|
||||
typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this);
|
||||
typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: U16.");
|
||||
return Unimplemented("HloEvaluator: unhandled primitive type: U16.");
|
||||
});
|
||||
typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this);
|
||||
typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this);
|
||||
typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this);
|
||||
typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: S16.");
|
||||
return Unimplemented("HloEvaluator: unhandled primitive type: S16.");
|
||||
});
|
||||
typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
|
||||
typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
|
||||
typed_visitors_[F16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: F16.");
|
||||
return Unimplemented("HloEvaluator: unhandled primitive type: F16.");
|
||||
});
|
||||
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
|
||||
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
|
||||
typed_visitors_[C64] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: C64.");
|
||||
});
|
||||
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
|
||||
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: TUPLE.");
|
||||
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
|
||||
});
|
||||
typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: OPAQUE.");
|
||||
return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE.");
|
||||
});
|
||||
}
|
||||
|
||||
@ -1573,6 +1771,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
||||
evaluated_[compare],
|
||||
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case C64: {
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<complex64>(compare->shape(), opcode,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
default:
|
||||
LOG(FATAL) << "HandleCompare: unknown primitive type: "
|
||||
<< PrimitiveType_Name(lhs->shape().element_type());
|
||||
|
@ -826,8 +826,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kDivide:
|
||||
@ -836,6 +838,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIndex:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
@ -850,6 +853,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
|
@ -219,10 +219,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSort:
|
||||
@ -241,26 +243,28 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
||||
// Only certain opcodes are supported with CreateBinary: opcodes of binary
|
||||
// instructions with no auxiliary fields.
|
||||
switch (opcode) {
|
||||
case (HloOpcode::kAdd):
|
||||
case (HloOpcode::kDivide):
|
||||
case (HloOpcode::kDot):
|
||||
case (HloOpcode::kEq):
|
||||
case (HloOpcode::kGe):
|
||||
case (HloOpcode::kGt):
|
||||
case (HloOpcode::kLe):
|
||||
case (HloOpcode::kLt):
|
||||
case (HloOpcode::kMaximum):
|
||||
case (HloOpcode::kMinimum):
|
||||
case (HloOpcode::kMultiply):
|
||||
case (HloOpcode::kNe):
|
||||
case (HloOpcode::kPower):
|
||||
case (HloOpcode::kRemainder):
|
||||
case (HloOpcode::kSubtract):
|
||||
case (HloOpcode::kAnd):
|
||||
case (HloOpcode::kOr):
|
||||
case (HloOpcode::kShiftLeft):
|
||||
case (HloOpcode::kShiftRightArithmetic):
|
||||
case (HloOpcode::kShiftRightLogical):
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kDot:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Invalid binary instruction opcode "
|
||||
@ -978,11 +982,13 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSort:
|
||||
@ -992,6 +998,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
break;
|
||||
// Binary ops.
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kSubtract:
|
||||
@ -1403,10 +1411,12 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
// The result of these instructions only depend upon their opcode and
|
||||
// operands.
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kCrossReplicaSum:
|
||||
@ -1417,6 +1427,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLog:
|
||||
@ -1430,6 +1441,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kSelect:
|
||||
case HloOpcode::kShiftLeft:
|
||||
@ -2117,6 +2129,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
||||
switch (opcode_) {
|
||||
case HloOpcode::kAbs:
|
||||
return visitor->HandleAbs(this, operands_[0]);
|
||||
case HloOpcode::kAtan2:
|
||||
return visitor->HandleAtan2(this, operands_[0], operands_[1]);
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
return visitor->HandleRound(this);
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
@ -2140,6 +2154,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kNe:
|
||||
return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]);
|
||||
case HloOpcode::kComplex:
|
||||
return visitor->HandleComplex(this, operands_[0], operands_[1]);
|
||||
case HloOpcode::kAdd:
|
||||
return visitor->HandleAdd(this, operands_[0], operands_[1]);
|
||||
case HloOpcode::kDivide:
|
||||
@ -2214,6 +2230,10 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
||||
return visitor->HandleCos(this, operands_[0]);
|
||||
case HloOpcode::kSin:
|
||||
return visitor->HandleSin(this, operands_[0]);
|
||||
case HloOpcode::kReal:
|
||||
return visitor->HandleReal(this, operands_[0]);
|
||||
case HloOpcode::kImag:
|
||||
return visitor->HandleImag(this, operands_[0]);
|
||||
case HloOpcode::kIsFinite:
|
||||
return visitor->HandleIsFinite(this, operands_[0]);
|
||||
case HloOpcode::kNot:
|
||||
@ -2305,7 +2325,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
|
||||
//
|
||||
// We need to keep track of both the id and the instruction because
|
||||
// instructions can get deleted while they are on the stack, so we
|
||||
// can't always use the (potentiall dead) instruction object to grab
|
||||
// can't always use the (potentially dead) instruction object to grab
|
||||
// its id.
|
||||
DFSStack dfs_stack;
|
||||
dfs_stack.emplace_back(root->unique_id(), root);
|
||||
@ -2505,6 +2525,7 @@ bool HloInstruction::IsElementwiseBinary() const {
|
||||
// Binary elementwise operations. If you update this, please update
|
||||
// IsElementwise() accordingly.
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
@ -2537,6 +2558,7 @@ bool HloInstruction::IsElementwise() const {
|
||||
|
||||
// Unary elementwise operations.
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kConvert:
|
||||
@ -2544,10 +2566,12 @@ bool HloInstruction::IsElementwise() const {
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
@ -2557,6 +2581,7 @@ bool HloInstruction::IsElementwise() const {
|
||||
// Binary elementwise operations, the same as in IsElementwiseBinary().
|
||||
// If you update this, please update IsElementwiseBinary() accordingly.
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
|
@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
||||
return "abs";
|
||||
case HloOpcode::kAdd:
|
||||
return "add";
|
||||
case HloOpcode::kAtan2:
|
||||
return "atan2";
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
return "batch-norm-training";
|
||||
case HloOpcode::kBatchNormInference:
|
||||
@ -47,6 +49,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
||||
return "call";
|
||||
case HloOpcode::kClamp:
|
||||
return "clamp";
|
||||
case HloOpcode::kComplex:
|
||||
return "complex";
|
||||
case HloOpcode::kConcatenate:
|
||||
return "concatenate";
|
||||
case HloOpcode::kConstant:
|
||||
@ -87,6 +91,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
||||
return "get-tuple-element";
|
||||
case HloOpcode::kGt:
|
||||
return "greater-than";
|
||||
case HloOpcode::kImag:
|
||||
return "imag";
|
||||
case HloOpcode::kIndex:
|
||||
return "index";
|
||||
case HloOpcode::kInfeed:
|
||||
@ -125,6 +131,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
||||
return "parameter";
|
||||
case HloOpcode::kPower:
|
||||
return "power";
|
||||
case HloOpcode::kReal:
|
||||
return "real";
|
||||
case HloOpcode::kRecv:
|
||||
return "recv";
|
||||
case HloOpcode::kReduce:
|
||||
|
@ -31,6 +31,7 @@ namespace xla {
|
||||
enum class HloOpcode {
|
||||
kAbs,
|
||||
kAdd,
|
||||
kAtan2,
|
||||
kBatchNormGrad,
|
||||
kBatchNormInference,
|
||||
kBatchNormTraining,
|
||||
@ -39,6 +40,7 @@ enum class HloOpcode {
|
||||
kCall,
|
||||
kCeil,
|
||||
kClamp,
|
||||
kComplex,
|
||||
kConcatenate,
|
||||
kConstant,
|
||||
kConvert,
|
||||
@ -58,6 +60,7 @@ enum class HloOpcode {
|
||||
kGe,
|
||||
kGetTupleElement,
|
||||
kGt,
|
||||
kImag,
|
||||
kIndex,
|
||||
kInfeed,
|
||||
kIsFinite,
|
||||
@ -77,6 +80,7 @@ enum class HloOpcode {
|
||||
kPad,
|
||||
kParameter,
|
||||
kPower,
|
||||
kReal,
|
||||
kRecv,
|
||||
kReduce,
|
||||
kReducePrecision,
|
||||
|
@ -59,6 +59,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
for (auto& invariant_checker : invariant_checkers_) {
|
||||
VLOG(1) << " Invariant checker " << invariant_checker->name();
|
||||
StatusOr<bool> changed_status = invariant_checker->Run(module);
|
||||
VLOG(1) << " Invariant checker done " << invariant_checker->name();
|
||||
if (!changed_status.ok()) {
|
||||
VLOG(2) << "Module failed invariant check:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
@ -64,6 +64,10 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
}
|
||||
|
||||
Status HandleConvert(HloInstruction* convert) override {
|
||||
if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) {
|
||||
TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape()))
|
||||
<< "Unsupported complex->real kConvert";
|
||||
}
|
||||
return CheckShape(convert, ShapeInference::InferConvertShape(
|
||||
convert->operand(0)->shape(),
|
||||
convert->shape().element_type()));
|
||||
|
@ -32,17 +32,16 @@ namespace xla {
|
||||
const HloInstruction& instruction) {
|
||||
switch (instruction.opcode()) {
|
||||
// Cheap instructions.
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kEq:
|
||||
@ -50,6 +49,7 @@ namespace xla {
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLe:
|
||||
@ -64,6 +64,7 @@ namespace xla {
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kReverse:
|
||||
@ -72,15 +73,21 @@ namespace xla {
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kTuple:
|
||||
return false;
|
||||
|
||||
// Cheap instructions for reals, but expensive for complex.
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
return ShapeUtil::ElementIsComplex(instruction.shape());
|
||||
|
||||
// Expensive instructions.
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
|
@ -75,7 +75,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
|
||||
Status FusedIrEmitter::HandleConstant(HloInstruction* constant,
|
||||
const Literal& literal) {
|
||||
llvm::Constant* initializer =
|
||||
llvm_ir::ConvertLiteralToIrConstant(literal, ir_builder_);
|
||||
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
|
||||
llvm::GlobalVariable* global = new llvm::GlobalVariable(
|
||||
*ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
|
||||
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
|
||||
@ -101,7 +101,7 @@ Status FusedIrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
|
||||
// Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
|
||||
llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
|
||||
get_tuple_element->shape(), get_tuple_element->tuple_index(),
|
||||
/*alignment=*/1, it->second, ir_builder_);
|
||||
/*alignment=*/1, it->second, ir_builder_, module_);
|
||||
gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr));
|
||||
// Emit code to read base tuple element array (if non-tuple shaped).
|
||||
if (!ShapeUtil::IsTuple(get_tuple_element->shape())) {
|
||||
@ -134,7 +134,7 @@ Status FusedIrEmitter::HandleTuple(
|
||||
std::vector<llvm::Type*> operand_elemental_ir_types;
|
||||
for (HloInstruction* operand : operands) {
|
||||
operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
|
||||
operand->shape().element_type(), ir_builder_));
|
||||
operand->shape().element_type(), module_));
|
||||
}
|
||||
generators_[tuple] =
|
||||
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
|
@ -42,7 +42,8 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
|
||||
ElementalIrEmitter* elemental_emitter)
|
||||
: parameter_arrays_(parameter_arrays),
|
||||
elemental_emitter_(elemental_emitter),
|
||||
ir_builder_(elemental_emitter->ir_builder()) {}
|
||||
ir_builder_(elemental_emitter->ir_builder()),
|
||||
module_(elemental_emitter->module()) {}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
@ -85,6 +86,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Borrowed
|
||||
llvm::IRBuilder<>* ir_builder_;
|
||||
llvm::Module* module_;
|
||||
|
||||
// Map from instruction pointers to functions to generate elements of their
|
||||
// outputs
|
||||
|
@ -229,9 +229,11 @@ llvm::Value* IrArray::EmitArrayElementAddress(
|
||||
}
|
||||
|
||||
if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) {
|
||||
llvm::Module* module =
|
||||
ir_builder->GetInsertBlock()->getParent()->getParent();
|
||||
return ir_builder->CreateInBoundsGEP(
|
||||
ir_builder->CreateBitCast(
|
||||
base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), ir_builder)
|
||||
base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module)
|
||||
->getPointerTo()),
|
||||
{index.linear()}, llvm_ir::AsStringRef(name));
|
||||
}
|
||||
@ -281,7 +283,8 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
|
||||
|
||||
IrArray IrArray::CastToShape(const Shape& new_shape,
|
||||
llvm::IRBuilder<>* ir_builder) const {
|
||||
llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, ir_builder);
|
||||
llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
|
||||
llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
|
||||
return IrArray(
|
||||
ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()),
|
||||
new_shape);
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/MDBuilder.h"
|
||||
#include "llvm/IR/Operator.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
@ -38,6 +39,19 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace llvm_ir {
|
||||
|
||||
namespace {
|
||||
|
||||
// Note, this function is only useful in an insertion context; in a global
|
||||
// (e.g. constants) context it will CHECK fail.
|
||||
llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) {
|
||||
auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock());
|
||||
auto fn = CHECK_NOTNULL(block->getParent());
|
||||
auto module = CHECK_NOTNULL(fn->getParent());
|
||||
return module;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
string AsString(const std::string& str) {
|
||||
return string(str.data(), str.length());
|
||||
}
|
||||
@ -63,7 +77,7 @@ llvm::Value* EmitCallToIntrinsic(
|
||||
for (auto type : overloaded_types) {
|
||||
types.push_back(type);
|
||||
}
|
||||
llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
|
||||
llvm::Module* module = ModuleFromIRBuilder(ir_builder);
|
||||
llvm::Function* intrinsic =
|
||||
llvm::Intrinsic::getDeclaration(module, intrinsic_id, types);
|
||||
std::vector<llvm::Value*> operands_vec;
|
||||
@ -119,38 +133,53 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
|
||||
}
|
||||
|
||||
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
|
||||
llvm::IRBuilder<>* ir_builder) {
|
||||
llvm::Module* module) {
|
||||
switch (element_type) {
|
||||
case PRED:
|
||||
case S8:
|
||||
case U8:
|
||||
return ir_builder->getInt8Ty();
|
||||
return llvm::Type::getInt8Ty(module->getContext());
|
||||
case S16:
|
||||
case U16:
|
||||
return ir_builder->getInt16Ty();
|
||||
return llvm::Type::getInt16Ty(module->getContext());
|
||||
case S32:
|
||||
case U32:
|
||||
return ir_builder->getInt32Ty();
|
||||
return llvm::Type::getInt32Ty(module->getContext());
|
||||
case S64:
|
||||
case U64:
|
||||
return ir_builder->getInt64Ty();
|
||||
return llvm::Type::getInt64Ty(module->getContext());
|
||||
case F32:
|
||||
return ir_builder->getFloatTy();
|
||||
return llvm::Type::getFloatTy(module->getContext());
|
||||
case F64:
|
||||
return ir_builder->getDoubleTy();
|
||||
return llvm::Type::getDoubleTy(module->getContext());
|
||||
case C64: {
|
||||
auto cplx_t = module->getTypeByName("complex64");
|
||||
if (cplx_t == nullptr) {
|
||||
// C++ standard dictates the memory layout of std::complex is contiguous
|
||||
// real followed by imaginary. C++11 section 26.4 [complex.numbers]:
|
||||
// If z is an lvalue expression of type cv std::complex<T> then the
|
||||
// expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
|
||||
// reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
|
||||
// z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
|
||||
// imaginary part of z.
|
||||
return llvm::StructType::create(
|
||||
"complex64", llvm::Type::getFloatTy(module->getContext()),
|
||||
llvm::Type::getFloatTy(module->getContext()));
|
||||
}
|
||||
return cplx_t;
|
||||
}
|
||||
// A Tuple contains an array of pointers. Use i8*.
|
||||
case TUPLE:
|
||||
// An Opaque is like a void*, use i8*.
|
||||
case OPAQUE:
|
||||
return ir_builder->getInt8PtrTy();
|
||||
return llvm::Type::getInt8PtrTy(module->getContext());
|
||||
default:
|
||||
LOG(FATAL) << "unsupported type " << element_type;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) {
|
||||
llvm::Type* result_type =
|
||||
PrimitiveTypeToIrType(shape.element_type(), ir_builder);
|
||||
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
|
||||
llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
|
||||
if (ShapeUtil::IsTuple(shape)) {
|
||||
// A tuple buffer is an array of pointers.
|
||||
result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
|
||||
@ -197,10 +226,10 @@ namespace {
|
||||
// value down to zero).
|
||||
llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
||||
std::vector<int64>* multi_index,
|
||||
llvm::IRBuilder<>* ir_builder) {
|
||||
llvm::Module* module) {
|
||||
const Shape& shape = literal.shape();
|
||||
llvm::Type* ir_element_type =
|
||||
llvm_ir::PrimitiveTypeToIrType(shape.element_type(), ir_builder);
|
||||
llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module);
|
||||
if (dimension_index == -1) {
|
||||
// Base case of the recursion. Index into the data field of the protobuf
|
||||
// with the multi index.
|
||||
@ -238,6 +267,16 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
||||
value = llvm::ConstantFP::get(ir_element_type,
|
||||
literal.Get<double>(*multi_index));
|
||||
break;
|
||||
case C64: {
|
||||
complex64 x = literal.Get<complex64>(*multi_index);
|
||||
value = llvm::ConstantStruct::get(
|
||||
static_cast<llvm::StructType*>(ir_element_type),
|
||||
llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
|
||||
x.real()),
|
||||
llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
|
||||
x.imag()));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "unsupported type " << shape.element_type();
|
||||
}
|
||||
@ -256,8 +295,8 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
||||
std::vector<llvm::Constant*> elements;
|
||||
for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
|
||||
(*multi_index)[dimension] = i;
|
||||
elements.push_back(LiteralToConstant(literal, dimension_index - 1,
|
||||
multi_index, ir_builder));
|
||||
elements.push_back(
|
||||
LiteralToConstant(literal, dimension_index - 1, multi_index, module));
|
||||
}
|
||||
|
||||
llvm::Type* element_type;
|
||||
@ -279,11 +318,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
||||
} // namespace
|
||||
|
||||
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
|
||||
llvm::IRBuilder<>* ir_builder) {
|
||||
llvm::Module* module) {
|
||||
std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
|
||||
llvm::Constant* value = LiteralToConstant(
|
||||
literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
|
||||
&multi_index, ir_builder);
|
||||
&multi_index, module);
|
||||
return value;
|
||||
}
|
||||
|
||||
@ -380,7 +419,8 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
|
||||
// comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
|
||||
// arrays. So we extend it to i8 so that it's addressable.
|
||||
return ir_builder->CreateZExt(
|
||||
comparison_result, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder));
|
||||
comparison_result,
|
||||
llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder)));
|
||||
}
|
||||
|
||||
// Internal helper that is called from emitted code to log an int64 value with a
|
||||
|
@ -127,11 +127,11 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
|
||||
|
||||
// Returns the LLVM type which represents the given XLA primitive type.
|
||||
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
|
||||
llvm::IRBuilder<>* ir_builder);
|
||||
llvm::Module* module);
|
||||
|
||||
// Returns the LLVM type which represents the given XLA shape. For example,
|
||||
// if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
|
||||
llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder);
|
||||
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
|
||||
|
||||
// Returns a value that represents a pointer to a global string constant that
|
||||
// encodes the shape as a serialized protobuf.
|
||||
@ -149,7 +149,7 @@ StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
|
||||
// Converts a given literal to an IR Constant. Literals have known constant
|
||||
// values at IR emission time.
|
||||
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
|
||||
llvm::IRBuilder<>* ir_builder);
|
||||
llvm::Module* module);
|
||||
|
||||
// Inserts an allocate of the requested type at the entry point of the
|
||||
// function that the builder is currently building. The insert point
|
||||
|
@ -31,14 +31,15 @@ namespace xla {
|
||||
namespace llvm_ir {
|
||||
|
||||
void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
|
||||
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) {
|
||||
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
|
||||
llvm::Module* module) {
|
||||
CHECK(ShapeUtil::IsScalar(pred.GetShape()));
|
||||
|
||||
llvm::LoadInst* pred_value =
|
||||
ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value");
|
||||
llvm::Value* pred_cond = ir_builder->CreateICmpNE(
|
||||
pred_value,
|
||||
llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0),
|
||||
llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0),
|
||||
"boolean_predicate");
|
||||
|
||||
VLOG(2) << "HandleSelect for tuple:";
|
||||
@ -71,11 +72,11 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
|
||||
|
||||
void EmitTuple(IrArray tuple,
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
||||
llvm::IRBuilder<>* ir_builder) {
|
||||
llvm::IRBuilder<>* ir_builder, llvm::Module* module) {
|
||||
for (size_t i = 0; i < operands.size(); ++i) {
|
||||
auto* store = ir_builder->CreateStore(
|
||||
ir_builder->CreatePointerCast(operands[i],
|
||||
PrimitiveTypeToIrType(TUPLE, ir_builder)),
|
||||
PrimitiveTypeToIrType(TUPLE, module)),
|
||||
ir_builder->CreateInBoundsGEP(
|
||||
tuple.GetBasePointer(),
|
||||
{ir_builder->getInt64(0), ir_builder->getInt64(i)}));
|
||||
@ -85,7 +86,8 @@ void EmitTuple(IrArray tuple,
|
||||
|
||||
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||
int alignment, llvm::Value* operand,
|
||||
llvm::IRBuilder<>* ir_builder) {
|
||||
llvm::IRBuilder<>* ir_builder,
|
||||
llvm::Module* module) {
|
||||
llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP(
|
||||
operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)});
|
||||
llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr);
|
||||
@ -98,7 +100,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||
}
|
||||
SetAlignmentMetadataForLoad(src_buffer, alignment);
|
||||
|
||||
llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder);
|
||||
llvm::Type* element_type = ShapeToIrType(target_shape, module);
|
||||
llvm::Value* ret_val =
|
||||
ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo());
|
||||
return ret_val;
|
||||
|
@ -60,13 +60,14 @@ namespace llvm_ir {
|
||||
// tuple_on_true or tuple_on_false:
|
||||
// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i]
|
||||
void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
|
||||
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder);
|
||||
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
|
||||
llvm::Module* module);
|
||||
|
||||
// A tuple is an array of pointers, one for each operand. Each pointer points to
|
||||
// the output buffer of its corresponding operand.
|
||||
void EmitTuple(IrArray tuple,
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
||||
llvm::IRBuilder<>* ir_builder);
|
||||
llvm::IRBuilder<>* ir_builder, llvm::Module* module);
|
||||
|
||||
// A tuple is an array of pointers, one for each operand. Each pointer points to
|
||||
// the output buffer of its corresponding operand. A GetTupleElement instruction
|
||||
@ -74,7 +75,8 @@ void EmitTuple(IrArray tuple,
|
||||
// Returns an llvm value representing a pointer to the tuple element buffer.
|
||||
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||
int alignment, llvm::Value* operand,
|
||||
llvm::IRBuilder<>* ir_builder);
|
||||
llvm::IRBuilder<>* ir_builder,
|
||||
llvm::Module* module);
|
||||
} // namespace llvm_ir
|
||||
} // namespace xla
|
||||
|
||||
|
@ -53,6 +53,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
|
||||
return UNOP_EXP;
|
||||
case HloOpcode::kFloor:
|
||||
return UNOP_FLOOR;
|
||||
case HloOpcode::kImag:
|
||||
return UNOP_IMAG;
|
||||
case HloOpcode::kIsFinite:
|
||||
return UNOP_IS_FINITE;
|
||||
case HloOpcode::kLog:
|
||||
@ -61,6 +63,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
|
||||
return UNOP_NOT;
|
||||
case HloOpcode::kNegate:
|
||||
return UNOP_NEGATE;
|
||||
case HloOpcode::kReal:
|
||||
return UNOP_REAL;
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
return UNOP_ROUND_NEAREST_AFZ;
|
||||
case HloOpcode::kSign:
|
||||
@ -81,6 +85,10 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
|
||||
// opcode.
|
||||
BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
|
||||
switch (opcode) {
|
||||
case HloOpcode::kAtan2:
|
||||
return BINOP_ATAN2;
|
||||
case HloOpcode::kComplex:
|
||||
return BINOP_COMPLEX;
|
||||
case HloOpcode::kDot:
|
||||
return BINOP_DOT;
|
||||
case HloOpcode::kMultiply:
|
||||
@ -307,19 +315,41 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
switch (operation) {
|
||||
case UNOP_FLOOR:
|
||||
case UNOP_CEIL:
|
||||
if (!ShapeUtil::ElementIsFloating(arg)) {
|
||||
return InvalidArgument(
|
||||
"expected element type in shape to be floating for floor/ceil "
|
||||
"operation; got %s",
|
||||
PrimitiveType_Name(arg.element_type()).c_str());
|
||||
}
|
||||
return arg;
|
||||
case UNOP_COS:
|
||||
case UNOP_SIN:
|
||||
case UNOP_EXP:
|
||||
case UNOP_LOG:
|
||||
case UNOP_TANH:
|
||||
if (!ShapeUtil::ElementIsFloating(arg)) {
|
||||
if (!ShapeUtil::ElementIsFloating(arg) &&
|
||||
!ShapeUtil::ElementIsComplex(arg)) {
|
||||
return InvalidArgument(
|
||||
"expected element type in shape to be floating for exp/log/tanh "
|
||||
"operation; got %s",
|
||||
"expected element type in shape to be floating or complex for "
|
||||
"sin/cos/exp/log/tanh operation; got %s",
|
||||
PrimitiveType_Name(arg.element_type()).c_str());
|
||||
}
|
||||
return arg;
|
||||
case UNOP_REAL:
|
||||
case UNOP_IMAG:
|
||||
if (!ShapeUtil::ElementIsComplex(arg)) {
|
||||
return InvalidArgument(
|
||||
"expected element type in shape to be complex for real/imag "
|
||||
"operation; got %s",
|
||||
PrimitiveType_Name(arg.element_type()).c_str());
|
||||
}
|
||||
return ShapeUtil::ChangeElementType(arg, F32);
|
||||
case UNOP_ABS:
|
||||
if (ShapeUtil::ElementIsComplex(arg)) {
|
||||
return ShapeUtil::ChangeElementType(
|
||||
arg, primitive_util::ComplexComponentType(arg.element_type()));
|
||||
}
|
||||
return arg;
|
||||
case UNOP_NEGATE:
|
||||
case UNOP_ROUND_NEAREST_AFZ:
|
||||
case UNOP_SIGN:
|
||||
@ -751,6 +781,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
case BINOP_MIN:
|
||||
case BINOP_SUB:
|
||||
case BINOP_ADD:
|
||||
case BINOP_ATAN2:
|
||||
case BINOP_POW:
|
||||
case BINOP_DIV:
|
||||
case BINOP_REM:
|
||||
@ -761,6 +792,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
return InferElementwiseBinaryOpShape(operation, lhs, rhs,
|
||||
broadcast_dimensions);
|
||||
|
||||
case BINOP_COMPLEX: {
|
||||
if (!ShapeUtil::ElementIsFloating(lhs)) {
|
||||
return InvalidArgument(
|
||||
"expected element type in shape to be floating for complex compose "
|
||||
"operation; got %s",
|
||||
PrimitiveType_Name(lhs.element_type()).c_str());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape,
|
||||
InferElementwiseBinaryOpShape(operation, lhs, rhs,
|
||||
broadcast_dimensions));
|
||||
if (lhs.element_type() == F32) {
|
||||
return ShapeUtil::ChangeElementType(shape, C64);
|
||||
} else {
|
||||
return Unimplemented("complex component type not supported");
|
||||
}
|
||||
}
|
||||
case BINOP_AND:
|
||||
case BINOP_OR:
|
||||
if (lhs.element_type() != PRED &&
|
||||
|
@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test {
|
||||
// Some handy scalar shapes.
|
||||
const Shape s32_ = ShapeUtil::MakeShape(S32, {});
|
||||
const Shape f32_ = ShapeUtil::MakeShape(F32, {});
|
||||
const Shape f64_ = ShapeUtil::MakeShape(F64, {});
|
||||
const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
|
||||
|
||||
// Some handy vector and matrix shapes of F32 type.
|
||||
@ -251,6 +252,44 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) {
|
||||
.ok());
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, Complex) {
|
||||
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
|
||||
const tensorflow::gtl::ArraySlice<int64>& bcast) {
|
||||
return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX,
|
||||
lhs, rhs, bcast);
|
||||
};
|
||||
// Inputs must be FP.
|
||||
ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
|
||||
ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
|
||||
// Component types must match.
|
||||
ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
|
||||
// Only F32->C64 supported.
|
||||
ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok());
|
||||
// Validate correct uses.
|
||||
Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
|
||||
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
||||
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
||||
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
||||
|
||||
Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
|
||||
TF_ASSERT_OK_AND_ASSIGN(result,
|
||||
complex_shape(vector_64_, matrix_32_64_, {1}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||
TF_ASSERT_OK_AND_ASSIGN(result,
|
||||
complex_shape(matrix_32_64_, vector_64_, {1}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||
TF_ASSERT_OK_AND_ASSIGN(result,
|
||||
complex_shape(matrix_32_64_, matrix_32_64_, {}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
|
||||
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
|
||||
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
|
||||
|
@ -55,6 +55,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
|
||||
return HloOpcode::kExp;
|
||||
case UNOP_FLOOR:
|
||||
return HloOpcode::kFloor;
|
||||
case UNOP_IMAG:
|
||||
return HloOpcode::kImag;
|
||||
case UNOP_IS_FINITE:
|
||||
return HloOpcode::kIsFinite;
|
||||
case UNOP_LOG:
|
||||
@ -63,6 +65,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
|
||||
return HloOpcode::kNot;
|
||||
case UNOP_NEGATE:
|
||||
return HloOpcode::kNegate;
|
||||
case UNOP_REAL:
|
||||
return HloOpcode::kReal;
|
||||
case UNOP_ROUND_NEAREST_AFZ:
|
||||
return HloOpcode::kRoundNearestAfz;
|
||||
case UNOP_SIGN:
|
||||
@ -80,6 +84,10 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
|
||||
|
||||
HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
|
||||
switch (binop) {
|
||||
case BINOP_ATAN2:
|
||||
return HloOpcode::kAtan2;
|
||||
case BINOP_COMPLEX:
|
||||
return HloOpcode::kComplex;
|
||||
case BINOP_DOT:
|
||||
return HloOpcode::kDot;
|
||||
case BINOP_MUL:
|
||||
|
@ -272,6 +272,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
case U16:
|
||||
case U32:
|
||||
case U64:
|
||||
case C64:
|
||||
case TUPLE:
|
||||
case OPAQUE:
|
||||
return false;
|
||||
|
@ -361,8 +361,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
|
||||
ComputationBuilder* builder, const Array2D<NativeT>& expected,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
||||
static_assert(std::is_same<NativeT, float>::value ||
|
||||
std::is_same<NativeT, double>::value,
|
||||
"Floating point type required when specifying an ErrorSpec");
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal::CreateR2FromArray2D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
@ -384,8 +385,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
|
||||
ComputationBuilder* builder, const Array3D<NativeT>& expected,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
||||
static_assert(std::is_same<NativeT, float>::value ||
|
||||
std::is_same<NativeT, double>::value,
|
||||
"Floating point type required when specifying an ErrorSpec");
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal::CreateR3FromArray3D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
@ -407,8 +409,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
|
||||
ComputationBuilder* builder, const Array4D<NativeT>& expected,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
||||
static_assert(std::is_same<NativeT, float>::value ||
|
||||
std::is_same<NativeT, double>::value,
|
||||
"Floating point type required when specifying an ErrorSpec");
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
std::unique_ptr<Literal> expected_literal =
|
||||
Literal::CreateR4FromArray4D<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||
|
@ -347,7 +347,7 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) {
|
||||
TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
|
||||
}
|
||||
|
||||
TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
|
||||
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
|
||||
constexpr bool kLhsRowMajor = true;
|
||||
constexpr bool kRhsRowMajor = true;
|
||||
TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
|
||||
@ -357,7 +357,11 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) {
|
||||
TestNonsquareMatrixDot<double>();
|
||||
}
|
||||
|
||||
TEST_F(DotOperationTest, ConcurrentMatMul) {
|
||||
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
|
||||
TestNonsquareMatrixDot<complex64>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(DotOperationTest, ConcurrentMatMul) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});
|
||||
|
@ -41,7 +41,11 @@ class UnaryOpTest : public ClientLibraryTestBase {
|
||||
auto arg = builder.ConstantR1<T>({});
|
||||
auto abs = builder.Abs(arg);
|
||||
|
||||
ComputeAndCompareR1<T>(&builder, {}, {});
|
||||
if (primitive_util::NativeToPrimitiveType<T>() == C64) {
|
||||
ComputeAndCompareR1<float>(&builder, {}, {});
|
||||
} else {
|
||||
ComputeAndCompareR1<T>(&builder, {}, {});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -80,14 +84,58 @@ int UnaryOpTest::inf<int>() {
|
||||
return 2147483647;
|
||||
}
|
||||
|
||||
template <>
|
||||
void UnaryOpTest::AbsTestHelper<complex64>() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto arg = builder.ConstantR1<complex64>({{-2, 0},
|
||||
{0, 25},
|
||||
{0, 0},
|
||||
{-0.3f, 0.4f},
|
||||
{0, inf<float>()},
|
||||
{-inf<float>(), 0}});
|
||||
auto abs = builder.Abs(arg);
|
||||
|
||||
std::unique_ptr<Literal> expected =
|
||||
Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
|
||||
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
|
||||
}
|
||||
|
||||
template <>
|
||||
void UnaryOpTest::SignTestHelper<complex64>() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto arg = builder.ConstantR1<complex64>(
|
||||
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
|
||||
auto sign = builder.Sign(arg);
|
||||
|
||||
std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>(
|
||||
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
|
||||
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
|
||||
}
|
||||
|
||||
template <>
|
||||
void UnaryOpTest::SignAbsTestHelper<complex64>() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto arg =
|
||||
builder.ConstantR1<complex64>({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}});
|
||||
auto sign = builder.Sign(arg);
|
||||
auto abs = builder.Abs(arg);
|
||||
builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg);
|
||||
|
||||
std::unique_ptr<Literal> expected =
|
||||
Literal::CreateR1<complex64>({0, 0, 0, 0});
|
||||
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
|
||||
}
|
||||
|
||||
XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
|
||||
AbsSize0TestHelper<int>();
|
||||
AbsSize0TestHelper<float>();
|
||||
AbsSize0TestHelper<complex64>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(UnaryOpTest, AbsTestR1) {
|
||||
AbsTestHelper<int>();
|
||||
AbsTestHelper<float>();
|
||||
AbsTestHelper<complex64>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(UnaryOpTest, AbsTestR0) {
|
||||
@ -98,34 +146,44 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) {
|
||||
auto absf = builder.Abs(argf);
|
||||
auto argf0 = builder.ConstantR0<float>(-0.0f);
|
||||
auto absf0 = builder.Abs(argf0);
|
||||
builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
|
||||
absi, PrimitiveType::F32)));
|
||||
auto argc = builder.ConstantR0<complex64>({-0.3f, 0.4f});
|
||||
auto absc = builder.Abs(argc);
|
||||
builder.Add(builder.Add(absc, absf0),
|
||||
builder.Add(absf, builder.ConvertElementType(absi, F32)));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, 8.0f, {});
|
||||
ComputeAndCompareR0<float>(&builder, 8.5f, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(UnaryOpTest, SignTestR0) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto argi = builder.ConstantR0<int>(-5);
|
||||
auto absi = builder.Sign(argi);
|
||||
auto sgni = builder.Sign(argi); // -1
|
||||
auto argf = builder.ConstantR0<float>(-4.0f);
|
||||
auto absf = builder.Sign(argf);
|
||||
auto sgnf = builder.Sign(argf); // -1
|
||||
auto argf0 = builder.ConstantR0<float>(-0.0f);
|
||||
auto absf0 = builder.Sign(argf0);
|
||||
builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
|
||||
absi, PrimitiveType::F32)));
|
||||
auto sgnf0 = builder.Sign(argf0); // 0
|
||||
auto argc = builder.ConstantR0<complex64>({-.3, .4});
|
||||
auto sgnc = builder.Sign(argc); // (-.6, .8)
|
||||
builder.Add(sgnc, builder.ConvertElementType(
|
||||
builder.Add(builder.Add(sgnf0, sgnf),
|
||||
builder.ConvertElementType(sgni, F32)),
|
||||
C64));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, -2.0f, {});
|
||||
std::unique_ptr<Literal> expected =
|
||||
Literal::CreateR0<complex64>({-2.6f, 0.8f});
|
||||
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
|
||||
}
|
||||
|
||||
XLA_TEST_F(UnaryOpTest, SignTestR1) {
|
||||
SignTestHelper<int>();
|
||||
SignTestHelper<float>();
|
||||
SignTestHelper<complex64>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
|
||||
SignAbsTestHelper<int>();
|
||||
SignAbsTestHelper<float>();
|
||||
SignAbsTestHelper<complex64>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
|
||||
|
@ -235,11 +235,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kImag:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSort:
|
||||
@ -256,6 +258,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kAtan2:
|
||||
case HloOpcode::kComplex:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TYPES_H_
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -35,7 +37,7 @@ using ::tensorflow::uint16;
|
||||
using ::tensorflow::uint32;
|
||||
using ::tensorflow::uint64;
|
||||
|
||||
typedef std::complex<float> complex64;
|
||||
using complex64 = std::complex<float>;
|
||||
|
||||
using ::Eigen::half;
|
||||
|
||||
|
@ -49,7 +49,7 @@ enum PrimitiveType {
|
||||
F64 = 12;
|
||||
|
||||
// Complex values of fixed width.
|
||||
C64 = 15;
|
||||
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>.
|
||||
|
||||
// A tuple is a polymorphic sequence; e.g. a shape that holds different
|
||||
// sub-shapes. They are used for things like returning multiple values from a
|
||||
@ -667,6 +667,12 @@ enum UnaryOperation {
|
||||
// Elementwise, rounds x to nearest integral value, rounding half-way cases
|
||||
// away from zero.
|
||||
UNOP_ROUND_NEAREST_AFZ = 14;
|
||||
|
||||
// Elementwise, extract real component of complex x.
|
||||
UNOP_REAL = 15;
|
||||
|
||||
// Elementwise, extract real component of complex x.
|
||||
UNOP_IMAG = 16;
|
||||
}
|
||||
|
||||
message UnaryOpRequest {
|
||||
@ -721,6 +727,12 @@ enum BinaryOperation {
|
||||
BINOP_SHIFT_LEFT = 20;
|
||||
BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
|
||||
BINOP_SHIFT_RIGHT_LOGICAL = 22;
|
||||
|
||||
// Complex from real, imag.
|
||||
BINOP_COMPLEX = 23;
|
||||
|
||||
// Computes the 4-quadrant arctangent of the y, x input arguments.
|
||||
BINOP_ATAN2 = 24;
|
||||
}
|
||||
|
||||
message BinaryOpRequest {
|
||||
|
Loading…
Reference in New Issue
Block a user