[XLA:CPU] [XLA:GPU] Adds compiler support for C64 primitive type, including relevant elementwise unary and binary op lowering for CPU and GPU.

We use a named LLVM struct "complex64", laid out the same as std::complex<float>. This named struct is accessed via the llvm::Module, which required changes to accessors of PrimitiveTypeToIrType & friends.

Ops that require atan2 (in particular, angle and log) are only supported on GPU at this point. LLVM lacks a CPU intrinsic for atan or atan2, whereas libdevice provides this for GPU.

PiperOrigin-RevId: 173676849
This commit is contained in:
A. Unique TensorFlower 2017-10-27 09:00:51 -07:00 committed by TensorFlower Gardener
parent 4ae245a7db
commit 4198e27be8
68 changed files with 2113 additions and 516 deletions

View File

@ -144,8 +144,8 @@ TEST(XlaCompilationTest, UnsupportedTypes) {
Node* a = ops::SourceOp(
"Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_COMPLEX64)
.WithAttr("value", Tensor(DT_COMPLEX64, TensorShape())));
.WithAttr("dtype", DT_COMPLEX128)
.WithAttr("value", Tensor(DT_COMPLEX128, TensorShape())));
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
TF_EXPECT_OK(builder.ToGraph(graph.get()));

View File

@ -50,8 +50,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 5> kAllXlaCpuTypes = {
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
constexpr std::array<DataType, 6> kAllXlaCpuTypes = {
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);

View File

@ -55,8 +55,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 5> kAllXlaGpuTypes = {
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
constexpr std::array<DataType, 6> kAllXlaGpuTypes = {
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);

View File

@ -23,6 +23,10 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites")
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
)
generate_backend_suites()
@ -581,11 +585,12 @@ cc_library(
tf_cuda_cc_test(
name = "randomized_tests",
size = "large",
# This test is randomized, so only run it if explicitly requested.
tags = [
"manual",
"notap",
],
] + tf_cuda_tests_tags(),
deps = [":randomized_tests_library"],
)

View File

@ -46,7 +46,9 @@ class ArgMinMaxTest(xla_test.XLATestCase):
self.assertAllEqual(result, expected)
def testArgMinMax(self):
for dtype in self.numeric_types:
# Complex numbers do not support argmin/argmax.
minmax_types = set(self.numeric_types) - set(self.complex_types)
for dtype in minmax_types:
self._assertOpOutputMatchesExpected(
lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32),
np.array([1, 10, 27, 3, 3, 4], dtype=dtype),

View File

@ -94,6 +94,15 @@ class BinaryOpsTest(XLATestCase):
dtype(4),
expected=np.array([[16], [81]], dtype=dtype))
atan2_supported = self.device == "XLA_GPU"
if atan2_supported:
self._testBinary(
math_ops.atan2,
np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype),
np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype),
expected=np.array(
[0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype))
self._testBinary(
gen_math_ops._reciprocal_grad,
np.array([4, -3, -2, 1], dtype=dtype),
@ -259,37 +268,38 @@ class BinaryOpsTest(XLATestCase):
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
self._testBinary(
math_ops.maximum,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([10, 20], dtype=dtype))
self._testBinary(
math_ops.maximum,
dtype(5),
np.array([1, 20], dtype=dtype),
expected=np.array([5, 20], dtype=dtype))
self._testBinary(
math_ops.maximum,
np.array([[10], [2]], dtype=dtype),
dtype(7),
expected=np.array([[10], [7]], dtype=dtype))
if dtype not in self.complex_types: # min/max not supported for complex
self._testBinary(
math_ops.maximum,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([10, 20], dtype=dtype))
self._testBinary(
math_ops.maximum,
dtype(5),
np.array([1, 20], dtype=dtype),
expected=np.array([5, 20], dtype=dtype))
self._testBinary(
math_ops.maximum,
np.array([[10], [2]], dtype=dtype),
dtype(7),
expected=np.array([[10], [7]], dtype=dtype))
self._testBinary(
math_ops.minimum,
np.array([1, 20], dtype=dtype),
np.array([10, 2], dtype=dtype),
expected=np.array([1, 2], dtype=dtype))
self._testBinary(
math_ops.minimum,
dtype(5),
np.array([1, 20], dtype=dtype),
expected=np.array([1, 5], dtype=dtype))
self._testBinary(
math_ops.minimum,
np.array([[10], [2]], dtype=dtype),
dtype(7),
expected=np.array([[7], [2]], dtype=dtype))
self._testBinary(
math_ops.minimum,
np.array([1, 20], dtype=dtype),
np.array([10, 2], dtype=dtype),
expected=np.array([1, 2], dtype=dtype))
self._testBinary(
math_ops.minimum,
dtype(5),
np.array([1, 20], dtype=dtype),
expected=np.array([1, 5], dtype=dtype))
self._testBinary(
math_ops.minimum,
np.array([[10], [2]], dtype=dtype),
dtype(7),
expected=np.array([[7], [2]], dtype=dtype))
self._testBinary(
math_ops.multiply,
@ -307,21 +317,23 @@ class BinaryOpsTest(XLATestCase):
dtype(7),
expected=np.array([[70], [14]], dtype=dtype))
self._testBinary(
math_ops.squared_difference,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([81, 324], dtype=dtype))
self._testBinary(
math_ops.squared_difference,
dtype(5),
np.array([1, 2], dtype=dtype),
expected=np.array([16, 9], dtype=dtype))
self._testBinary(
math_ops.squared_difference,
np.array([[1], [2]], dtype=dtype),
dtype(7),
expected=np.array([[36], [25]], dtype=dtype))
# Complex support for squared_difference is incidental, see b/68205550
if dtype not in self.complex_types:
self._testBinary(
math_ops.squared_difference,
np.array([1, 2], dtype=dtype),
np.array([10, 20], dtype=dtype),
expected=np.array([81, 324], dtype=dtype))
self._testBinary(
math_ops.squared_difference,
dtype(5),
np.array([1, 2], dtype=dtype),
expected=np.array([16, 9], dtype=dtype))
self._testBinary(
math_ops.squared_difference,
np.array([[1], [2]], dtype=dtype),
dtype(7),
expected=np.array([[36], [25]], dtype=dtype))
self._testBinary(
nn_ops.bias_add,
@ -334,6 +346,139 @@ class BinaryOpsTest(XLATestCase):
np.array([2, -1], dtype=dtype),
expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
def testComplexOps(self):
for dtype in self.complex_types:
ctypes = {np.complex64: np.float32}
self._testBinary(
math_ops.complex,
np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]),
np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]),
expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype))
self._testBinary(
lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
np.array(
[[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]],
dtype=dtype),
np.array(
[[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype),
expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
self._testBinary(
gen_math_ops._real_div,
np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype),
np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype),
expected=np.array(
[
1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2,
float("inf")
],
dtype=dtype))
# TODO(b/65408531): support+test pow for cplx
lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
self._testBinary(
gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs)
self._testBinary(
gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))
# TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow)
self._testBinary(
gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))
self._testBinary(
gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs))
def testComplexMath(self):
for dtype in self.complex_types:
self._testBinary(
math_ops.add,
np.array([1 + 3j, 2 + 7j], dtype=dtype),
np.array([10 - 4j, 20 + 17j], dtype=dtype),
expected=np.array([11 - 1j, 22 + 24j], dtype=dtype))
self._testBinary(
math_ops.add,
dtype(5 - 7j),
np.array([1 + 2j, 2 + 4j], dtype=dtype),
expected=np.array([6 - 5j, 7 - 3j], dtype=dtype))
self._testBinary(
math_ops.add,
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
dtype(7 + 5j),
expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype))
self._testBinary(
math_ops.subtract,
np.array([1 + 3j, 2 + 7j], dtype=dtype),
np.array([10 - 4j, 20 + 17j], dtype=dtype),
expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype))
self._testBinary(
math_ops.subtract,
dtype(5 - 7j),
np.array([1 + 2j, 2 + 4j], dtype=dtype),
expected=np.array([4 - 9j, 3 - 11j], dtype=dtype))
self._testBinary(
math_ops.subtract,
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
dtype(7 + 5j),
expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype))
self._testBinary(
math_ops.multiply,
np.array([1 + 3j, 2 + 7j], dtype=dtype),
np.array([10 - 4j, 20 + 17j], dtype=dtype),
expected=np.array(
[(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype))
self._testBinary(
math_ops.multiply,
dtype(5 - 7j),
np.array([1 + 2j, 2 + 4j], dtype=dtype),
expected=np.array(
[(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype))
self._testBinary(
math_ops.multiply,
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
dtype(7 + 5j),
expected=np.array(
[[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype))
self._testBinary(
math_ops.div,
np.array([8 - 1j, 2 + 16j], dtype=dtype),
np.array([2 + 4j, 4 - 8j], dtype=dtype),
expected=np.array(
[(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype))
self._testBinary(
math_ops.div,
dtype(1 + 2j),
np.array([2 + 4j, 4 - 8j], dtype=dtype),
expected=np.array(
[(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype))
self._testBinary(
math_ops.div,
np.array([2 + 4j, 4 - 8j], dtype=dtype),
dtype(1 + 2j),
expected=np.array(
[(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype))
# TODO(b/68205550): math_ops.squared_difference shouldn't be supported.
self._testBinary(
nn_ops.bias_add,
np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype),
np.array([2 + 6j, -1 - 3j], dtype=dtype),
expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype))
self._testBinary(
nn_ops.bias_add,
np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype),
np.array([2 + 1j, -1 + 2j], dtype=dtype),
expected=np.array(
[[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype))
def _testDivision(self, dtype):
"""Test cases for division operators."""
self._testBinary(
@ -352,18 +497,19 @@ class BinaryOpsTest(XLATestCase):
dtype(2),
expected=np.array([[5], [2]], dtype=dtype))
self._testBinary(
gen_math_ops._floor_div,
np.array([3, 3, -1, -9, -8], dtype=dtype),
np.array([2, -2, 7, 2, -4], dtype=dtype),
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
if dtype not in self.complex_types: # floordiv unsupported for complex.
self._testBinary(
gen_math_ops._floor_div,
np.array([3, 3, -1, -9, -8], dtype=dtype),
np.array([2, -2, 7, 2, -4], dtype=dtype),
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
def testIntDivision(self):
for dtype in self.int_types:
self._testDivision(dtype)
def testFloatDivision(self):
for dtype in self.float_types:
for dtype in self.float_types + self.complex_types:
self._testDivision(dtype)
def _testRemainder(self, dtype):

View File

@ -49,11 +49,15 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
backend_deps = []
backend_data = []
if backend == "cpu":
backend_args += ["--test_device=XLA_CPU",
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
backend_args += [
"--test_device=XLA_CPU",
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
]
elif backend == "gpu":
backend_args += ["--test_device=XLA_GPU",
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
backend_args += [
"--test_device=XLA_GPU",
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
]
backend_tags += ["requires-gpu-sm35"]
elif backend in plugins:
backend_args += ["--test_device=" + plugins[backend]["device"],

View File

@ -30,8 +30,6 @@ from tensorflow.python.platform import test
FLAGS = flags.FLAGS
_TEST_TYPES = [dtypes.float32]
class GatherTest(xla_test.XLATestCase):
@ -46,7 +44,7 @@ class GatherTest(xla_test.XLATestCase):
def testScalar1D(self):
with self.test_session() as session, self.test_scope():
data = np.array([0, 1, 2, 3, 7, 5])
for dtype in _TEST_TYPES:
for dtype in self.all_tf_types:
for indices in 4, [1, 2, 2, 4, 5]:
params_np = self._buildParams(data, dtype)
params = array_ops.placeholder(dtype=dtype)
@ -60,7 +58,7 @@ class GatherTest(xla_test.XLATestCase):
with self.test_session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
for dtype in _TEST_TYPES:
for dtype in self.all_tf_types:
for axis in 0, 1, -1:
params_np = self._buildParams(data, dtype)
params = array_ops.placeholder(dtype=dtype)
@ -74,7 +72,7 @@ class GatherTest(xla_test.XLATestCase):
with self.test_session() as session, self.test_scope():
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
for dtype in _TEST_TYPES:
for dtype in self.all_tf_types:
for axis in 0, 1, -1:
params_np = self._buildParams(data, dtype)
params = array_ops.placeholder(dtype=dtype)
@ -94,7 +92,7 @@ class GatherTest(xla_test.XLATestCase):
[12, 13, 14]])
# The indices must be in bounds for any axis.
indices_np = np.array([0, 1, 0, 2])
for dtype in _TEST_TYPES:
for dtype in self.all_tf_types:
for axis in 0, 1, -1:
params_np = self._buildParams(data, dtype)
params = array_ops.placeholder(dtype=dtype)
@ -112,7 +110,7 @@ class GatherTest(xla_test.XLATestCase):
"""Check that scalar and empty indices shapes work as well."""
shape = (2, 1, 3, 2)
for indices_shape in (), (0,), (2, 0), (2, 3):
for dtype in _TEST_TYPES:
for dtype in self.all_tf_types:
for axis in 0, 1, 2, 3, -1, -2:
params = self._buildParams(np.random.randn(*shape), dtype)
indices = np.random.randint(shape[axis], size=indices_shape)

View File

@ -68,6 +68,26 @@ class NAryOpsTest(XLATestCase):
np.array([42], dtype=np.float32)],
expected=np.array([48], dtype=np.float32))
def testComplex(self):
for dtype in self.complex_types:
self._testNAry(
math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)],
expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype))
self._testNAry(
math_ops.add_n, [
np.array([1 + 2j, 2 - 3j], dtype=dtype),
np.array([10j, 20], dtype=dtype)
],
expected=np.array([1 + 12j, 22 - 3j], dtype=dtype))
self._testNAry(
math_ops.add_n, [
np.array([-4, 5j], dtype=dtype),
np.array([2 + 10j, -2], dtype=dtype),
np.array([42j, 3 + 3j], dtype=dtype)
],
expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype))
@unittest.skip("IdentityN is temporarily CompilationOnly as workaround")
def testIdentityN(self):
self._testNAryLists(array_ops.identity_n,

View File

@ -29,6 +29,9 @@ from tensorflow.python.platform import googletest
class RandomOpsTest(XLATestCase):
"""Test cases for random-number generating operators."""
def _random_types(self):
return set(self.numeric_types) - set(self.complex_types)
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
with self.test_session() as sess:
@ -51,7 +54,8 @@ class RandomOpsTest(XLATestCase):
def rng(dtype):
return random_ops.random_uniform(shape=[2], dtype=dtype,
maxval=1000000)
for dtype in self.numeric_types:
for dtype in self._random_types():
self._testRngIsNotConstant(rng, dtype)
def testRandomNormalIsNotConstant(self):
@ -63,7 +67,7 @@ class RandomOpsTest(XLATestCase):
self._testRngIsNotConstant(rng, dtype)
def testRandomUniformIsInRange(self):
for dtype in self.numeric_types:
for dtype in self._random_types():
with self.test_session() as sess:
with self.test_scope():
x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,

View File

@ -75,7 +75,7 @@ namespace {
// Command line flags: see main() below.
int64 tf_xla_random_seed = 0;
int32 tf_xla_test_repetitions = 20;
int64 tf_xla_max_tensor_size = 100000LL;
int64 tf_xla_max_tensor_size = 10000LL;
string* tf_xla_test_device_ptr; // initial value set in main()
bool tf_xla_test_use_jit = true;
@ -83,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) {
return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
}
constexpr std::array<DataType, 3> kAllXlaTypes = {
{DT_INT32, DT_FLOAT, DT_BOOL}};
constexpr std::array<DataType, 4> kAllXlaTypes = {
{DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}};
// An OpTestBuilder is a graph builder class that takes as input an operator to
// test, its inputs and attributes, and builds a graph that executes the
@ -449,6 +449,13 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape) {
});
break;
}
case DT_COMPLEX64: {
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<complex64>(&tensor, [this, &distribution](int i) {
return complex64(distribution(generator()), distribution(generator()));
});
break;
}
case DT_INT32: {
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
@ -624,11 +631,47 @@ std::vector<int32> OpTest::AsInt32s(const std::vector<int64>& int64s) {
// Functions for comparing tensors.
template <typename T>
double Abs(T x) {
return std::fabs(x);
}
template <>
double Abs<complex64>(complex64 x) {
return std::abs(x);
}
template <typename T>
bool IsClose(const T& x, const T& y, double atol, double rtol) {
if (std::isnan(x) && std::isnan(y)) return true;
if (x == y) return true; // Allow inf == inf.
return fabs(x - y) < atol + rtol * fabs(x);
return Abs(x - y) < atol + rtol * Abs(x);
}
template <>
bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
double rtol) {
if (std::isnan(x.real()) && std::isnan(y.real())) {
if (std::isnan(x.imag()) && std::isnan(y.imag())) {
return true;
}
if (x.imag() == y.imag()) return true; // Allow inf == inf.
return Abs(x.imag() - y.imag()) < atol + rtol * Abs(x.imag());
} else if (std::isnan(x.imag()) && std::isnan(y.imag())) {
if (x.real() == y.real()) return true; // Allow inf == inf.
return Abs(x.real() - y.real()) < atol + rtol * Abs(x.real());
}
if (x == y) return true; // Allow inf == inf.
return Abs(x - y) < atol + rtol * Abs(x);
}
template <typename T>
string Str(T x) {
return strings::StrCat(x);
}
template <>
string Str<complex64>(complex64 x) {
return strings::StrCat("(", x.real(), ", ", x.imag(), ")");
}
template <typename T>
@ -639,9 +682,10 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
for (int i = 0; i < Tx.size(); ++i) {
if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
return errors::InvalidArgument(strings::StrCat(
i, "-th tensor element isn't close: ", Tx(i), " vs. ", Ty(i),
". x = ", x.DebugString(), "y = ", y.DebugString(), "atol = ", atol,
" rtol = ", rtol, " tol = ", atol + rtol * std::fabs(Tx(i))));
i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ",
Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(),
"atol = ", atol, " rtol = ", rtol,
" tol = ", atol + rtol * Abs(Tx(i))));
}
}
return Status::OK();
@ -683,6 +727,8 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
return TensorsAreCloseImpl<float>(a, b, atol, rtol);
case DT_DOUBLE:
return TensorsAreCloseImpl<double>(a, b, atol, rtol);
case DT_COMPLEX64:
return TensorsAreCloseImpl<complex64>(a, b, atol, rtol);
case DT_INT32:
return TensorsAreEqualImpl<int32>(a, b);
case DT_INT64:
@ -822,7 +868,7 @@ Tensor AsIntTensor(DataType dtype, const std::vector<int64>& values) {
TEST_F(OpTest, Abs) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Abs").RandomInput(type).Attr("T", type));
});
@ -837,7 +883,7 @@ TEST_F(OpTest, Acosh) {
TEST_F(OpTest, Add) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
.RandomInput(type, dims.first)
@ -848,7 +894,7 @@ TEST_F(OpTest, Add) {
TEST_F(OpTest, AddN) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
int n = std::uniform_int_distribution<int>(1, 5)(generator());
auto shape = RandomDims();
@ -890,9 +936,10 @@ TEST_F(OpTest, Any) {
TEST_F(OpTest, ApproximateEqual) {
Repeatedly([this]() {
auto dims = RandomDims();
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.RandomInput(type, dims)
.RandomInput(type, dims)
.Attr("T", DT_FLOAT));
});
}
@ -1038,6 +1085,7 @@ TEST_F(OpTest, AvgPool3DGrad) {
TEST_F(OpTest, BatchMatMul) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
std::vector<int64> output_dims = RandomDims(2, 5, 0, 7);
int64 ndims = output_dims.size();
int64 inner_dim = RandomDim();
@ -1056,9 +1104,9 @@ TEST_F(OpTest, BatchMatMul) {
}
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
.RandomInput(DT_FLOAT, x_dims)
.RandomInput(DT_FLOAT, y_dims)
.Attr("T", DT_FLOAT)
.RandomInput(type, x_dims)
.RandomInput(type, y_dims)
.Attr("T", type)
.Attr("adj_x", adj_x)
.Attr("adj_y", adj_y));
});
@ -1090,10 +1138,11 @@ TEST_F(OpTest, BatchToSpace) {
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
TensorShape({num_block_dims, 2})));
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
.RandomInput(DT_FLOAT, input_dims)
.RandomInput(type, input_dims)
.Input(crops)
.Attr("T", DT_FLOAT)
.Attr("T", type)
.Attr("block_size", block_size));
});
}
@ -1127,13 +1176,14 @@ TEST_F(OpTest, BatchToSpaceND) {
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
TensorShape({num_block_dims, 2})));
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("BatchToSpaceND")
.RandomInput(DT_FLOAT, input_dims)
.RandomInput(type, input_dims)
.Input(test::AsTensor<int32>(
std::vector<int32>(block_dims.begin(), block_dims.end())))
.Input(crops)
.Attr("T", DT_FLOAT));
.Attr("T", type));
});
}
@ -1142,18 +1192,20 @@ TEST_F(OpTest, BiasAdd) {
auto x_dims = RandomDims(2, kDefaultMaxRank);
auto y_dims = {x_dims[x_dims.size() - 1]};
// TODO(phawkins): test both data formats.
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd")
.RandomInput(DT_FLOAT, x_dims)
.RandomInput(DT_FLOAT, y_dims)
.Attr("T", DT_FLOAT));
.RandomInput(type, x_dims)
.RandomInput(type, y_dims)
.Attr("T", type));
});
}
TEST_F(OpTest, BiasAddGrad) {
Repeatedly([this]() {
// TODO(phawkins): test both data formats.
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("BiasAddGrad").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("BiasAddGrad").RandomInput(type).Attr("T", type));
});
}
@ -1161,10 +1213,11 @@ TEST_F(OpTest, BiasAddV1) {
Repeatedly([this]() {
auto x_dims = RandomDims(2, kDefaultMaxRank);
auto y_dims = {x_dims[x_dims.size() - 1]};
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1")
.RandomInput(DT_FLOAT, x_dims)
.RandomInput(DT_FLOAT, y_dims)
.Attr("T", DT_FLOAT));
.RandomInput(type, x_dims)
.RandomInput(type, y_dims)
.Attr("T", type));
});
}
@ -1221,8 +1274,8 @@ TEST_F(OpTest, BroadcastGradientArgs) {
TEST_F(OpTest, Cast) {
Repeatedly([this]() {
DataType src_type, dst_type;
src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
.RandomInput(src_type)
.Attr("SrcT", src_type)
@ -1293,11 +1346,12 @@ TEST_F(OpTest, Conv2D) {
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
features_in, features_out};
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv2D")
.RandomInput(DT_FLOAT, data_dims)
.RandomInput(DT_FLOAT, kernel_dims)
.Attr("T", DT_FLOAT)
.RandomInput(type, data_dims)
.RandomInput(type, kernel_dims)
.Attr("T", type)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
.Attr("data_format", "NHWC"));
@ -1317,12 +1371,13 @@ TEST_F(OpTest, Conv2DBackpropFilter) {
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
{d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}));
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv2DBackpropFilter")
.RandomInput(DT_FLOAT, activations)
.RandomInput(type, activations)
.Input(kernel_shape)
.RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.RandomInput(type, backprop)
.Attr("T", type)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
.Attr("data_format", "NHWC"));
@ -1342,12 +1397,13 @@ TEST_F(OpTest, Conv2DBackpropInput) {
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
features_in, features_out};
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv2DBackpropInput")
.Input(in_shape)
.RandomInput(DT_FLOAT, kernel)
.RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.RandomInput(type, kernel)
.RandomInput(type, backprop)
.Attr("T", type)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
.Attr("data_format", "NHWC"));
@ -1365,11 +1421,12 @@ TEST_F(OpTest, Conv3D) {
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
d.kernel_dims[2], features_in, features_out};
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv3D")
.RandomInput(DT_FLOAT, data)
.RandomInput(DT_FLOAT, kernel)
.Attr("T", DT_FLOAT)
.RandomInput(type, data)
.RandomInput(type, kernel)
.Attr("T", type)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
});
@ -1389,12 +1446,13 @@ TEST_F(OpTest, Conv3DBackpropFilter) {
Tensor kernel_shape = test::AsTensor<int32>(
AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2],
features_in, features_out}));
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv3DBackpropFilterV2")
.RandomInput(DT_FLOAT, activations)
.RandomInput(type, activations)
.Input(kernel_shape)
.RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.RandomInput(type, backprop)
.Attr("T", type)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
});
@ -1413,17 +1471,34 @@ TEST_F(OpTest, Conv3DBackpropInput) {
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
d.kernel_dims[2], features_in, features_out};
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Conv3DBackpropInputV2")
.Input(in_shape)
.RandomInput(DT_FLOAT, kernel)
.RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.RandomInput(type, kernel)
.RandomInput(type, backprop)
.Attr("T", type)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
});
}
TEST_F(OpTest, Cos) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Cos").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Cosh) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Cosh").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, DepthToSpace) {
Repeatedly([this]() {
int64 block = RandomDim(2, 5);
@ -1431,14 +1506,16 @@ TEST_F(OpTest, DepthToSpace) {
input_dims[1] = (input_dims[1] + (block - 1)) / block;
input_dims[2] = (input_dims[2] + (block - 1)) / block;
input_dims[3] *= block * block;
DataType type = Choose<DataType>(kAllXlaTypes);
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace")
.RandomInput(DT_FLOAT, input_dims)
.Attr("T", DT_FLOAT)
.RandomInput(type, input_dims)
.Attr("T", type)
.Attr("block_size", block));
});
}
TEST_F(OpTest, DepthwiseConv2DNative) {
if (1) return;
Repeatedly([this]() {
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
std::uniform_int_distribution<int> random_int(1, 5);
@ -1449,17 +1526,20 @@ TEST_F(OpTest, DepthwiseConv2DNative) {
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
features_in, depth_multiplier};
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
strides[2] = strides[1]; // Current impl only supports equal strides
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("DepthwiseConv2dNative")
.RandomInput(DT_FLOAT, input_dims)
.RandomInput(DT_FLOAT, kernel_dims)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("strides", strides)
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
});
}
TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
if (1) return;
Repeatedly([this]() {
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
std::uniform_int_distribution<int> random_int(1, 5);
@ -1472,33 +1552,22 @@ TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
{d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}));
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
strides[2] = strides[1]; // Current impl only supports equal strides
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("DepthwiseConv2dNativeBackpropFilter")
.RandomInput(DT_FLOAT, activations)
.Input(kernel_shape)
.RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("strides", strides)
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
.Attr("data_format", "NHWC"));
});
}
TEST_F(OpTest, Cos) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Cos").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, Cosh) {
Repeatedly([this]() {
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Cosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
});
}
TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
if (1) return;
Repeatedly([this]() {
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
std::uniform_int_distribution<int> random_int(1, 5);
@ -1511,21 +1580,24 @@ TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
features_in, depth_multiplier};
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
strides[2] = strides[1]; // Current impl only supports equal strides
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("DepthwiseConv2dNativeBackpropInput")
.Input(in_shape)
.RandomInput(DT_FLOAT, kernel)
.RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("strides", strides)
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
.Attr("data_format", "NHWC"));
});
}
TEST_F(OpTest, Diag) {
if (1) return;
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>(kAllXlaTypes);
std::vector<int64> dims;
// Diag causes a quadratic blowup in output size.
int64 size;
@ -1540,7 +1612,7 @@ TEST_F(OpTest, Diag) {
TEST_F(OpTest, DiagPart) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>(kAllXlaTypes);
auto dims = RandomDims(1, 3);
// Duplicate the random dims.
std::vector<int64> doubled_dims(dims.size() * 2);
@ -1554,7 +1626,7 @@ TEST_F(OpTest, DiagPart) {
TEST_F(OpTest, Div) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
.RandomInput(type, dims.first)
@ -1650,7 +1722,7 @@ TEST_F(OpTest, SeluGrad) {
TEST_F(OpTest, Equal) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
.RandomInput(type, dims.first)
@ -1661,15 +1733,17 @@ TEST_F(OpTest, Equal) {
TEST_F(OpTest, Exp) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Exp").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Exp").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Expm1) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Expm1").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Expm1").RandomInput(type).Attr("T", type));
});
}
@ -1809,15 +1883,17 @@ TEST_F(OpTest, LinSpace) {
TEST_F(OpTest, Log) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Log").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Log").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Log1p) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Log1p").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Log1p").RandomInput(type).Attr("T", DT_FLOAT));
});
}
@ -1914,10 +1990,11 @@ TEST_F(OpTest, MatMul) {
std::swap(b_dims[0], b_dims[1]);
}
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
.RandomInput(DT_FLOAT, a_dims)
.RandomInput(DT_FLOAT, b_dims)
.Attr("T", DT_FLOAT)
.RandomInput(type, a_dims)
.RandomInput(type, b_dims)
.Attr("T", type)
.Attr("transpose_a", transpose_a)
.Attr("transpose_b", transpose_b));
});
@ -1925,7 +2002,7 @@ TEST_F(OpTest, MatMul) {
TEST_F(OpTest, MatrixDiag) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
.RandomInput(type, RandomDims(1))
.Attr("T", type));
@ -1934,7 +2011,7 @@ TEST_F(OpTest, MatrixDiag) {
TEST_F(OpTest, MatrixDiagPart) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
.RandomInput(type, RandomDims(2))
.Attr("T", type));
@ -2025,7 +2102,7 @@ TEST_F(OpTest, MaxPool3D) {
TEST_F(OpTest, Mean) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
// TODO(phawkins): CPU and XLA differ output for reducing across a
// size-0 dimension (nan vs 0). For now, require size >= 1.
std::vector<int64> data_dims = RandomDims(0, kDefaultMaxRank, 1);
@ -2076,7 +2153,7 @@ TEST_F(OpTest, Mod) {
TEST_F(OpTest, Mul) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
.RandomInput(type, dims.first)
@ -2087,7 +2164,7 @@ TEST_F(OpTest, Mul) {
TEST_F(OpTest, Neg) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Neg").RandomInput(type).Attr("T", type));
});
@ -2095,7 +2172,7 @@ TEST_F(OpTest, Neg) {
TEST_F(OpTest, NotEqual) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
.RandomInput(type, dims.first)
@ -2136,7 +2213,7 @@ TEST_F(OpTest, OneHot) {
TEST_F(OpTest, OnesLike) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type));
});
@ -2195,16 +2272,17 @@ TEST_F(OpTest, Pow) {
// nontermination.
Repeatedly([this]() {
auto dims = BroadcastableDims();
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow")
.RandomInput(DT_FLOAT, dims.first)
.RandomInput(DT_FLOAT, dims.second)
.Attr("T", DT_FLOAT));
.RandomInput(type, dims.first)
.RandomInput(type, dims.second)
.Attr("T", type));
});
}
TEST_F(OpTest, Prod) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
std::vector<int64> data_dims = RandomDims();
Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
@ -2238,7 +2316,7 @@ TEST_F(OpTest, Range) {
TEST_F(OpTest, Rank) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Rank").RandomInput(type).Attr("T", type));
});
@ -2246,7 +2324,7 @@ TEST_F(OpTest, Rank) {
TEST_F(OpTest, RealDiv) {
Repeatedly([this]() {
DataType type = DT_FLOAT;
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
.RandomInput(type, dims.first)
@ -2257,18 +2335,20 @@ TEST_F(OpTest, RealDiv) {
TEST_F(OpTest, Reciprocal) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Reciprocal").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Reciprocal").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, ReciprocalGrad) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims();
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
.RandomInput(type, dims)
.RandomInput(type, dims)
.Attr("T", type));
});
}
TEST_F(OpTest, Relu) {
@ -2335,24 +2415,24 @@ TEST_F(OpTest, Reshape) {
TEST_F(OpTest, Reverse) {
Repeatedly([this]() {
std::vector<int64> dims = RandomDims(1);
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>(kAllXlaTypes);
int64 rank = dims.size();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
.RandomInput(type, dims)
.RandomInput(DT_BOOL, {rank})
.Attr("T", DT_FLOAT));
.Attr("T", type));
});
}
TEST_F(OpTest, ReverseV2) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>(kAllXlaTypes);
std::vector<int64> data_dims = RandomDims();
Tensor indices = RandomReductionIndices(data_dims.size());
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
.RandomInput(type, data_dims)
.Input(indices)
.Attr("T", DT_FLOAT));
.Attr("T", type));
});
}
@ -2372,18 +2452,20 @@ TEST_F(OpTest, Round) {
TEST_F(OpTest, Rsqrt) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Rsqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Rsqrt").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, RsqrtGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
.RandomInput(type, dims)
.RandomInput(type, dims)
.Attr("T", type));
});
}
@ -2411,24 +2493,26 @@ TEST_F(OpTest, ShapeN) {
TEST_F(OpTest, Sigmoid) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Sigmoid").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Sigmoid").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, SigmoidGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
.RandomInput(type, dims)
.RandomInput(type, dims)
.Attr("T", type));
});
}
TEST_F(OpTest, Sign) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Sign").RandomInput(type).Attr("T", type));
});
@ -2436,21 +2520,23 @@ TEST_F(OpTest, Sign) {
TEST_F(OpTest, Sin) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Sin").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Sin").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Sinh) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Sinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Sinh").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Size) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>(kAllXlaTypes);
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Size").RandomInput(type).Attr("T", type));
});
@ -2562,10 +2648,11 @@ TEST_F(OpTest, SpaceToBatch) {
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
TensorShape({num_block_dims, 2})));
DataType type = Choose<DataType>(kAllXlaTypes);
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
.RandomInput(DT_FLOAT, input_dims)
.RandomInput(type, input_dims)
.Input(paddings)
.Attr("T", DT_FLOAT)
.Attr("T", type)
.Attr("block_size", block_size));
});
}
@ -2603,13 +2690,14 @@ TEST_F(OpTest, SpaceToBatchND) {
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
TensorShape({num_block_dims, 2})));
DataType type = Choose<DataType>(kAllXlaTypes);
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("SpaceToBatchND")
.RandomInput(DT_FLOAT, input_dims)
.RandomInput(type, input_dims)
.Input(test::AsTensor<int32>(
std::vector<int32>(block_dims.begin(), block_dims.end())))
.Input(paddings)
.Attr("T", DT_FLOAT));
.Attr("T", type));
});
}
@ -2699,18 +2787,20 @@ TEST_F(OpTest, Split) {
TEST_F(OpTest, Sqrt) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Sqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Sqrt").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, SqrtGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
.RandomInput(type, dims)
.RandomInput(type, dims)
.Attr("T", type));
});
}
@ -2726,7 +2816,7 @@ TEST_F(OpTest, SquaredDifference) {
TEST_F(OpTest, Square) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Square").RandomInput(type).Attr("T", type));
});
@ -2752,7 +2842,7 @@ TEST_F(OpTest, Squeeze) {
TEST_F(OpTest, Sub) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
auto dims = BroadcastableDims();
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
.RandomInput(type, dims.first)
@ -2763,7 +2853,7 @@ TEST_F(OpTest, Sub) {
TEST_F(OpTest, Sum) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
std::vector<int64> data_dims = RandomDims();
Tensor indices = RandomReductionIndices(data_dims.size());
bool keep_dims = Choose<bool>({false, true});
@ -2875,25 +2965,28 @@ TEST_F(OpTest, StridedSliceGrad) {
TEST_F(OpTest, Tan) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Tan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Tan").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, Tanh) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Tanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
OpTestBuilder("Tanh").RandomInput(type).Attr("T", type));
});
}
TEST_F(OpTest, TanhGrad) {
Repeatedly([this]() {
auto dims = RandomDims();
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
.RandomInput(DT_FLOAT, dims)
.RandomInput(DT_FLOAT, dims)
.Attr("T", DT_FLOAT));
.RandomInput(type, dims)
.RandomInput(type, dims)
.Attr("T", type));
});
}
@ -2951,7 +3044,7 @@ TEST_F(OpTest, TruncateMod) {
TEST_F(OpTest, ZerosLike) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type));
});

View File

@ -328,6 +328,131 @@ class UnaryOpsTest(XLATestCase):
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype))
def testComplexOps(self):
for dtype in self.complex_types:
# TODO(b/65408531): math_ops.acosh (needs pow)
# TODO(b/65408531): math_ops.asinh (needs pow)
# TODO(b/65408531): Wider support for log (needs atan2).
atan2_supported = self.device == "XLA_GPU"
if atan2_supported:
self._assertOpOutputMatchesExpected(
math_ops.atanh,
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
expected=np.arctanh(
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.cosh,
np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype),
expected=np.cosh(np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.sinh,
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
expected=np.sinh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.exp,
np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
expected=np.exp(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.expm1,
np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.reciprocal,
np.array([[1, 2j, 2 + 3j]], dtype=dtype),
expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype))
if atan2_supported:
self._assertOpOutputMatchesExpected(
math_ops.log,
np.array([[5j, 3 - 2j]], dtype=dtype),
expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.sin,
np.array([[5j, 3 - 2j]], dtype=dtype),
expected=np.sin(np.array([[5j, 3 - 2j]], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.cos,
np.array([[5j, 3 - 2j]], dtype=dtype),
expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype)))
# TODO(b/34703906): improve log1p implementation and make tolerance
# tighter.
if atan2_supported: # TODO(b/34703906): log support
self._assertOpOutputMatchesExpected(
math_ops.log1p,
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype),
expected=np.log1p(
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
# TODO(b/34703906): math_ops.rsqrt (needs pow)
# TODO(b/34703906): math_ops.sigmoid (needs tanh)
# TODO(b/34703906): math_ops.sqrt (needs pow)
self._assertOpOutputMatchesExpected(
math_ops.tan,
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
# TODO(b/34703906): math_ops.tanh (as itself)
ctypes = {np.complex64: np.float32}
self._assertOpOutputMatchesExpected(
math_ops.abs,
np.array([[3 - 4j, -1j, np.inf]], dtype=dtype),
expected=np.array([[5, 1, np.inf]], dtype=ctypes[dtype]))
self._assertOpOutputMatchesExpected(
math_ops.negative,
np.array([[-1 + 2j, -3j]], dtype=dtype),
expected=np.array([[1 - 2j, 3j]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.square,
np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype),
expected=np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype)**2)
self._assertOpOutputMatchesExpected(
array_ops.zeros_like,
np.array([[4j, 3 - 2j], [2, -1j]], dtype=dtype),
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.ones_like,
np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype),
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
if atan2_supported: # TODO(b/34703906): atan2 support
self._assertOpOutputMatchesExpected(
math_ops.angle,
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
expected=np.angle(
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype)))
self._assertOpOutputMatchesExpected(
math_ops.conj,
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
expected=np.array([1 - 3j, -4 - 7j, 2.7, 3j], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.imag,
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
expected=np.array([3, 7, 0, -3], dtype=ctypes[dtype]))
self._assertOpOutputMatchesExpected(
math_ops.real,
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype]))
def testIntOps(self):
for dtype in self.int_types:
self._assertOpOutputMatchesExpected(
@ -399,11 +524,14 @@ class UnaryOpsTest(XLATestCase):
def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]]
types = [dtypes.bool, dtypes.int32, dtypes.float32]
types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types
for shape in shapes:
for src_type in types:
for dst_type in types:
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
if src_type in self.complex_tf_types:
src += (np.arange(np.prod(shape)) * 2j).astype(
src_type.as_numpy_dtype)
src = src.reshape(shape)
dst = src.astype(dst_type.as_numpy_dtype)

View File

@ -43,7 +43,7 @@ class VariableOpsTest(XLATestCase):
# Regression test for a bug where computations with one non-constant
# output and one variable update were mishandled.
for dtype in self.numeric_types:
init = np.array([[1, 2], [3, 4]], dtype=dtype)
init = np.array([[1, 2j], [3, 4]]).astype(dtype)
with self.test_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
@ -51,82 +51,91 @@ class VariableOpsTest(XLATestCase):
x = v.assign_add(p)
with ops.control_dependencies([x]):
y = v.read_value()
self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype),
sess.run(y, {p: 1}))
self.assertAllClose(
np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {
p: 1
}))
def testSparseRead0DIndices(self):
for dtype in self.numeric_types:
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10,
11]]).astype(dtype)
with self.test_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
x = v.sparse_read(2)
self.assertAllClose(np.array([8, 9, 10, 11], dtype=dtype), sess.run(x))
self.assertAllClose(
np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x))
def testSparseRead1DIndices(self):
for dtype in self.numeric_types:
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10,
11]]).astype(dtype)
with self.test_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
x = v.sparse_read([2, 1])
self.assertAllClose(
np.array([[8, 9, 10, 11], [4, 5, 6, 7]], dtype=dtype), sess.run(x))
np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
sess.run(x))
def testSparseRead2DIndices(self):
for dtype in self.numeric_types:
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10,
11]]).astype(dtype)
with self.test_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
x = v.sparse_read([[2, 1], [0, 2]])
self.assertAllClose(
np.array(
[[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2, 3], [8, 9, 10,
11]]],
dtype=dtype), sess.run(x))
np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
[[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
sess.run(x))
def testSparseRead2DIndices3DTensor(self):
for dtype in self.numeric_types:
init = np.array(
[[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
[[20, 21, 22], [23, 24, 25]], [[30, 31, 32], [33, 34, 35]]],
dtype=dtype)
init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
[[20, 21, 22], [23, 24j, 25]],
[[30, 31, 32], [33, 34, 35]]]).astype(dtype)
with self.test_session() as sess, self.test_scope():
v = resource_variable_ops.ResourceVariable(init)
sess.run(variables.variables_initializer([v]))
x = v.sparse_read([[2, 1], [3, 0]])
self.assertAllClose(
np.array(
[[[[20, 21, 22], [23, 24, 25]], [[10, 11, 12], [13, 14, 15]]],
[[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
[[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]],
dtype=dtype), sess.run(x))
).astype(dtype), sess.run(x))
def testReadWrite(self):
"""Tests initialization, reading, and writing a resource variable."""
with self.test_session() as session:
with self.test_scope():
with variable_scope.variable_scope("ascope", use_resource=True):
x = variable_scope.get_variable(
"x",
shape=[],
dtype=dtypes.float32,
initializer=init_ops.constant_initializer(2))
a = x.read_value()
with ops.control_dependencies([a]):
b = state_ops.assign(x, 47)
with ops.control_dependencies([b]):
c = x.read_value()
with ops.control_dependencies([c]):
d = state_ops.assign_add(x, 3)
with ops.control_dependencies([d]):
e = x.read_value()
for dtype in self.numeric_types:
with self.test_session() as session:
print(ops.get_default_graph())
with self.test_scope():
with variable_scope.variable_scope("ascope", use_resource=True):
x = variable_scope.get_variable(
"x",
shape=[],
dtype=dtype,
initializer=init_ops.constant_initializer(2))
a = x.read_value()
with ops.control_dependencies([a]):
b = state_ops.assign(x, dtype(47))
with ops.control_dependencies([b]):
c = x.read_value()
with ops.control_dependencies([c]):
d = state_ops.assign_add(x, np.array(6 + 2j).astype(dtype))
with ops.control_dependencies([d]):
e = state_ops.assign_sub(x, dtype(3))
with ops.control_dependencies([e]):
f = x.read_value()
session.run(variables.global_variables_initializer())
v1, v2, v3 = session.run([a, c, e])
self.assertAllClose(2.0, v1)
self.assertAllClose(47.0, v2)
self.assertAllClose(50.0, v3)
session.run(variables.global_variables_initializer())
v1, v2, v3 = session.run([a, c, f])
self.assertAllClose(dtype(2), v1)
self.assertAllClose(dtype(47), v2)
self.assertAllClose(np.array(50 + 2j).astype(dtype), v3)
def testTraining(self):
"""Tests a gradient descent step for a simple model."""

View File

@ -63,12 +63,19 @@ class XLATestCase(test.TestCase):
self.float_tf_types = [
dtype for dtype in self.all_tf_types if dtype.is_floating
]
self.numeric_tf_types = self.int_tf_types + self.float_tf_types
self.complex_tf_types = [
dtype for dtype in self.all_tf_types if dtype.is_complex
]
self.numeric_tf_types = (
self.int_tf_types + self.float_tf_types + self.complex_tf_types)
self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types]
self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types]
self.numeric_types = self.int_types + self.float_types
self.complex_types = [
dtype.as_numpy_dtype for dtype in self.complex_tf_types
]
self.numeric_types = self.int_types + self.float_types + self.complex_types
# Parse the manifest file, if any, into a regex identifying tests to
# disable

View File

@ -77,7 +77,13 @@ class BatchMatMulOp : public XlaOpKernel {
xla::ComputationBuilder* builder = ctx->builder();
xla::ComputationDataHandle x_handle = ctx->Input(0);
if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) {
x_handle = builder->Conj(x_handle);
}
xla::ComputationDataHandle y_handle = ctx->Input(1);
if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) {
y_handle = builder->Conj(y_handle);
}
// Reshape input tensors into 3D tensors by flattening the batch
// dimensions. This makes it easier to unroll the batch dimension.

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Native XLA implementations of simple unary Ops
// Native XLA implementations of simple binary Ops
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace {
@ -50,6 +51,9 @@ XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions));
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);
@ -171,8 +175,12 @@ class ApproximateEqualOp : public XlaOpKernel {
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))),
XlaHelpers::FloatLiteral(b, input_type(0), tolerance_));
auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1)));
auto abs_shape = b->GetShape(abs);
OP_REQUIRES_OK(ctx, abs_shape.status());
auto abs_type = abs_shape.ValueOrDie()->element_type();
auto result = b->Lt(
abs, b->ConvertElementType(b->ConstantR0<float>(tolerance_), abs_type));
ctx->SetOutput(0, result);
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@ -40,6 +41,11 @@ class CastOp : public XlaOpKernel {
output = input;
} else if (dst_dtype_ == DT_BOOL) {
output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_));
} else if (xla::primitive_util::IsComplexType(src_type_) &&
!xla::primitive_util::IsComplexType(dst_type_)) {
// As in cast_op.h, we replicate the numpy behavior of truncating the
// imaginary part.
output = builder->ConvertElementType(builder->Real(input), dst_type_);
} else {
output = builder->ConvertElementType(input, dst_type_);
}

View File

@ -192,7 +192,7 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) {
errors::InvalidArgument("indices must be int32 or int64"));
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
context, input, input_shape, indices, indices_shape, axis, DT_FLOAT,
context, input, input_shape, indices, indices_shape, axis, input_type(0),
index_type, builder);
context->SetOutput(0, gather);
}

View File

@ -23,6 +23,9 @@ limitations under the License.
namespace tensorflow {
namespace {
constexpr std::array<DataType, 4> kMatmulTypes = {
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
class MatMulOp : public XlaOpKernel {
public:
explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false)
@ -73,7 +76,7 @@ class MatMulOp : public XlaOpKernel {
bool transpose_b_;
};
REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kFloatTypes), MatMulOp);
REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kMatmulTypes), MatMulOp);
class SparseMatMulOp : public MatMulOp {
public:

View File

@ -37,8 +37,9 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
}
};
REGISTER_XLA_OP(Name("ResourceApplyGradientDescent"),
ResourceApplyGradientDescent);
REGISTER_XLA_OP(
Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes),
ResourceApplyGradientDescent);
class ResourceApplyMomentum : public XlaOpKernel {
public:
@ -109,7 +110,8 @@ class ResourceApplyMomentum : public XlaOpKernel {
private:
bool use_nesterov_;
};
REGISTER_XLA_OP(Name("ResourceApplyMomentum"), ResourceApplyMomentum);
REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes),
ResourceApplyMomentum);
class ResourceApplyAdagrad : public XlaOpKernel {
public:
@ -163,7 +165,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
};
REGISTER_XLA_OP(Name("ResourceApplyAdagrad"), ResourceApplyAdagrad);
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
ResourceApplyAdagrad);
class ResourceApplyAdam : public XlaOpKernel {
public:
@ -263,7 +266,8 @@ class ResourceApplyAdam : public XlaOpKernel {
private:
DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceApplyAdam"), ResourceApplyAdam);
REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes),
ResourceApplyAdam);
class ResourceApplyRMSProp : public XlaOpKernel {
public:
@ -362,7 +366,8 @@ class ResourceApplyRMSProp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom));
}
};
REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp);
REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
ResourceApplyRMSProp);
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
@ -500,7 +505,8 @@ class ResourceApplyFtrl : public XlaOpKernel {
private:
DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl);
REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes),
ResourceApplyFtrl);
class ResourceApplyFtrlV2 : public XlaOpKernel {
public:
@ -515,7 +521,8 @@ class ResourceApplyFtrlV2 : public XlaOpKernel {
private:
DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2"), ResourceApplyFtrlV2);
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
ResourceApplyFtrlV2);
} // namespace
} // namespace tensorflow

View File

@ -41,6 +41,12 @@ namespace {
}; \
REGISTER_XLA_OP(Name(#NAME), NAME##Op);
XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x));
XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x)));
XLAJIT_MAKE_UNARY(Conj, b->Complex(b->Real(x), b->Neg(b->Imag(x))));
// Return x if x>0, otherwise -x.
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
@ -162,6 +168,9 @@ XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x)));
XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x));
XLAJIT_MAKE_UNARY(Real, b->Real(x));
XLAJIT_MAKE_UNARY(Imag, b->Imag(x));
#undef XLAJIT_MAKE_UNARY
} // namespace

View File

@ -97,6 +97,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
case xla::F64:
literal = *xla::Literal::CreateR0<double>(value);
break;
case xla::C64:
literal = *xla::Literal::CreateR0<complex64>(value);
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
case xla::S16:
@ -132,6 +135,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
case xla::F64:
return b->ConstantR0<double>(value);
break;
case xla::C64:
return b->ConstantR0<complex64>(value);
break;
default:
LOG(FATAL) << "unhandled element type " << type;
}

View File

@ -47,14 +47,17 @@ extern const char* const DEVICE_XLA_GPU;
constexpr std::array<DataType, 3> kFloatTypes = {
{DT_HALF, DT_FLOAT, DT_DOUBLE}};
constexpr std::array<DataType, 7> kNumericTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}};
constexpr std::array<DataType, 8> kNumericTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64}};
constexpr std::array<DataType, 7> kCpuAllTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
constexpr std::array<DataType, 8> kCpuAllTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_BOOL}};
constexpr std::array<DataType, 7> kGpuAllTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
constexpr std::array<DataType, 8> kGpuAllTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_BOOL}};
// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.

View File

@ -913,6 +913,17 @@ ComputationDataHandle ComputationBuilder::CustomCall(
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::Complex(
const ComputationDataHandle& real, const ComputationDataHandle& imag,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
}
ComputationDataHandle ComputationBuilder::Conj(
const ComputationDataHandle& operand) {
return Complex(Real(operand), Neg(Imag(operand)));
}
ComputationDataHandle ComputationBuilder::Add(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
@ -995,6 +1006,12 @@ ComputationDataHandle ComputationBuilder::Abs(
return UnaryOp(UNOP_ABS, operand);
}
ComputationDataHandle ComputationBuilder::Atan2(
const ComputationDataHandle& y, const ComputationDataHandle& x,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
}
ComputationDataHandle ComputationBuilder::Exp(
const ComputationDataHandle& operand) {
return UnaryOp(UNOP_EXP, operand);
@ -1040,6 +1057,16 @@ ComputationDataHandle ComputationBuilder::Tanh(
return UnaryOp(UNOP_TANH, operand);
}
ComputationDataHandle ComputationBuilder::Real(
const ComputationDataHandle& operand) {
return UnaryOp(UNOP_REAL, operand);
}
ComputationDataHandle ComputationBuilder::Imag(
const ComputationDataHandle& operand) {
return UnaryOp(UNOP_IMAG, operand);
}
ComputationDataHandle ComputationBuilder::IsFinite(
const ComputationDataHandle& operand) {
return UnaryOp(UNOP_IS_FINITE, operand);

View File

@ -431,6 +431,14 @@ class ComputationBuilder {
// of the operands is a scalar, or an explicit broadcast dimension is given
// (see g3doc for more details).
// Enqueues a complex compose instruction onto the computation.
ComputationDataHandle Complex(
const ComputationDataHandle& real, const ComputationDataHandle& imag,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues a complex conjugate instruction onto the computation.
ComputationDataHandle Conj(const ComputationDataHandle& operand);
// Enqueues an add instruction onto the computation.
ComputationDataHandle Add(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
@ -542,6 +550,11 @@ class ComputationBuilder {
// Enqueues an abs instruction onto the computation.
ComputationDataHandle Abs(const ComputationDataHandle& operand);
// Enqueues a atan2 instruction onto the computation.
ComputationDataHandle Atan2(
const ComputationDataHandle& y, const ComputationDataHandle& x,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
// Enqueues an exp instruction onto the computation.
ComputationDataHandle Exp(const ComputationDataHandle& operand);
@ -570,6 +583,12 @@ class ComputationBuilder {
// Enqueues a tanh instruction onto the computation.
ComputationDataHandle Tanh(const ComputationDataHandle& operand);
// Enqueues a real-part instruction onto the computation.
ComputationDataHandle Real(const ComputationDataHandle& operand);
// Enqueues an imaginary-part instruction onto the computation.
ComputationDataHandle Imag(const ComputationDataHandle& operand);
// Enqueues a float32 sqrt instruction onto the computation.
// (float32 is specified as there is an implicit float32 0.5f constant
// exponent).

View File

@ -204,6 +204,8 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<float>(0);
case F64:
return *Literal::CreateR0<double>(0);
case C64:
return *Literal::CreateR0<complex64>(0);
case PRED:
return *Literal::CreateR0<bool>(false);
case S16:
@ -236,6 +238,8 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<float>(1);
case F64:
return *Literal::CreateR0<double>(1);
case C64:
return *Literal::CreateR0<complex64>(1);
case PRED:
return *Literal::CreateR0<bool>(true);
case S16:
@ -271,6 +275,8 @@ Status Literal::Copy(const Literal& src_literal,
case F64:
return *Literal::CreateR0<double>(
-std::numeric_limits<double>::infinity());
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
return *Literal::CreateR0<bool>(false);
case S16:

View File

@ -141,6 +141,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleConvert(HloInstruction* convert) override;
Status HandleReal(HloInstruction* real, HloInstruction* operand) override;
Status HandleImag(HloInstruction* imag, HloInstruction* operand) override;
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
HloInstruction* rhs, const Window& window) override;
@ -967,6 +970,24 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
return Status::OK();
}
// Real(Complex(r, i)) -> r
Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kComplex) {
return ReplaceInstruction(real, operand->mutable_operand(0));
}
return Status::OK();
}
// Imag(Complex(r, i)) -> i
Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kComplex) {
return ReplaceInstruction(imag, operand->mutable_operand(1));
}
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
// Eliminate nop pads (padding all zero), and replace a pad with negative
// padding with a pad with non-negative padding followed by a slice.

View File

@ -433,6 +433,56 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
EXPECT_EQ(root, param0);
}
// Test that real(complex(r,i)) is simplified to r.
TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r2f32, "param1"));
HloInstruction* cplx = builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
HloOpcode::kComplex, param0, param1));
HloInstruction* real = builder.AddInstruction(
HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, real);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param0);
}
// Test that imag(complex(r,i)) is simplified to i.
TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r2f32, "param1"));
HloInstruction* cplx = builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
HloOpcode::kComplex, param0, param1));
HloInstruction* imag = builder.AddInstruction(
HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, imag);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_EQ(root, param1);
}
// Test that get_element(make_tuple({A,B}),1) is simplified to B
TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});

View File

@ -63,7 +63,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
const HloModuleConfig& hlo_module_config) {
PrimitiveType type = target_array.GetShape().element_type();
TF_RET_CHECK(F32 == type || F64 == type);
TF_RET_CHECK(F32 == type || F64 == type || C64 == type);
DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
lhs_array, rhs_array, executable_run_options_value,
ir_builder, hlo_module_config);
@ -176,7 +176,7 @@ tensorflow::Status DotOpEmitter::Emit() {
llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
ir_builder_->SetInsertPoint(preheader_bb->getTerminator());
ir_builder_->CreateStore(llvm::ConstantFP::get(accum_type, 0.0),
ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type),
accum_address);
// Body basic block of reduction loop:
@ -191,9 +191,29 @@ tensorflow::Status DotOpEmitter::Emit() {
llvm::Value* rhs_element =
rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_);
llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
llvm::Value* accum = ir_builder_->CreateLoad(accum_address);
llvm::Value* updated_accum = ir_builder_->CreateFAdd(accum, product);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto real = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {0});
};
auto imag = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {1});
};
llvm::Value* product_real = ir_builder_->CreateFSub(
ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)),
ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element)));
llvm::Value* product_imag = ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)),
ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element)));
updated_accum = ir_builder_->CreateInsertValue(
accum, ir_builder_->CreateFAdd(real(accum), product_real), {0});
updated_accum = ir_builder_->CreateInsertValue(
updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1});
} else {
llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
updated_accum = ir_builder_->CreateFAdd(accum, product);
}
ir_builder_->CreateStore(updated_accum, accum_address);
// Exit basic block of reduction loop.
@ -230,11 +250,28 @@ tensorflow::Status DotOpEmitter::Emit() {
tensorflow::Status DotOpEmitter::EmitScalarDot() {
// A scalar dot is just a scalar multiply.
llvm::Value* result;
llvm::Value* lhs_value =
lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
llvm::Value* rhs_value =
rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
llvm::Value* result = ir_builder_->CreateFMul(lhs_value, rhs_value);
if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
#define REAL(x) ir_builder_->CreateExtractValue(x, {0})
#define IMAG(x) ir_builder_->CreateExtractValue(x, {1})
llvm::Value* real = ir_builder_->CreateFSub(
ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)),
ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value)));
llvm::Value* imag = ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)),
ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value)));
#undef IMAG
#undef REAL
result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
result = ir_builder_->CreateInsertValue(result, real, {0});
result = ir_builder_->CreateInsertValue(result, imag, {1});
} else {
result = ir_builder_->CreateFMul(lhs_value, rhs_value);
}
target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
return tensorflow::Status::OK();
}

View File

@ -46,8 +46,8 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
}
// Create function type for the function.
llvm::FunctionType* function_type = llvm::FunctionType::get(
llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_),
llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_),
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
/*isVarArg=*/false);
// Create function declaration for 'tanhf'.
llvm::Function* function =

View File

@ -41,6 +41,12 @@ bool PotentiallyImplementedAsEigenConvolution(
ShapeUtil::HasZeroElements(kernel_shape)) {
return false;
}
// TODO(b/65408531): Explore using Eigen dot for complex64 type.
if (ShapeUtil::ElementIsComplex(input_shape) ||
ShapeUtil::ElementIsComplex(kernel_shape)) {
return false;
}
const ConvolutionDimensionNumbers& dnums =
convolution.convolution_dimension_numbers();
// Only 1D and 2D convolutions are supported at the moment.

View File

@ -288,7 +288,7 @@ Status IrEmitter::HandleConstant(HloInstruction* constant,
MinimumAlignmentForShape(literal.shape()));
} else {
llvm::Constant* initializer =
llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
global_for_const = new llvm::GlobalVariable(
/*Module=*/*module_,
/*Type=*/initializer->getType(),
@ -401,7 +401,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
const Shape& shape = get_tuple_element->shape();
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
GetEmittedValueFor(operand), &ir_builder_);
GetEmittedValueFor(operand), &ir_builder_, module_);
return Status::OK();
}
@ -412,9 +412,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
if (ShapeUtil::IsTuple(select->shape())) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select));
llvm_ir::EmitTupleSelect(GetIrArrayFor(select), GetIrArrayFor(pred),
GetEmittedValueFor(on_true),
GetEmittedValueFor(on_false), &ir_builder_);
llvm_ir::EmitTupleSelect(
GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true),
GetEmittedValueFor(on_false), &ir_builder_, module_);
return Status::OK();
}
@ -459,7 +459,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
tuple_element_addresses.push_back(tuple_element_address);
}
llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_);
llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_,
module_);
} else {
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape,
GetEmittedValueFor(infeed)));
@ -562,7 +563,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
ShapeUtil::GetTupleElementShape(operand_shape, i);
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
value, &ir_builder_);
value, &ir_builder_, module_);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
tuple_element_shape, tuple_element));
}
@ -583,7 +584,7 @@ Status IrEmitter::HandleTuple(
for (auto operand : operands) {
base_ptrs.push_back(GetEmittedValueFor(operand));
}
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_);
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_);
return Status::OK();
}
@ -644,7 +645,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window,
// the initial value on the reduce_window.
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
"reduce_window_accumulator_address", &ir_builder_,
MinimumAlignmentForPrimitiveType(operand_element_type));
ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
@ -769,7 +770,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// Allocate space to keep the currently selected value, its index, and
// the boolean initialized_flag, which is initially set to false.
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
"selected_value_address", &ir_builder_,
MinimumAlignmentForPrimitiveType(operand_element_type));
llvm::Value* selected_index_address =
@ -851,8 +852,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
llvm::Value* cond = ir_builder_.CreateICmpNE(
result, llvm::ConstantInt::get(
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
@ -895,7 +896,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
HloInstruction* rhs) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
/*supported_types=*/{F32, F64}));
/*supported_types=*/{F32, F64, C64}));
llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
@ -923,7 +924,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
const Window& window) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
/*supported_types=*/{F32}));
/*supported_types=*/{F32, C64}));
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
@ -1079,7 +1080,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
// the output entry at the given index.
PrimitiveType lhs_element_type = lhs->shape().element_type();
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, &ir_builder_),
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_),
"convolution_sum_address", &ir_builder_,
MinimumAlignmentForPrimitiveType(lhs_element_type));
ir_builder_.CreateStore(
@ -1295,14 +1296,14 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
PrimitiveType element_type = operand->shape().element_type();
// Used to calculate E(X).
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_),
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
"sum_address", &ir_builder_,
MinimumAlignmentForPrimitiveType(element_type));
// Used to calculate E(X^2).
llvm::Value* sum_square_address =
llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_),
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
"sum_square_address", &ir_builder_,
MinimumAlignmentForPrimitiveType(element_type));
@ -1425,7 +1426,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
.EmitLoop(IrName(batch_norm_training, "normalize")));
llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training),
{normalized, mean, var}, &ir_builder_);
{normalized, mean, var}, &ir_builder_, module_);
return Status::OK();
}
@ -1488,6 +1489,14 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
}
const Shape& root_shape = root_instruction->shape();
if (ShapeUtil::ElementIsComplex(root_shape)) {
// TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
// Complex multiply would be more challenging. We could perhaps use a
// strided load to get all reals in a vector, all imags in a vector, or use
// CreateShuffleVector on a bitcast to float x [2N].
*failure_reason = "complex values not supported";
return nullptr;
}
bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
@ -1509,7 +1518,7 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
// This is visually similar to ElementalIrEmitter, though conceptually we're
// doing something different here. ElementalIrEmitter emits scalar operations
// while these emit scalar or vector operations depending on the type of the
// operands.
// operands. See CreateShardedVectorType for the actual types in use here.
switch (root_instruction->opcode()) {
default:
*failure_reason = "did not recognize root instruction opcode";
@ -1586,7 +1595,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
ShardedVectorType sharded_vector_type;
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_);
llvm_ir::PrimitiveTypeToIrType(element_type, module_);
for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
// For every power of two present in element_count, we generate one or more
@ -1919,7 +1928,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(accumulator_type, &ir_builder_),
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_),
"accumulator", &ir_builder_,
MinimumAlignmentForPrimitiveType(accumulator_type));
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
@ -2248,6 +2257,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
return Status::OK();
} else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion,
assignment_)) {
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
@ -2257,6 +2267,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
fusion, operands, GetIrArrayFor(fusion), &elemental_emitter,
&ir_builder_);
} else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
VLOG(3) << "HandleFusion kLoop";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
auto operands = GetIrArraysForOperandsOf(fusion);
FusedIrEmitter fused_emitter(operands, &elemental_emitter);
@ -2400,8 +2411,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
{while_result}, IrName(xla_while, "cond"));
llvm::Value* while_predicate = ir_builder_.CreateICmpNE(
while_condition,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_),
0));
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
llvm::BasicBlock* body_bb = llvm::BasicBlock::Create(
@ -2542,7 +2552,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
unsigned element_alignment = GCD(
primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
llvm_ir::PrimitiveTypeToIrType(primitive_type, &ir_builder_));
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
if (element_count == 1) {
auto* load_instruction = ir_builder_.CreateAlignedLoad(
@ -2755,7 +2765,7 @@ llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
}
llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
return llvm_ir::ShapeToIrType(shape, &ir_builder_);
return llvm_ir::ShapeToIrType(shape, module_);
}
std::vector<llvm::Type*> IrEmitter::GetComputeFunctionParams() {
@ -2925,7 +2935,7 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall(
PrimitiveType return_type = return_shape.element_type();
llvm::Value* return_value_buffer =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
llvm_ir::PrimitiveTypeToIrType(return_type, &ir_builder_), elements,
llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
tensorflow::strings::StrCat(name, "_return_value_address"),
&ir_builder_, MinimumAlignmentForPrimitiveType(return_type));
EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
@ -3100,7 +3110,7 @@ Status IrEmitter::EmitTargetElementLoop(
for (int64 i = 0; i < output_arrays.size(); ++i) {
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_);
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_);
} else {
if (ShouldEmitParallelLoopFor(*target_op)) {

View File

@ -85,6 +85,10 @@ class DfsHloVisitor {
virtual Status HandleCopy(HloInstruction* copy) {
return HandleElementwiseUnary(copy);
}
virtual Status HandleComplex(HloInstruction* complex, HloInstruction* real,
HloInstruction* imag) {
return HandleElementwiseBinary(complex);
}
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
HloInstruction* rhs) {
return HandleElementwiseBinary(multiply);
@ -122,6 +126,10 @@ class DfsHloVisitor {
virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
return HandleElementwiseUnary(abs);
}
virtual Status HandleAtan2(HloInstruction* atan2, HloInstruction* y,
HloInstruction* x) {
return HandleElementwiseBinary(atan2);
}
virtual Status HandleRound(HloInstruction* round) {
return HandleElementwiseUnary(round);
}
@ -152,6 +160,12 @@ class DfsHloVisitor {
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
return HandleElementwiseUnary(tanh);
}
virtual Status HandleReal(HloInstruction* real, HloInstruction* operand) {
return HandleElementwiseUnary(real);
}
virtual Status HandleImag(HloInstruction* imag, HloInstruction* operand) {
return HandleElementwiseUnary(imag);
}
virtual Status HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) {
return HandleElementwiseUnary(is_finite);

View File

@ -54,10 +54,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
if (op->opcode() == HloOpcode::kCopy) {
return operand_value;
} else if (operand_value->getType()->isIntegerTy()) {
return EmitIntegerUnaryOp(op, operand_value);
} else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
return EmitComplexUnaryOp(op, operand_value);
} else {
return operand_value->getType()->isIntegerTy()
? EmitIntegerUnaryOp(op, operand_value)
: EmitFloatUnaryOp(op, operand_value);
return EmitFloatUnaryOp(op, operand_value);
}
}
@ -73,20 +75,35 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
if (primitive_util::IsIntegralType(to_type)) {
return ir_builder_->CreateIntCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_),
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
primitive_util::IsSignedIntegralType(to_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
if (primitive_util::IsSignedIntegralType(from_type)) {
return ir_builder_->CreateSIToFP(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED) {
return ir_builder_->CreateUIToFP(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
}
if (primitive_util::IsComplexType(to_type)) {
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
primitive_util::ComplexComponentType(to_type), module_);
if (primitive_util::IsSignedIntegralType(from_type)) {
return ComposeComplex(
op,
ir_builder_->CreateSIToFP(operand_value, to_ir_component_type),
nullptr);
}
if (primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED) {
return ComposeComplex(
op,
ir_builder_->CreateUIToFP(operand_value, to_ir_component_type),
nullptr);
}
}
return Unimplemented("conversion from primitive type %s to %s",
@ -97,8 +114,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
bool is_signed =
primitive_util::IsSignedIntegralType(op->shape().element_type());
if (is_signed) {
auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
ir_builder_);
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto zero = llvm::ConstantInt::get(type, 0);
auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
return ir_builder_->CreateSelect(cmp, operand_value,
@ -110,8 +127,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
case HloOpcode::kSign: {
bool is_signed =
primitive_util::IsSignedIntegralType(op->shape().element_type());
auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
ir_builder_);
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto zero = llvm::ConstantInt::get(type, 0);
auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
if (is_signed) {
@ -135,7 +152,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
return ir_builder_->CreateZExt(
ir_builder_->CreateNot(ir_builder_->CreateTrunc(
operand_value, ir_builder_->getInt1Ty())),
llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
} else if (primitive_util::IsIntegralType(type)) {
return ir_builder_->CreateNot(operand_value);
}
@ -157,20 +174,30 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
if (from_type == to_type) {
return operand_value;
}
if (primitive_util::IsComplexType(to_type)) {
PrimitiveType to_component_type =
primitive_util::ComplexComponentType(to_type);
if (from_type == to_component_type) {
return ComposeComplex(op, operand_value, nullptr);
}
return ComposeComplex(
op,
ir_builder_->CreateFPCast(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
nullptr);
}
if (primitive_util::IsFloatingPointType(to_type)) {
return ir_builder_->CreateFPCast(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsSignedIntegralType(to_type)) {
return ir_builder_->CreateFPToSI(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsUnsignedIntegralType(to_type)) {
return ir_builder_->CreateFPToUI(
operand_value,
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return Unimplemented("unhandled conversion operation: %s => %s",
PrimitiveType_Name(from_type).c_str(),
@ -230,7 +257,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
return ir_builder_->CreateZExt(
result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
case HloOpcode::kNegate:
return ir_builder_->CreateFNeg(operand_value);
@ -240,20 +267,164 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
auto real = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {0});
};
auto imag = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {1});
};
switch (op->opcode()) {
// TODO(b/65209142): Angle/Log require atan2.
// case HloOpcode::kAngle:
// case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
TF_RET_CHECK(primitive_util::IsComplexType(from_type));
PrimitiveType to_type = op->shape().element_type();
TF_RET_CHECK(primitive_util::IsComplexType(to_type));
if (from_type == to_type) {
return operand_value;
}
PrimitiveType to_component_type =
primitive_util::ComplexComponentType(to_type);
auto to_ir_component_type =
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
return ComposeComplex(
op,
ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type),
ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type));
}
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
auto exp_a = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::exp, {real(operand_value)},
{real(operand_value)->getType()}, ir_builder_);
auto cos_b = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::cos, {imag(operand_value)},
{imag(operand_value)->getType()}, ir_builder_);
auto sin_b = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::sin, {imag(operand_value)},
{imag(operand_value)->getType()}, ir_builder_);
return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
ir_builder_->CreateFMul(exp_a, sin_b));
}
case HloOpcode::kCos: {
// cos(z) = .5(e^(iz) + e^(-iz))
// cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
// = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
auto a = real(operand_value);
auto b = imag(operand_value);
auto type = a->getType();
auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
{type}, ir_builder_);
auto half_exp_b =
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b =
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a},
{type}, ir_builder_);
auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
{type}, ir_builder_);
return ComposeComplex(
op,
ir_builder_->CreateFMul(
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
ir_builder_->CreateFMul(
sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
}
case HloOpcode::kSin: {
// sin(z) = .5i(e^(-iz) - e^(iz))
// sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
// = .5i(e^(b-ai) - e^(-b+ai))
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
// sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
// = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
// = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
// = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
auto a = real(operand_value);
auto b = imag(operand_value);
auto type = a->getType();
auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
{type}, ir_builder_);
auto half_exp_b =
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b =
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a},
{type}, ir_builder_);
auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
{type}, ir_builder_);
return ComposeComplex(
op,
ir_builder_->CreateFMul(
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
ir_builder_->CreateFMul(
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
}
case HloOpcode::kAbs: {
auto sum_sq = ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(real(operand_value), real(operand_value)),
ir_builder_->CreateFMul(imag(operand_value), imag(operand_value)));
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
{sum_sq->getType()}, ir_builder_);
}
case HloOpcode::kSign: { // Sign(c) = c / |c|
auto sum_sq = ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(real(operand_value), real(operand_value)),
ir_builder_->CreateFMul(imag(operand_value), imag(operand_value)));
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_);
auto type = cplx_abs->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero);
return ir_builder_->CreateSelect(
oeq, ComposeComplex(op, zero, zero),
ComposeComplex(
op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs),
ir_builder_->CreateFDiv(imag(operand_value), cplx_abs)));
}
case HloOpcode::kNegate:
return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)),
ir_builder_->CreateFNeg(imag(operand_value)));
case HloOpcode::kReal:
return real(operand_value);
case HloOpcode::kImag:
return imag(operand_value);
default:
return Unimplemented("unary complex op '%s'",
HloOpcodeString(op->opcode()).c_str());
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
return lhs_value->getType()->isIntegerTy()
? EmitIntegerBinaryOp(op, lhs_value, rhs_value,
primitive_util::IsSignedIntegralType(
op->operand(0)->shape().element_type()))
: EmitFloatBinaryOp(op, lhs_value, rhs_value);
PrimitiveType operand_type = op->operand(0)->shape().element_type();
if (lhs_value->getType()->isIntegerTy()) {
return EmitIntegerBinaryOp(
op, lhs_value, rhs_value,
primitive_util::IsSignedIntegralType(operand_type));
} else if (primitive_util::IsComplexType(operand_type)) {
return EmitComplexBinaryOp(op, lhs_value, rhs_value);
} else {
return EmitFloatBinaryOp(op, lhs_value, rhs_value);
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
switch (op->opcode()) {
// case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support
case HloOpcode::kComplex:
return ComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
return ir_builder_->CreateFAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
@ -305,6 +476,88 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
}
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
auto real = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {0});
};
auto imag = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {1});
};
switch (op->opcode()) {
case HloOpcode::kAdd:
return ComposeComplex(
op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)),
ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value)));
case HloOpcode::kSubtract:
return ComposeComplex(
op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)),
ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value)));
case HloOpcode::kMultiply:
return ComposeComplex(
op,
ir_builder_->CreateFSub(
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))),
ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)),
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))));
case HloOpcode::kDivide: {
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
auto rhs_sum_sq = ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)),
ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value)));
auto type = rhs_sum_sq->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero);
return ir_builder_->CreateSelect(
oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero),
ComposeComplex(
op,
ir_builder_->CreateFDiv(
ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
ir_builder_->CreateFMul(imag(lhs_value),
imag(rhs_value))),
rhs_sum_sq),
ir_builder_->CreateFDiv(
ir_builder_->CreateFSub(
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)),
ir_builder_->CreateFMul(real(lhs_value),
imag(rhs_value))),
rhs_sum_sq)));
}
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
//
// We use ordered comparisons for everything except kNe, where we use an
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kEq:
return ir_builder_->CreateAnd(
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value),
real(rhs_value), ir_builder_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value),
imag(rhs_value), ir_builder_));
case HloOpcode::kNe:
return ir_builder_->CreateOr(
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value),
real(rhs_value), ir_builder_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value),
imag(rhs_value), ir_builder_));
// TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic
// case HloOpcode::kPower:
// // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2)
default:
return Unimplemented("binary complex op '%s'",
HloOpcodeString(op->opcode()).c_str());
}
}
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_);
@ -396,7 +649,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
PrimitiveType prim_type, llvm::Value* value) const {
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, ir_builder_);
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
}
@ -619,7 +872,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
const {
PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
llvm::Type* param_ir_type =
llvm_ir::PrimitiveTypeToIrType(param_prim_type, ir_builder_);
llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_);
// Same values as PCG library
// https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
@ -783,7 +1036,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
return ir_builder_->CreateZExt(
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p),
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_));
module_));
}
default:
return InvalidArgument(
@ -806,9 +1059,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kNegate:
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kTanh:
@ -821,6 +1076,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
return EmitUnaryOp(hlo, operand_value);
};
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kGe:
@ -913,10 +1170,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
}
llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
llvm::PHINode* output = ir_builder_->CreatePHI(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_),
hlo->operands().size());
llvm::PHINode* output =
ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(
hlo->shape().element_type(), module_),
hlo->operands().size());
auto prior_insert_point = ir_builder_->GetInsertPoint();
ir_builder_->SetInsertPoint(init_block);
@ -1075,7 +1332,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
// else -> return data from 'index'.
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_),
module_),
"ret_value_addr", ir_builder_);
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
slice_intersection, "slice_intersection", ir_builder_);
@ -1164,7 +1421,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
// }
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_),
module_),
"pad_result_addr", ir_builder_);
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
@ -1206,7 +1463,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
ir_builder_);
PrimitiveType primitive_type = hlo->shape().element_type();
llvm::Type* primitive_type_llvm =
llvm_ir::PrimitiveTypeToIrType(primitive_type, ir_builder_);
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
primitive_type_llvm, "dot_acc", ir_builder_);
ir_builder_->CreateStore(
@ -1239,7 +1496,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
llvm::Value* next_accumulator;
if (primitive_util::IsFloatingPointType(primitive_type)) {
if (primitive_util::IsComplexType(primitive_type)) {
auto real = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {0});
};
auto imag = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {1});
};
llvm::Value* product_real = ir_builder_->CreateFSub(
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value)));
llvm::Value* product_imag = ir_builder_->CreateFAdd(
ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)),
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)));
next_accumulator = ir_builder_->CreateInsertValue(
current_accumulator,
ir_builder_->CreateFAdd(real(current_accumulator), product_real),
{0});
next_accumulator = ir_builder_->CreateInsertValue(
next_accumulator,
ir_builder_->CreateFAdd(imag(current_accumulator), product_imag),
{1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
next_accumulator = ir_builder_->CreateFAdd(
current_accumulator,
ir_builder_->CreateFMul(lhs_value, rhs_value));
@ -1261,4 +1539,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
}
}
llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op,
llvm::Value* real,
llvm::Value* imag) const {
auto cplx_type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto complex = ir_builder_->CreateInsertValue(
llvm::ConstantAggregateZero::get(cplx_type), real, {0});
if (imag != nullptr) {
complex = ir_builder_->CreateInsertValue(complex, imag, {1});
}
return complex;
}
} // namespace xla

View File

@ -55,6 +55,7 @@ class ElementalIrEmitter {
const HloToElementGeneratorMap& operand_to_generator) const;
llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
llvm::Module* module() const { return module_; }
protected:
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
@ -63,6 +64,9 @@ class ElementalIrEmitter {
virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const;
virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const;
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value,
@ -72,6 +76,10 @@ class ElementalIrEmitter {
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const;
virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const;
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) const;
@ -109,6 +117,11 @@ class ElementalIrEmitter {
// compiled executable outside of the HLO code itself.
const HloModuleConfig& hlo_module_config_;
protected:
// Composes a complex struct. imag may be nullptr for simple cast operations.
llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real,
llvm::Value* imag) const;
private:
// Returns a ElementGenerator for a RNG HloInstruction.
llvm_ir::ElementGenerator MakeRngElementGenerator(

View File

@ -135,6 +135,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
case HloOpcode::kAtan2:
return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value},
{lhs_input_type, rhs_input_type},
output_type);
case HloOpcode::kRemainder: {
return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value},
{lhs_input_type, rhs_input_type},
@ -226,6 +230,112 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
}
}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const {
PrimitiveType input_type = op->operand(0)->shape().element_type();
PrimitiveType component_type =
primitive_util::IsComplexType(input_type)
? primitive_util::ComplexComponentType(input_type)
: input_type;
auto real = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {0});
};
auto imag = [&](llvm::Value* x) {
return ir_builder_->CreateExtractValue(x, {1});
};
switch (op->opcode()) {
case HloOpcode::kLog: {
// log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
auto a = real(operand_value);
auto b = imag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
ir_builder_->CreateFMul(b, b));
TF_ASSIGN_OR_RETURN(
auto log_sum_sq,
EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type},
component_type));
TF_ASSIGN_OR_RETURN(
auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a},
{component_type, component_type},
component_type));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq),
angle);
}
// TODO(b/65408531): Implement kPower on GPU, where atan2 is available.
// case HloOpcode::kPower:
// // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di))
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
auto b = imag(operand_value);
TF_ASSIGN_OR_RETURN(
auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)},
{component_type}, component_type));
TF_ASSIGN_OR_RETURN(
auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type},
component_type));
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type},
component_type));
return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
ir_builder_->CreateFMul(exp_a, sin_b));
}
case HloOpcode::kCos: {
// cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
auto a = real(operand_value);
auto llvm_ty = a->getType();
TF_ASSIGN_OR_RETURN(
auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)},
{component_type}, component_type));
TF_ASSIGN_OR_RETURN(
auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
component_type));
TF_ASSIGN_OR_RETURN(
auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type},
component_type));
auto half_exp_b =
ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
auto half_exp_neg_b =
ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
return ComposeComplex(
op,
ir_builder_->CreateFMul(
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
ir_builder_->CreateFMul(
sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
}
case HloOpcode::kSin: {
// sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
auto a = real(operand_value);
auto llvm_ty = a->getType();
TF_ASSIGN_OR_RETURN(
auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)},
{component_type}, component_type));
TF_ASSIGN_OR_RETURN(
auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
component_type));
TF_ASSIGN_OR_RETURN(
auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type},
component_type));
auto half_exp_b =
ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
auto half_exp_neg_b =
ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
return ComposeComplex(
op,
ir_builder_->CreateFMul(
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
ir_builder_->CreateFMul(
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
}
default:
return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value);
}
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
const string& callee_name,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
@ -235,13 +345,12 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
std::vector<llvm::Type*> ir_input_types;
for (PrimitiveType input_type : input_types) {
ir_input_types.push_back(
llvm_ir::PrimitiveTypeToIrType(input_type, ir_builder_));
llvm_ir::PrimitiveTypeToIrType(input_type, module_));
}
llvm::FunctionType* callee_type = llvm::FunctionType::get(
llvm_ir::PrimitiveTypeToIrType(output_type,
ir_builder_), // The return type.
ir_input_types, // The parameter types.
false); // No variadic arguments.
llvm_ir::PrimitiveTypeToIrType(output_type, module_), // Return type.
ir_input_types, // Parameter types.
false); // No variadic arguments.
// Declares the callee if it is not declared already.
llvm::Function* callee = llvm::cast<llvm::Function>(
@ -315,7 +424,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_builder_),
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
"reduce_window_accum_ptr", ir_builder_);
{
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
@ -377,7 +486,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
llvm::Value* accum_ptr =
ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
hlo->shape().element_type(), ir_builder()));
hlo->shape().element_type(), module_));
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
operand_to_generator.at(hlo->operand(1))({}));
ir_builder()->CreateStore(init_value, accum_ptr);

View File

@ -54,6 +54,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
StatusOr<llvm::Value*> EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
StatusOr<llvm::Value*> EmitFloatBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value,
llvm::Value* rhs_value) const override;

View File

@ -102,7 +102,7 @@ void HloToIrBindings::EmitBasePointersForHlos(
slice_result.ConsumeValueOrDie();
if (slice.allocation()->is_thread_local()) {
llvm::Type* pointee_type =
llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_);
llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
BindHloToIrValue(*non_io_hlo,
ir_builder_->CreateAlloca(pointee_type), index);
} else {
@ -124,18 +124,18 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_);
GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_);
}
return llvm_ir::EmitGetTupleElement(
gte->shape(), gte->tuple_index(), /*alignment=*/1,
EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_);
EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_);
}
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
const ShapeIndex& shape_index,
llvm::Value* ir_value) {
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_builder_);
ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
llvm::Type* dest_type = pointee_type->getPointerTo();
llvm::Value* typed_ir_value;

View File

@ -36,10 +36,12 @@ class HloToIrBindings {
public:
HloToIrBindings(const HloModule& module,
const BufferAssignment* buffer_assignment,
llvm::IRBuilder<>* ir_builder, bool is_nested)
llvm::IRBuilder<>* ir_builder, llvm::Module* llvm_module,
bool is_nested)
: buffer_assignment_(buffer_assignment),
is_nested_(is_nested),
ir_builder_(ir_builder),
module_(llvm_module),
alias_analysis_(module, *buffer_assignment_,
&ir_builder_->getContext()) {}
@ -93,6 +95,7 @@ class HloToIrBindings {
const bool is_nested_;
llvm::IRBuilder<>* ir_builder_;
llvm::Module* module_;
// Stores the underlying llvm::IrArray for each HloInstruction.
// For an instruction that generates multiple outputs, the root will be a

View File

@ -53,9 +53,10 @@ namespace gpu {
IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
IrEmitterContext* ir_emitter_context, bool is_nested)
: ir_emitter_context_(ir_emitter_context),
ir_builder_(ir_emitter_context->llvm_module()->getContext()),
module_(ir_emitter_context->llvm_module()),
ir_builder_(module_->getContext()),
bindings_(ir_emitter_context->hlo_module(),
&ir_emitter_context->buffer_assignment(), &ir_builder_,
&ir_emitter_context->buffer_assignment(), &ir_builder_, module_,
is_nested),
hlo_module_config_(hlo_module_config) {
ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
@ -71,18 +72,17 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
};
}
return EmitTargetElementLoop(
*hlo, GpuElementalIrEmitter(hlo_module_config_,
ir_emitter_context_->llvm_module(),
&ir_builder_, GetNestedComputer())
*hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
GetNestedComputer())
.MakeElementGenerator(hlo, operand_to_generator));
}
Status IrEmitter::HandleConstant(HloInstruction* constant,
const Literal& literal) {
llvm::Constant* initializer =
llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
*ir_emitter_context_->llvm_module(), initializer->getType(),
*module_, initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
/*Name=*/"");
VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl
@ -115,7 +115,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
get_tuple_element->shape(), get_tuple_element->tuple_index(),
// TODO(b/26344050): tighten the alignment here
// based on the real element type.
/*alignment=*/1, GetBasePointer(*operand), &ir_builder_));
/*alignment=*/1, GetBasePointer(*operand), &ir_builder_, module_));
return Status::OK();
}
@ -140,7 +140,7 @@ Status IrEmitter::HandleTuple(
for (const HloInstruction* operand : operands) {
base_ptrs.push_back(GetBasePointer(*operand));
}
llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_);
llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_);
return Status::OK();
}
@ -329,7 +329,7 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
if (ShapeUtil::IsTuple(select->shape())) {
llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred),
GetBasePointer(*on_true),
GetBasePointer(*on_false), &ir_builder_);
GetBasePointer(*on_false), &ir_builder_, module_);
return Status::OK();
}
@ -355,7 +355,26 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
llvm::Value* rhs_value =
rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
llvm::Value* result = ir_builder_.CreateFMul(lhs_value, rhs_value);
llvm::Value* result;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto real = [&](llvm::Value* x) {
return ir_builder_.CreateExtractValue(x, {0});
};
auto imag = [&](llvm::Value* x) {
return ir_builder_.CreateExtractValue(x, {1});
};
llvm::Value* real_result = ir_builder_.CreateFSub(
ir_builder_.CreateFMul(real(lhs_value), real(rhs_value)),
ir_builder_.CreateFMul(imag(lhs_value), imag(rhs_value)));
llvm::Value* imag_result = ir_builder_.CreateFAdd(
ir_builder_.CreateFMul(real(lhs_value), imag(rhs_value)),
ir_builder_.CreateFMul(imag(lhs_value), real(rhs_value)));
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
result = ir_builder_.CreateInsertValue(result, real_result, {0});
result = ir_builder_.CreateInsertValue(result, imag_result, {1});
} else {
result = ir_builder_.CreateFMul(lhs_value, rhs_value);
}
target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_);
return Status::OK();
}
@ -411,8 +430,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
// Initialize the accumulator in the preheader to zero.
new llvm::StoreInst(
llvm::ConstantFP::get(accum_type, 0.0), // The value stored.
accum_address, // The address.
llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0
accum_address, // The address.
reduction_loop->GetPreheaderBasicBlock()
->getTerminator()); // The instruction this store is inserted before.
@ -427,9 +446,27 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_);
llvm::Value* rhs_element =
rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_);
llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element);
llvm::Value* accum = ir_builder_.CreateLoad(accum_address);
llvm::Value* updated_accum = ir_builder_.CreateFAdd(accum, product);
llvm::Value* updated_accum;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
#define REAL(x) ir_builder_.CreateExtractValue(x, {0})
#define IMAG(x) ir_builder_.CreateExtractValue(x, {1})
llvm::Value* product_real = ir_builder_.CreateFSub(
ir_builder_.CreateFMul(REAL(lhs_element), REAL(rhs_element)),
ir_builder_.CreateFMul(IMAG(lhs_element), IMAG(rhs_element)));
llvm::Value* product_imag = ir_builder_.CreateFAdd(
ir_builder_.CreateFMul(REAL(lhs_element), IMAG(rhs_element)),
ir_builder_.CreateFMul(IMAG(lhs_element), REAL(rhs_element)));
updated_accum = ir_builder_.CreateInsertValue(
accum, ir_builder_.CreateFAdd(REAL(accum), product_real), {0});
updated_accum = ir_builder_.CreateInsertValue(
updated_accum, ir_builder_.CreateFAdd(IMAG(accum), product_imag), {1});
#undef IMAG
#undef REAL
} else {
llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element);
updated_accum = ir_builder_.CreateFAdd(accum, product);
}
ir_builder_.CreateStore(updated_accum, accum_address);
// After the reduction loop exits, store the accumulator into the target
@ -494,7 +531,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
// Initialize an accumulator with init_value.
llvm::AllocaInst* accumulator_addr =
ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
reduce->shape().element_type(), &ir_builder_));
reduce->shape().element_type(), module_));
ir_builder_.CreateStore(
ir_builder_.CreateLoad(GetBasePointer(*init_value)),
accumulator_addr);
@ -547,8 +584,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand));
}
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
ir_emitter_context_->llvm_module(),
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_,
&ir_builder_, GetNestedComputer());
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
@ -591,9 +627,8 @@ Status IrEmitter::HandleRng(HloInstruction* random,
// Emits a single-threaded loop because the loop body generated by the element
// generator for Rng can't be parallelized (b/32333178).
return llvm_ir::LoopEmitter(
GpuElementalIrEmitter(hlo_module_config_,
ir_emitter_context_->llvm_module(),
&ir_builder_, GetNestedComputer())
GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
GetNestedComputer())
.MakeElementGenerator(random, operand_to_generator),
GetIrArray(*random), &ir_builder_)
.EmitLoop(IrName(random));
@ -634,7 +669,7 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(
computation.root_instruction()->shape().element_type(), &ir_builder_),
computation.root_instruction()->shape().element_type(), module_),
"return_buffer", &ir_builder_);
std::vector<llvm::Value*> parameter_buffers;
for (llvm::Value* parameter_element : parameter_elements) {

View File

@ -162,6 +162,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
}
IrEmitterContext* ir_emitter_context_;
llvm::Module* module_;
// The following fields track the IR emission state. According to LLVM memory
// management rules, their memory is owned by the module.

View File

@ -52,9 +52,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
io_hlos->push_back(param);
const Shape& param_shape = param->shape();
argument_types.push_back(
llvm_ir::ShapeToIrType(param_shape, &ir_builder_)->getPointerTo());
int64 param_size = llvm_ir::ByteSizeOf(
param_shape, ir_emitter_context_->llvm_module()->getDataLayout());
llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
int64 param_size =
llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
argument_dereferenceable_bytes.push_back(param_size);
}
{
@ -62,7 +62,7 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
io_hlos->push_back(root);
const Shape& root_shape = root->shape();
argument_types.push_back(
llvm_ir::ShapeToIrType(root_shape, &ir_builder_)->getPointerTo());
llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
int64 root_size = llvm_ir::ByteSizeOf(
root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
argument_dereferenceable_bytes.push_back(root_size);

View File

@ -757,8 +757,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
auto loop_body_emitter =
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
// Emit the loop body that reduces one tile.
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
input_shape.element_type(), &ir_builder_);
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
{
@ -973,7 +973,7 @@ Status IrEmitterUnnested::EmitRowReduction(
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
// Emit the loop body that reduces one tile.
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
input_shape.element_type(), &ir_builder_);
input_shape.element_type(), ir_emitter_context_->llvm_module());
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
{
@ -1360,7 +1360,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// boolean flag if the value is initialized. The initialized_flag is set
// false.
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
llvm_ir::PrimitiveTypeToIrType(operand_element_type,
ir_emitter_context_->llvm_module()),
"selected_value_address", &ir_builder_);
llvm::Value* selected_index_address =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
@ -1440,7 +1441,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_),
llvm_ir::PrimitiveTypeToIrType(PRED,
ir_emitter_context_->llvm_module()),
"select_return_buffer", &ir_builder_);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*select_and_scatter->select(),
@ -1450,8 +1452,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
llvm::Value* cond = ir_builder_.CreateICmpNE(
result, llvm::ConstantInt::get(
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
result,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
PRED, ir_emitter_context_->llvm_module()),
0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
@ -1877,7 +1881,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
}
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_);
llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_,
module_);
return Status::OK();
}

View File

@ -50,6 +50,12 @@ namespace xla {
namespace {
template <typename T>
struct is_complex_t : public std::false_type {};
template <>
struct is_complex_t<complex64> : public std::true_type {};
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
const Literal& lhs_literal,
@ -101,6 +107,37 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
return std::move(result);
}
template <>
StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
const Shape& shape, HloOpcode opcode, const Literal& lhs_literal,
const Literal& rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
return lhs_el != rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
}
auto result = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(result->Populate<bool>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return compare_op(lhs_literal.Get<complex64>(multi_index),
rhs_literal.Get<complex64>(multi_index));
}));
return std::move(result);
}
template <typename ReturnT, typename NativeT>
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
@ -138,7 +175,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo_instruction) override {
return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
HloOpcodeString(hlo_instruction->opcode()).c_str());
};
}
// TODO(b/35950897): many of the stl functions used in the handlers are not
// overloaded for every XLA primitive types.
@ -156,7 +193,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
template <
typename NativeT,
typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr>
typename std::enable_if<std::is_signed<NativeT>::value ||
is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
@ -169,7 +207,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return HandleAbs<ReturnT>(abs, operand);
}
Status HandleRound(HloInstruction* round) override {
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleRound(HloInstruction* round) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[round],
ElementWiseUnaryOp(round, [](ReturnT elem_operand) {
return std::round(elem_operand);
@ -177,6 +218,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleRound(HloInstruction* round) {
return InvalidArgument("Unsupported type for Round");
}
Status HandleRound(HloInstruction* round) override {
return HandleRound<ReturnT>(round);
}
Status HandleBroadcast(HloInstruction* broadcast) override {
parent_->evaluated_[broadcast] =
Literal::CreateFromShape(broadcast->shape());
@ -205,15 +257,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
}
return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
});
};
}
Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override {
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleCeil(HloInstruction* ceil) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) {
return std::ceil(elem_operand);
}));
return Status::OK();
};
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleCeil(HloInstruction* ceil) {
return InvalidArgument("Unsupported type for Ceil");
}
Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override {
return HandleCeil<ReturnT>(ceil);
}
Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0);
@ -237,15 +303,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return std::exp(elem_operand);
}));
return Status::OK();
};
}
Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override {
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleFloor(HloInstruction* floor) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor],
ElementWiseUnaryOp(floor, [](ReturnT elem_operand) {
return std::floor(elem_operand);
}));
return Status::OK();
};
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleFloor(HloInstruction* floor) {
return InvalidArgument("Unsupported type for Floor");
}
Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override {
return HandleFloor<ReturnT>(floor);
}
Status HandleLog(HloInstruction* log, HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
@ -253,15 +333,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return std::log(elem_operand);
}));
return Status::OK();
};
}
Status HandleNot(HloInstruction* not_, HloInstruction* operand) override {
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleNot(HloInstruction* not_) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
return !elem_operand;
}));
return Status::OK();
};
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleNot(HloInstruction* not_) {
return InvalidArgument("Unsupported type for Not");
}
Status HandleNot(HloInstruction* not_, HloInstruction* operand) override {
return HandleNot<ReturnT>(not_);
}
Status HandleNegate(HloInstruction* negate,
HloInstruction* operand) override {
@ -270,16 +364,36 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return -elem_operand;
}));
return Status::OK();
};
}
Status HandleSign(HloInstruction* sign, HloInstruction* operand) override {
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleSign(HloInstruction* sign) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
ElementWiseUnaryOp(sign, [](ReturnT elem_operand) {
return (ReturnT(0) < elem_operand) -
(elem_operand < ReturnT(0));
}));
return Status::OK();
};
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleSign(HloInstruction* sign) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
ElementWiseUnaryOp(sign, [](ReturnT elem_operand) {
auto abs_val = std::abs(elem_operand);
return 0 == abs_val ? ReturnT(0)
: elem_operand / abs_val;
}));
return Status::OK();
}
Status HandleSign(HloInstruction* sign, HloInstruction* operand) override {
return HandleSign<ReturnT>(sign);
}
Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
@ -287,7 +401,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return std::tanh(elem_operand);
}));
return Status::OK();
};
}
Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
HloInstruction* rhs) override {
@ -297,7 +411,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return lhs_elem * rhs_elem;
}));
return Status::OK();
};
}
Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
HloInstruction* rhs) override {
@ -307,7 +421,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return lhs_elem - rhs_elem;
}));
return Status::OK();
};
}
Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
HloInstruction* rhs) override {
@ -317,7 +431,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return lhs_elem + rhs_elem;
}));
return Status::OK();
};
}
Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
HloInstruction* rhs) override {
@ -327,25 +441,53 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return lhs_elem / rhs_elem;
}));
return Status::OK();
};
}
Status HandleMaximum(HloInstruction* maximum) override {
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleMaximum(HloInstruction* maximum) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[maximum],
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
return std::fmax(lhs, rhs);
}));
return Status::OK();
};
}
Status HandleMinimum(HloInstruction* minimum) override {
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleMaximum(HloInstruction* maximum) {
return InvalidArgument("Unsupported type for Maximum");
}
Status HandleMaximum(HloInstruction* maximum) override {
return HandleMaximum<ReturnT>(maximum);
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleMinimum(HloInstruction* minimum) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[minimum],
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
return std::fmin(lhs_el, rhs_el);
}));
return Status::OK();
};
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleMinimum(HloInstruction* minimum) {
return InvalidArgument("Unsupported type for Minimum");
}
Status HandleMinimum(HloInstruction* minimum) override {
return HandleMinimum<ReturnT>(minimum);
}
Status HandlePower(HloInstruction* power, HloInstruction* lhs,
HloInstruction* rhs) override {
@ -355,37 +497,79 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return std::pow(lhs_el, rhs_el);
}));
return Status::OK();
};
}
Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
HloInstruction* rhs) override {
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleRemainder(HloInstruction* remainder) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[remainder],
ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) {
return std::fmod(lhs_el, rhs_el);
}));
return Status::OK();
};
}
Status HandleAnd(HloInstruction* and_, HloInstruction* lhs,
HloInstruction* rhs) override {
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleRemainder(HloInstruction* remainder) {
return InvalidArgument("Unsupported type for Remainder");
}
Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
HloInstruction* rhs) override {
return HandleRemainder<ReturnT>(remainder);
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleAnd(HloInstruction* and_) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[and_],
ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el && rhs_el;
}));
return Status::OK();
};
}
Status HandleOr(HloInstruction* or_, HloInstruction* lhs,
HloInstruction* rhs) override {
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleAnd(HloInstruction* and_) {
return InvalidArgument("Unsupported type for And");
}
Status HandleAnd(HloInstruction* and_, HloInstruction* lhs,
HloInstruction* rhs) override {
return HandleAnd<ReturnT>(and_);
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleOr(HloInstruction* or_) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[or_],
ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el || rhs_el;
}));
return Status::OK();
};
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleOr(HloInstruction* or_) {
return InvalidArgument("Unsupported type for Or");
}
Status HandleOr(HloInstruction* or_, HloInstruction* lhs,
HloInstruction* rhs) override {
return HandleOr<ReturnT>(or_);
}
template <typename NativeT,
typename std::enable_if<
@ -474,8 +658,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return HandleShiftRightLogical<ReturnT>(shrl, lhs, rhs);
}
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
HloInstruction* arg, HloInstruction* max) override {
HloInstruction* arg, HloInstruction* max) {
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
[](ReturnT low, ReturnT high, ReturnT value) {
return std::fmax(low, std::fmin(value, high));
@ -483,7 +670,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
ElementWiseTernaryOp(clamp, std::move(clamp_op)));
return Status::OK();
};
}
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
HloInstruction* arg, HloInstruction* max) {
return InvalidArgument("Unsupported type for Clamp");
}
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
HloInstruction* arg, HloInstruction* max) override {
return HandleClamp<ReturnT>(clamp, min, arg, max);
}
Status HandleSelect(HloInstruction* select, HloInstruction* pred,
HloInstruction* on_true,
@ -499,7 +699,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
ElementWiseTernaryOp(select, std::move(select_op)));
return Status::OK();
};
}
Status HandleReverse(HloInstruction* reverse,
HloInstruction* operand) override {
@ -529,7 +729,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
parent_->evaluated_[reverse] = std::move(result);
return Status::OK();
};
}
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
HloInstruction* rhs, const Window& window) override {
@ -652,7 +852,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
};
}
Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
HloInstruction* rhs) override {
@ -719,7 +919,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
parent_->evaluated_[dot] = std::move(result);
return Status::OK();
};
}
Status HandlePad(HloInstruction* pad) override {
CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
@ -788,7 +988,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
parent_->evaluated_[pad] = std::move(result);
return Status::OK();
};
}
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
HloInstruction* operand,
@ -841,7 +1041,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
}
return Status::OK();
};
}
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
HloInstruction* operand,
@ -897,7 +1097,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
}
return Status::OK();
};
}
Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
HloInstruction* init_value,
@ -985,7 +1185,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
parent_->evaluated_[reduce] = std::move(result);
return Status::OK();
};
}
Status HandleReduceWindow(HloInstruction* reduce_window,
HloInstruction* operand, const Window& window,
@ -1072,7 +1272,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
parent_->evaluated_[reduce_window] = std::move(result);
return Status::OK();
};
}
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override {
const Shape& shape = slice->shape();
@ -1101,7 +1301,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
};
}
private:
template <typename IndexT>
@ -1244,35 +1444,33 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
}
HloEvaluator* parent_;
}; // namespace xla
}; // class HloEvaluator::TypedVisitor
HloEvaluator::HloEvaluator() {
typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this);
typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this);
typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: U16.");
return Unimplemented("HloEvaluator: unhandled primitive type: U16.");
});
typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this);
typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this);
typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this);
typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: S16.");
return Unimplemented("HloEvaluator: unhandled primitive type: S16.");
});
typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
typed_visitors_[F16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: F16.");
return Unimplemented("HloEvaluator: unhandled primitive type: F16.");
});
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[C64] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: C64.");
});
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: TUPLE.");
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
});
typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: OPAQUE.");
return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE.");
});
}
@ -1573,6 +1771,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
evaluated_[compare],
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case C64: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<complex64>(compare->shape(), opcode,
lhs_literal, rhs_literal));
} break;
default:
LOG(FATAL) << "HandleCompare: unknown primitive type: "
<< PrimitiveType_Name(lhs->shape().element_type());

View File

@ -826,8 +826,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kComplex:
case HloOpcode::kConvert:
case HloOpcode::kCos:
case HloOpcode::kDivide:
@ -836,6 +838,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kImag:
case HloOpcode::kIndex:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
@ -850,6 +853,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:

View File

@ -219,10 +219,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSort:
@ -241,26 +243,28 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
// Only certain opcodes are supported with CreateBinary: opcodes of binary
// instructions with no auxiliary fields.
switch (opcode) {
case (HloOpcode::kAdd):
case (HloOpcode::kDivide):
case (HloOpcode::kDot):
case (HloOpcode::kEq):
case (HloOpcode::kGe):
case (HloOpcode::kGt):
case (HloOpcode::kLe):
case (HloOpcode::kLt):
case (HloOpcode::kMaximum):
case (HloOpcode::kMinimum):
case (HloOpcode::kMultiply):
case (HloOpcode::kNe):
case (HloOpcode::kPower):
case (HloOpcode::kRemainder):
case (HloOpcode::kSubtract):
case (HloOpcode::kAnd):
case (HloOpcode::kOr):
case (HloOpcode::kShiftLeft):
case (HloOpcode::kShiftRightArithmetic):
case (HloOpcode::kShiftRightLogical):
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
case HloOpcode::kDot:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
break;
default:
LOG(FATAL) << "Invalid binary instruction opcode "
@ -978,11 +982,13 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSort:
@ -992,6 +998,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
break;
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
@ -1403,10 +1411,12 @@ bool HloInstruction::IdenticalSlowPath(
// The result of these instructions only depend upon their opcode and
// operands.
case HloOpcode::kAbs:
case HloOpcode::kAtan2:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kAdd:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kComplex:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kCrossReplicaSum:
@ -1417,6 +1427,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
@ -1430,6 +1441,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
@ -2117,6 +2129,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
switch (opcode_) {
case HloOpcode::kAbs:
return visitor->HandleAbs(this, operands_[0]);
case HloOpcode::kAtan2:
return visitor->HandleAtan2(this, operands_[0], operands_[1]);
case HloOpcode::kRoundNearestAfz:
return visitor->HandleRound(this);
case HloOpcode::kBatchNormTraining:
@ -2140,6 +2154,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
case HloOpcode::kLt:
case HloOpcode::kNe:
return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]);
case HloOpcode::kComplex:
return visitor->HandleComplex(this, operands_[0], operands_[1]);
case HloOpcode::kAdd:
return visitor->HandleAdd(this, operands_[0], operands_[1]);
case HloOpcode::kDivide:
@ -2214,6 +2230,10 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
return visitor->HandleCos(this, operands_[0]);
case HloOpcode::kSin:
return visitor->HandleSin(this, operands_[0]);
case HloOpcode::kReal:
return visitor->HandleReal(this, operands_[0]);
case HloOpcode::kImag:
return visitor->HandleImag(this, operands_[0]);
case HloOpcode::kIsFinite:
return visitor->HandleIsFinite(this, operands_[0]);
case HloOpcode::kNot:
@ -2305,7 +2325,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
//
// We need to keep track of both the id and the instruction because
// instructions can get deleted while they are on the stack, so we
// can't always use the (potentiall dead) instruction object to grab
// can't always use the (potentially dead) instruction object to grab
// its id.
DFSStack dfs_stack;
dfs_stack.emplace_back(root->unique_id(), root);
@ -2505,6 +2525,7 @@ bool HloInstruction::IsElementwiseBinary() const {
// Binary elementwise operations. If you update this, please update
// IsElementwise() accordingly.
case HloOpcode::kAdd:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kGe:
@ -2537,6 +2558,7 @@ bool HloInstruction::IsElementwise() const {
// Unary elementwise operations.
case HloOpcode::kAbs:
case HloOpcode::kAtan2:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kConvert:
@ -2544,10 +2566,12 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kSign:
case HloOpcode::kSin:
@ -2557,6 +2581,7 @@ bool HloInstruction::IsElementwise() const {
// Binary elementwise operations, the same as in IsElementwiseBinary().
// If you update this, please update IsElementwiseBinary() accordingly.
case HloOpcode::kAdd:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kGe:

View File

@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "abs";
case HloOpcode::kAdd:
return "add";
case HloOpcode::kAtan2:
return "atan2";
case HloOpcode::kBatchNormTraining:
return "batch-norm-training";
case HloOpcode::kBatchNormInference:
@ -47,6 +49,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "call";
case HloOpcode::kClamp:
return "clamp";
case HloOpcode::kComplex:
return "complex";
case HloOpcode::kConcatenate:
return "concatenate";
case HloOpcode::kConstant:
@ -87,6 +91,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "get-tuple-element";
case HloOpcode::kGt:
return "greater-than";
case HloOpcode::kImag:
return "imag";
case HloOpcode::kIndex:
return "index";
case HloOpcode::kInfeed:
@ -125,6 +131,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "parameter";
case HloOpcode::kPower:
return "power";
case HloOpcode::kReal:
return "real";
case HloOpcode::kRecv:
return "recv";
case HloOpcode::kReduce:

View File

@ -31,6 +31,7 @@ namespace xla {
enum class HloOpcode {
kAbs,
kAdd,
kAtan2,
kBatchNormGrad,
kBatchNormInference,
kBatchNormTraining,
@ -39,6 +40,7 @@ enum class HloOpcode {
kCall,
kCeil,
kClamp,
kComplex,
kConcatenate,
kConstant,
kConvert,
@ -58,6 +60,7 @@ enum class HloOpcode {
kGe,
kGetTupleElement,
kGt,
kImag,
kIndex,
kInfeed,
kIsFinite,
@ -77,6 +80,7 @@ enum class HloOpcode {
kPad,
kParameter,
kPower,
kReal,
kRecv,
kReduce,
kReducePrecision,

View File

@ -59,6 +59,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
for (auto& invariant_checker : invariant_checkers_) {
VLOG(1) << " Invariant checker " << invariant_checker->name();
StatusOr<bool> changed_status = invariant_checker->Run(module);
VLOG(1) << " Invariant checker done " << invariant_checker->name();
if (!changed_status.ok()) {
VLOG(2) << "Module failed invariant check:";
XLA_VLOG_LINES(2, module->ToString());

View File

@ -64,6 +64,10 @@ class ShapeVerifier : public DfsHloVisitor {
}
Status HandleConvert(HloInstruction* convert) override {
if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) {
TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape()))
<< "Unsupported complex->real kConvert";
}
return CheckShape(convert, ShapeInference::InferConvertShape(
convert->operand(0)->shape(),
convert->shape().element_type()));

View File

@ -32,17 +32,16 @@ namespace xla {
const HloInstruction& instruction) {
switch (instruction.opcode()) {
// Cheap instructions.
case HloOpcode::kAbs:
case HloOpcode::kAdd:
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConstant:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kEq:
@ -50,6 +49,7 @@ namespace xla {
case HloOpcode::kGe:
case HloOpcode::kGetTupleElement:
case HloOpcode::kGt:
case HloOpcode::kImag:
case HloOpcode::kInfeed:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
@ -64,6 +64,7 @@ namespace xla {
case HloOpcode::kNegate:
case HloOpcode::kOutfeed:
case HloOpcode::kPad:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
@ -72,15 +73,21 @@ namespace xla {
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
case HloOpcode::kSubtract:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
return false;
// Cheap instructions for reals, but expensive for complex.
case HloOpcode::kAbs:
case HloOpcode::kCos:
case HloOpcode::kSign:
case HloOpcode::kSin:
return ShapeUtil::ElementIsComplex(instruction.shape());
// Expensive instructions.
case HloOpcode::kAtan2:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:

View File

@ -75,7 +75,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
Status FusedIrEmitter::HandleConstant(HloInstruction* constant,
const Literal& literal) {
llvm::Constant* initializer =
llvm_ir::ConvertLiteralToIrConstant(literal, ir_builder_);
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
llvm::GlobalVariable* global = new llvm::GlobalVariable(
*ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
@ -101,7 +101,7 @@ Status FusedIrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
// Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
get_tuple_element->shape(), get_tuple_element->tuple_index(),
/*alignment=*/1, it->second, ir_builder_);
/*alignment=*/1, it->second, ir_builder_, module_);
gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr));
// Emit code to read base tuple element array (if non-tuple shaped).
if (!ShapeUtil::IsTuple(get_tuple_element->shape())) {
@ -134,7 +134,7 @@ Status FusedIrEmitter::HandleTuple(
std::vector<llvm::Type*> operand_elemental_ir_types;
for (HloInstruction* operand : operands) {
operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
operand->shape().element_type(), ir_builder_));
operand->shape().element_type(), module_));
}
generators_[tuple] =
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {

View File

@ -42,7 +42,8 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
ElementalIrEmitter* elemental_emitter)
: parameter_arrays_(parameter_arrays),
elemental_emitter_(elemental_emitter),
ir_builder_(elemental_emitter->ir_builder()) {}
ir_builder_(elemental_emitter->ir_builder()),
module_(elemental_emitter->module()) {}
Status DefaultAction(HloInstruction* hlo) override;
@ -85,6 +86,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
// Borrowed
llvm::IRBuilder<>* ir_builder_;
llvm::Module* module_;
// Map from instruction pointers to functions to generate elements of their
// outputs

View File

@ -229,9 +229,11 @@ llvm::Value* IrArray::EmitArrayElementAddress(
}
if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) {
llvm::Module* module =
ir_builder->GetInsertBlock()->getParent()->getParent();
return ir_builder->CreateInBoundsGEP(
ir_builder->CreateBitCast(
base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), ir_builder)
base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module)
->getPointerTo()),
{index.linear()}, llvm_ir::AsStringRef(name));
}
@ -281,7 +283,8 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
IrArray IrArray::CastToShape(const Shape& new_shape,
llvm::IRBuilder<>* ir_builder) const {
llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, ir_builder);
llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
return IrArray(
ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()),
new_shape);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Operator.h"
#include "llvm/Target/TargetOptions.h"
@ -38,6 +39,19 @@ limitations under the License.
namespace xla {
namespace llvm_ir {
namespace {
// Note, this function is only useful in an insertion context; in a global
// (e.g. constants) context it will CHECK fail.
llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) {
auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock());
auto fn = CHECK_NOTNULL(block->getParent());
auto module = CHECK_NOTNULL(fn->getParent());
return module;
}
} // namespace
string AsString(const std::string& str) {
return string(str.data(), str.length());
}
@ -63,7 +77,7 @@ llvm::Value* EmitCallToIntrinsic(
for (auto type : overloaded_types) {
types.push_back(type);
}
llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
llvm::Module* module = ModuleFromIRBuilder(ir_builder);
llvm::Function* intrinsic =
llvm::Intrinsic::getDeclaration(module, intrinsic_id, types);
std::vector<llvm::Value*> operands_vec;
@ -119,38 +133,53 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
}
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
llvm::IRBuilder<>* ir_builder) {
llvm::Module* module) {
switch (element_type) {
case PRED:
case S8:
case U8:
return ir_builder->getInt8Ty();
return llvm::Type::getInt8Ty(module->getContext());
case S16:
case U16:
return ir_builder->getInt16Ty();
return llvm::Type::getInt16Ty(module->getContext());
case S32:
case U32:
return ir_builder->getInt32Ty();
return llvm::Type::getInt32Ty(module->getContext());
case S64:
case U64:
return ir_builder->getInt64Ty();
return llvm::Type::getInt64Ty(module->getContext());
case F32:
return ir_builder->getFloatTy();
return llvm::Type::getFloatTy(module->getContext());
case F64:
return ir_builder->getDoubleTy();
return llvm::Type::getDoubleTy(module->getContext());
case C64: {
auto cplx_t = module->getTypeByName("complex64");
if (cplx_t == nullptr) {
// C++ standard dictates the memory layout of std::complex is contiguous
// real followed by imaginary. C++11 section 26.4 [complex.numbers]:
// If z is an lvalue expression of type cv std::complex<T> then the
// expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
// reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
// z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
// imaginary part of z.
return llvm::StructType::create(
"complex64", llvm::Type::getFloatTy(module->getContext()),
llvm::Type::getFloatTy(module->getContext()));
}
return cplx_t;
}
// A Tuple contains an array of pointers. Use i8*.
case TUPLE:
// An Opaque is like a void*, use i8*.
case OPAQUE:
return ir_builder->getInt8PtrTy();
return llvm::Type::getInt8PtrTy(module->getContext());
default:
LOG(FATAL) << "unsupported type " << element_type;
}
}
llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) {
llvm::Type* result_type =
PrimitiveTypeToIrType(shape.element_type(), ir_builder);
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
if (ShapeUtil::IsTuple(shape)) {
// A tuple buffer is an array of pointers.
result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
@ -197,10 +226,10 @@ namespace {
// value down to zero).
llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
std::vector<int64>* multi_index,
llvm::IRBuilder<>* ir_builder) {
llvm::Module* module) {
const Shape& shape = literal.shape();
llvm::Type* ir_element_type =
llvm_ir::PrimitiveTypeToIrType(shape.element_type(), ir_builder);
llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module);
if (dimension_index == -1) {
// Base case of the recursion. Index into the data field of the protobuf
// with the multi index.
@ -238,6 +267,16 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
value = llvm::ConstantFP::get(ir_element_type,
literal.Get<double>(*multi_index));
break;
case C64: {
complex64 x = literal.Get<complex64>(*multi_index);
value = llvm::ConstantStruct::get(
static_cast<llvm::StructType*>(ir_element_type),
llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
x.real()),
llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
x.imag()));
break;
}
default:
LOG(FATAL) << "unsupported type " << shape.element_type();
}
@ -256,8 +295,8 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
std::vector<llvm::Constant*> elements;
for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
(*multi_index)[dimension] = i;
elements.push_back(LiteralToConstant(literal, dimension_index - 1,
multi_index, ir_builder));
elements.push_back(
LiteralToConstant(literal, dimension_index - 1, multi_index, module));
}
llvm::Type* element_type;
@ -279,11 +318,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
} // namespace
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
llvm::IRBuilder<>* ir_builder) {
llvm::Module* module) {
std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
llvm::Constant* value = LiteralToConstant(
literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
&multi_index, ir_builder);
&multi_index, module);
return value;
}
@ -380,7 +419,8 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
// comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
// arrays. So we extend it to i8 so that it's addressable.
return ir_builder->CreateZExt(
comparison_result, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder));
comparison_result,
llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder)));
}
// Internal helper that is called from emitted code to log an int64 value with a

View File

@ -127,11 +127,11 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
// Returns the LLVM type which represents the given XLA primitive type.
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
llvm::IRBuilder<>* ir_builder);
llvm::Module* module);
// Returns the LLVM type which represents the given XLA shape. For example,
// if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder);
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
// Returns a value that represents a pointer to a global string constant that
// encodes the shape as a serialized protobuf.
@ -149,7 +149,7 @@ StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
// Converts a given literal to an IR Constant. Literals have known constant
// values at IR emission time.
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
llvm::IRBuilder<>* ir_builder);
llvm::Module* module);
// Inserts an allocate of the requested type at the entry point of the
// function that the builder is currently building. The insert point

View File

@ -31,14 +31,15 @@ namespace xla {
namespace llvm_ir {
void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) {
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
llvm::Module* module) {
CHECK(ShapeUtil::IsScalar(pred.GetShape()));
llvm::LoadInst* pred_value =
ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value");
llvm::Value* pred_cond = ir_builder->CreateICmpNE(
pred_value,
llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0),
llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0),
"boolean_predicate");
VLOG(2) << "HandleSelect for tuple:";
@ -71,11 +72,11 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
void EmitTuple(IrArray tuple,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
llvm::IRBuilder<>* ir_builder) {
llvm::IRBuilder<>* ir_builder, llvm::Module* module) {
for (size_t i = 0; i < operands.size(); ++i) {
auto* store = ir_builder->CreateStore(
ir_builder->CreatePointerCast(operands[i],
PrimitiveTypeToIrType(TUPLE, ir_builder)),
PrimitiveTypeToIrType(TUPLE, module)),
ir_builder->CreateInBoundsGEP(
tuple.GetBasePointer(),
{ir_builder->getInt64(0), ir_builder->getInt64(i)}));
@ -85,7 +86,8 @@ void EmitTuple(IrArray tuple,
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
llvm::IRBuilder<>* ir_builder) {
llvm::IRBuilder<>* ir_builder,
llvm::Module* module) {
llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP(
operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)});
llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr);
@ -98,7 +100,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
}
SetAlignmentMetadataForLoad(src_buffer, alignment);
llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder);
llvm::Type* element_type = ShapeToIrType(target_shape, module);
llvm::Value* ret_val =
ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo());
return ret_val;

View File

@ -60,13 +60,14 @@ namespace llvm_ir {
// tuple_on_true or tuple_on_false:
// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i]
void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder);
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
llvm::Module* module);
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand.
void EmitTuple(IrArray tuple,
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
llvm::IRBuilder<>* ir_builder);
llvm::IRBuilder<>* ir_builder, llvm::Module* module);
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand. A GetTupleElement instruction
@ -74,7 +75,8 @@ void EmitTuple(IrArray tuple,
// Returns an llvm value representing a pointer to the tuple element buffer.
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
llvm::IRBuilder<>* ir_builder);
llvm::IRBuilder<>* ir_builder,
llvm::Module* module);
} // namespace llvm_ir
} // namespace xla

View File

@ -53,6 +53,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
return UNOP_EXP;
case HloOpcode::kFloor:
return UNOP_FLOOR;
case HloOpcode::kImag:
return UNOP_IMAG;
case HloOpcode::kIsFinite:
return UNOP_IS_FINITE;
case HloOpcode::kLog:
@ -61,6 +63,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
return UNOP_NOT;
case HloOpcode::kNegate:
return UNOP_NEGATE;
case HloOpcode::kReal:
return UNOP_REAL;
case HloOpcode::kRoundNearestAfz:
return UNOP_ROUND_NEAREST_AFZ;
case HloOpcode::kSign:
@ -81,6 +85,10 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
// opcode.
BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kAtan2:
return BINOP_ATAN2;
case HloOpcode::kComplex:
return BINOP_COMPLEX;
case HloOpcode::kDot:
return BINOP_DOT;
case HloOpcode::kMultiply:
@ -307,19 +315,41 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
switch (operation) {
case UNOP_FLOOR:
case UNOP_CEIL:
if (!ShapeUtil::ElementIsFloating(arg)) {
return InvalidArgument(
"expected element type in shape to be floating for floor/ceil "
"operation; got %s",
PrimitiveType_Name(arg.element_type()).c_str());
}
return arg;
case UNOP_COS:
case UNOP_SIN:
case UNOP_EXP:
case UNOP_LOG:
case UNOP_TANH:
if (!ShapeUtil::ElementIsFloating(arg)) {
if (!ShapeUtil::ElementIsFloating(arg) &&
!ShapeUtil::ElementIsComplex(arg)) {
return InvalidArgument(
"expected element type in shape to be floating for exp/log/tanh "
"operation; got %s",
"expected element type in shape to be floating or complex for "
"sin/cos/exp/log/tanh operation; got %s",
PrimitiveType_Name(arg.element_type()).c_str());
}
return arg;
case UNOP_REAL:
case UNOP_IMAG:
if (!ShapeUtil::ElementIsComplex(arg)) {
return InvalidArgument(
"expected element type in shape to be complex for real/imag "
"operation; got %s",
PrimitiveType_Name(arg.element_type()).c_str());
}
return ShapeUtil::ChangeElementType(arg, F32);
case UNOP_ABS:
if (ShapeUtil::ElementIsComplex(arg)) {
return ShapeUtil::ChangeElementType(
arg, primitive_util::ComplexComponentType(arg.element_type()));
}
return arg;
case UNOP_NEGATE:
case UNOP_ROUND_NEAREST_AFZ:
case UNOP_SIGN:
@ -751,6 +781,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
case BINOP_MIN:
case BINOP_SUB:
case BINOP_ADD:
case BINOP_ATAN2:
case BINOP_POW:
case BINOP_DIV:
case BINOP_REM:
@ -761,6 +792,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return InferElementwiseBinaryOpShape(operation, lhs, rhs,
broadcast_dimensions);
case BINOP_COMPLEX: {
if (!ShapeUtil::ElementIsFloating(lhs)) {
return InvalidArgument(
"expected element type in shape to be floating for complex compose "
"operation; got %s",
PrimitiveType_Name(lhs.element_type()).c_str());
}
TF_ASSIGN_OR_RETURN(const Shape& shape,
InferElementwiseBinaryOpShape(operation, lhs, rhs,
broadcast_dimensions));
if (lhs.element_type() == F32) {
return ShapeUtil::ChangeElementType(shape, C64);
} else {
return Unimplemented("complex component type not supported");
}
}
case BINOP_AND:
case BINOP_OR:
if (lhs.element_type() != PRED &&

View File

@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test {
// Some handy scalar shapes.
const Shape s32_ = ShapeUtil::MakeShape(S32, {});
const Shape f32_ = ShapeUtil::MakeShape(F32, {});
const Shape f64_ = ShapeUtil::MakeShape(F64, {});
const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
// Some handy vector and matrix shapes of F32 type.
@ -251,6 +252,44 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) {
.ok());
}
TEST_F(ShapeInferenceTest, Complex) {
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
const tensorflow::gtl::ArraySlice<int64>& bcast) {
return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX,
lhs, rhs, bcast);
};
// Inputs must be FP.
ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
// Component types must match.
ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
// Only F32->C64 supported.
ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok());
// Validate correct uses.
Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
TF_ASSERT_OK_AND_ASSIGN(result,
complex_shape(vector_64_, matrix_32_64_, {1}));
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
TF_ASSERT_OK_AND_ASSIGN(result,
complex_shape(matrix_32_64_, vector_64_, {1}));
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
TF_ASSERT_OK_AND_ASSIGN(result,
complex_shape(matrix_32_64_, matrix_32_64_, {}));
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
}
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});

View File

@ -55,6 +55,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
return HloOpcode::kExp;
case UNOP_FLOOR:
return HloOpcode::kFloor;
case UNOP_IMAG:
return HloOpcode::kImag;
case UNOP_IS_FINITE:
return HloOpcode::kIsFinite;
case UNOP_LOG:
@ -63,6 +65,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
return HloOpcode::kNot;
case UNOP_NEGATE:
return HloOpcode::kNegate;
case UNOP_REAL:
return HloOpcode::kReal;
case UNOP_ROUND_NEAREST_AFZ:
return HloOpcode::kRoundNearestAfz;
case UNOP_SIGN:
@ -80,6 +84,10 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
switch (binop) {
case BINOP_ATAN2:
return HloOpcode::kAtan2;
case BINOP_COMPLEX:
return HloOpcode::kComplex;
case BINOP_DOT:
return HloOpcode::kDot;
case BINOP_MUL:

View File

@ -272,6 +272,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
case U16:
case U32:
case U64:
case C64:
case TUPLE:
case OPAQUE:
return false;

View File

@ -361,8 +361,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
ComputationBuilder* builder, const Array2D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value,
"Floating point type required when specifying an ErrorSpec");
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
Literal::CreateR2FromArray2D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
@ -384,8 +385,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
ComputationBuilder* builder, const Array3D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value,
"Floating point type required when specifying an ErrorSpec");
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
Literal::CreateR3FromArray3D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
@ -407,8 +409,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
ComputationBuilder* builder, const Array4D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value,
"Floating point type required when specifying an ErrorSpec");
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
Literal::CreateR4FromArray4D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,

View File

@ -347,7 +347,7 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) {
TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
}
TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
constexpr bool kLhsRowMajor = true;
constexpr bool kRhsRowMajor = true;
TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
@ -357,7 +357,11 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) {
TestNonsquareMatrixDot<double>();
}
TEST_F(DotOperationTest, ConcurrentMatMul) {
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
TestNonsquareMatrixDot<complex64>();
}
XLA_TEST_F(DotOperationTest, ConcurrentMatMul) {
ComputationBuilder builder(client_, TestName());
auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});

View File

@ -41,7 +41,11 @@ class UnaryOpTest : public ClientLibraryTestBase {
auto arg = builder.ConstantR1<T>({});
auto abs = builder.Abs(arg);
ComputeAndCompareR1<T>(&builder, {}, {});
if (primitive_util::NativeToPrimitiveType<T>() == C64) {
ComputeAndCompareR1<float>(&builder, {}, {});
} else {
ComputeAndCompareR1<T>(&builder, {}, {});
}
}
template <typename T>
@ -80,14 +84,58 @@ int UnaryOpTest::inf<int>() {
return 2147483647;
}
template <>
void UnaryOpTest::AbsTestHelper<complex64>() {
ComputationBuilder builder(client_, TestName());
auto arg = builder.ConstantR1<complex64>({{-2, 0},
{0, 25},
{0, 0},
{-0.3f, 0.4f},
{0, inf<float>()},
{-inf<float>(), 0}});
auto abs = builder.Abs(arg);
std::unique_ptr<Literal> expected =
Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
template <>
void UnaryOpTest::SignTestHelper<complex64>() {
ComputationBuilder builder(client_, TestName());
auto arg = builder.ConstantR1<complex64>(
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
auto sign = builder.Sign(arg);
std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>(
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
template <>
void UnaryOpTest::SignAbsTestHelper<complex64>() {
ComputationBuilder builder(client_, TestName());
auto arg =
builder.ConstantR1<complex64>({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}});
auto sign = builder.Sign(arg);
auto abs = builder.Abs(arg);
builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg);
std::unique_ptr<Literal> expected =
Literal::CreateR1<complex64>({0, 0, 0, 0});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
AbsSize0TestHelper<int>();
AbsSize0TestHelper<float>();
AbsSize0TestHelper<complex64>();
}
XLA_TEST_F(UnaryOpTest, AbsTestR1) {
AbsTestHelper<int>();
AbsTestHelper<float>();
AbsTestHelper<complex64>();
}
XLA_TEST_F(UnaryOpTest, AbsTestR0) {
@ -98,34 +146,44 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) {
auto absf = builder.Abs(argf);
auto argf0 = builder.ConstantR0<float>(-0.0f);
auto absf0 = builder.Abs(argf0);
builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
absi, PrimitiveType::F32)));
auto argc = builder.ConstantR0<complex64>({-0.3f, 0.4f});
auto absc = builder.Abs(argc);
builder.Add(builder.Add(absc, absf0),
builder.Add(absf, builder.ConvertElementType(absi, F32)));
ComputeAndCompareR0<float>(&builder, 8.0f, {});
ComputeAndCompareR0<float>(&builder, 8.5f, {});
}
XLA_TEST_F(UnaryOpTest, SignTestR0) {
ComputationBuilder builder(client_, TestName());
auto argi = builder.ConstantR0<int>(-5);
auto absi = builder.Sign(argi);
auto sgni = builder.Sign(argi); // -1
auto argf = builder.ConstantR0<float>(-4.0f);
auto absf = builder.Sign(argf);
auto sgnf = builder.Sign(argf); // -1
auto argf0 = builder.ConstantR0<float>(-0.0f);
auto absf0 = builder.Sign(argf0);
builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
absi, PrimitiveType::F32)));
auto sgnf0 = builder.Sign(argf0); // 0
auto argc = builder.ConstantR0<complex64>({-.3, .4});
auto sgnc = builder.Sign(argc); // (-.6, .8)
builder.Add(sgnc, builder.ConvertElementType(
builder.Add(builder.Add(sgnf0, sgnf),
builder.ConvertElementType(sgni, F32)),
C64));
ComputeAndCompareR0<float>(&builder, -2.0f, {});
std::unique_ptr<Literal> expected =
Literal::CreateR0<complex64>({-2.6f, 0.8f});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
XLA_TEST_F(UnaryOpTest, SignTestR1) {
SignTestHelper<int>();
SignTestHelper<float>();
SignTestHelper<complex64>();
}
XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
SignAbsTestHelper<int>();
SignAbsTestHelper<float>();
SignAbsTestHelper<complex64>();
}
XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {

View File

@ -235,11 +235,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSort:
@ -256,6 +258,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_
#define TENSORFLOW_COMPILER_XLA_TYPES_H_
#include <complex>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/types.h"
@ -35,7 +37,7 @@ using ::tensorflow::uint16;
using ::tensorflow::uint32;
using ::tensorflow::uint64;
typedef std::complex<float> complex64;
using complex64 = std::complex<float>;
using ::Eigen::half;

View File

@ -49,7 +49,7 @@ enum PrimitiveType {
F64 = 12;
// Complex values of fixed width.
C64 = 15;
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>.
// A tuple is a polymorphic sequence; e.g. a shape that holds different
// sub-shapes. They are used for things like returning multiple values from a
@ -667,6 +667,12 @@ enum UnaryOperation {
// Elementwise, rounds x to nearest integral value, rounding half-way cases
// away from zero.
UNOP_ROUND_NEAREST_AFZ = 14;
// Elementwise, extract real component of complex x.
UNOP_REAL = 15;
// Elementwise, extract real component of complex x.
UNOP_IMAG = 16;
}
message UnaryOpRequest {
@ -721,6 +727,12 @@ enum BinaryOperation {
BINOP_SHIFT_LEFT = 20;
BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
BINOP_SHIFT_RIGHT_LOGICAL = 22;
// Complex from real, imag.
BINOP_COMPLEX = 23;
// Computes the 4-quadrant arctangent of the y, x input arguments.
BINOP_ATAN2 = 24;
}
message BinaryOpRequest {