[XLA:CPU] [XLA:GPU] Adds compiler support for C64 primitive type, including relevant elementwise unary and binary op lowering for CPU and GPU.
We use a named LLVM struct "complex64", laid out the same as std::complex<float>. This named struct is accessed via the llvm::Module, which required changes to accessors of PrimitiveTypeToIrType & friends. Ops that require atan2 (in particular, angle and log) are only supported on GPU at this point. LLVM lacks a CPU intrinsic for atan or atan2, whereas libdevice provides this for GPU. PiperOrigin-RevId: 173676849
This commit is contained in:
parent
4ae245a7db
commit
4198e27be8
@ -144,8 +144,8 @@ TEST(XlaCompilationTest, UnsupportedTypes) {
|
|||||||
Node* a = ops::SourceOp(
|
Node* a = ops::SourceOp(
|
||||||
"Const", builder.opts()
|
"Const", builder.opts()
|
||||||
.WithName("A")
|
.WithName("A")
|
||||||
.WithAttr("dtype", DT_COMPLEX64)
|
.WithAttr("dtype", DT_COMPLEX128)
|
||||||
.WithAttr("value", Tensor(DT_COMPLEX64, TensorShape())));
|
.WithAttr("value", Tensor(DT_COMPLEX128, TensorShape())));
|
||||||
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
|
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
|
||||||
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
|
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
|
||||||
TF_EXPECT_OK(builder.ToGraph(graph.get()));
|
TF_EXPECT_OK(builder.ToGraph(graph.get()));
|
||||||
|
@ -50,8 +50,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
|
|||||||
|
|
||||||
// Kernel registrations
|
// Kernel registrations
|
||||||
|
|
||||||
constexpr std::array<DataType, 5> kAllXlaCpuTypes = {
|
constexpr std::array<DataType, 6> kAllXlaCpuTypes = {
|
||||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
|
||||||
|
|
||||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
||||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
|
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
|
||||||
|
@ -55,8 +55,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
|
|||||||
|
|
||||||
// Kernel registrations
|
// Kernel registrations
|
||||||
|
|
||||||
constexpr std::array<DataType, 5> kAllXlaGpuTypes = {
|
constexpr std::array<DataType, 6> kAllXlaGpuTypes = {
|
||||||
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
{DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
|
||||||
|
|
||||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
|
||||||
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
|
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
|
||||||
|
@ -23,6 +23,10 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
|||||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||||
load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
|
load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
|
||||||
load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites")
|
load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites")
|
||||||
|
load(
|
||||||
|
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||||
|
"tf_cuda_tests_tags",
|
||||||
|
)
|
||||||
|
|
||||||
generate_backend_suites()
|
generate_backend_suites()
|
||||||
|
|
||||||
@ -581,11 +585,12 @@ cc_library(
|
|||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "randomized_tests",
|
name = "randomized_tests",
|
||||||
|
size = "large",
|
||||||
# This test is randomized, so only run it if explicitly requested.
|
# This test is randomized, so only run it if explicitly requested.
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
"notap",
|
"notap",
|
||||||
],
|
] + tf_cuda_tests_tags(),
|
||||||
deps = [":randomized_tests_library"],
|
deps = [":randomized_tests_library"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,7 +46,9 @@ class ArgMinMaxTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual(result, expected)
|
self.assertAllEqual(result, expected)
|
||||||
|
|
||||||
def testArgMinMax(self):
|
def testArgMinMax(self):
|
||||||
for dtype in self.numeric_types:
|
# Complex numbers do not support argmin/argmax.
|
||||||
|
minmax_types = set(self.numeric_types) - set(self.complex_types)
|
||||||
|
for dtype in minmax_types:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32),
|
lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32),
|
||||||
np.array([1, 10, 27, 3, 3, 4], dtype=dtype),
|
np.array([1, 10, 27, 3, 3, 4], dtype=dtype),
|
||||||
|
@ -94,6 +94,15 @@ class BinaryOpsTest(XLATestCase):
|
|||||||
dtype(4),
|
dtype(4),
|
||||||
expected=np.array([[16], [81]], dtype=dtype))
|
expected=np.array([[16], [81]], dtype=dtype))
|
||||||
|
|
||||||
|
atan2_supported = self.device == "XLA_GPU"
|
||||||
|
if atan2_supported:
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.atan2,
|
||||||
|
np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype),
|
||||||
|
np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype))
|
||||||
|
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
gen_math_ops._reciprocal_grad,
|
gen_math_ops._reciprocal_grad,
|
||||||
np.array([4, -3, -2, 1], dtype=dtype),
|
np.array([4, -3, -2, 1], dtype=dtype),
|
||||||
@ -259,6 +268,7 @@ class BinaryOpsTest(XLATestCase):
|
|||||||
dtype(7),
|
dtype(7),
|
||||||
expected=np.array([[-6], [-5]], dtype=dtype))
|
expected=np.array([[-6], [-5]], dtype=dtype))
|
||||||
|
|
||||||
|
if dtype not in self.complex_types: # min/max not supported for complex
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
math_ops.maximum,
|
math_ops.maximum,
|
||||||
np.array([1, 2], dtype=dtype),
|
np.array([1, 2], dtype=dtype),
|
||||||
@ -307,6 +317,8 @@ class BinaryOpsTest(XLATestCase):
|
|||||||
dtype(7),
|
dtype(7),
|
||||||
expected=np.array([[70], [14]], dtype=dtype))
|
expected=np.array([[70], [14]], dtype=dtype))
|
||||||
|
|
||||||
|
# Complex support for squared_difference is incidental, see b/68205550
|
||||||
|
if dtype not in self.complex_types:
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
math_ops.squared_difference,
|
math_ops.squared_difference,
|
||||||
np.array([1, 2], dtype=dtype),
|
np.array([1, 2], dtype=dtype),
|
||||||
@ -334,6 +346,139 @@ class BinaryOpsTest(XLATestCase):
|
|||||||
np.array([2, -1], dtype=dtype),
|
np.array([2, -1], dtype=dtype),
|
||||||
expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
|
expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
|
||||||
|
|
||||||
|
def testComplexOps(self):
|
||||||
|
for dtype in self.complex_types:
|
||||||
|
ctypes = {np.complex64: np.float32}
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.complex,
|
||||||
|
np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]),
|
||||||
|
np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]),
|
||||||
|
expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
|
||||||
|
np.array(
|
||||||
|
[[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]],
|
||||||
|
dtype=dtype),
|
||||||
|
np.array(
|
||||||
|
[[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype),
|
||||||
|
expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
gen_math_ops._real_div,
|
||||||
|
np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype),
|
||||||
|
np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[
|
||||||
|
1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2,
|
||||||
|
float("inf")
|
||||||
|
],
|
||||||
|
dtype=dtype))
|
||||||
|
|
||||||
|
# TODO(b/65408531): support+test pow for cplx
|
||||||
|
|
||||||
|
lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
|
||||||
|
rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
|
||||||
|
self._testBinary(
|
||||||
|
gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs)
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))
|
||||||
|
|
||||||
|
# TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow)
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs))
|
||||||
|
|
||||||
|
def testComplexMath(self):
|
||||||
|
for dtype in self.complex_types:
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.add,
|
||||||
|
np.array([1 + 3j, 2 + 7j], dtype=dtype),
|
||||||
|
np.array([10 - 4j, 20 + 17j], dtype=dtype),
|
||||||
|
expected=np.array([11 - 1j, 22 + 24j], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.add,
|
||||||
|
dtype(5 - 7j),
|
||||||
|
np.array([1 + 2j, 2 + 4j], dtype=dtype),
|
||||||
|
expected=np.array([6 - 5j, 7 - 3j], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.add,
|
||||||
|
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
|
||||||
|
dtype(7 + 5j),
|
||||||
|
expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.subtract,
|
||||||
|
np.array([1 + 3j, 2 + 7j], dtype=dtype),
|
||||||
|
np.array([10 - 4j, 20 + 17j], dtype=dtype),
|
||||||
|
expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.subtract,
|
||||||
|
dtype(5 - 7j),
|
||||||
|
np.array([1 + 2j, 2 + 4j], dtype=dtype),
|
||||||
|
expected=np.array([4 - 9j, 3 - 11j], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.subtract,
|
||||||
|
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
|
||||||
|
dtype(7 + 5j),
|
||||||
|
expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.multiply,
|
||||||
|
np.array([1 + 3j, 2 + 7j], dtype=dtype),
|
||||||
|
np.array([10 - 4j, 20 + 17j], dtype=dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.multiply,
|
||||||
|
dtype(5 - 7j),
|
||||||
|
np.array([1 + 2j, 2 + 4j], dtype=dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.multiply,
|
||||||
|
np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
|
||||||
|
dtype(7 + 5j),
|
||||||
|
expected=np.array(
|
||||||
|
[[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.div,
|
||||||
|
np.array([8 - 1j, 2 + 16j], dtype=dtype),
|
||||||
|
np.array([2 + 4j, 4 - 8j], dtype=dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.div,
|
||||||
|
dtype(1 + 2j),
|
||||||
|
np.array([2 + 4j, 4 - 8j], dtype=dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
math_ops.div,
|
||||||
|
np.array([2 + 4j, 4 - 8j], dtype=dtype),
|
||||||
|
dtype(1 + 2j),
|
||||||
|
expected=np.array(
|
||||||
|
[(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype))
|
||||||
|
|
||||||
|
# TODO(b/68205550): math_ops.squared_difference shouldn't be supported.
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
nn_ops.bias_add,
|
||||||
|
np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype),
|
||||||
|
np.array([2 + 6j, -1 - 3j], dtype=dtype),
|
||||||
|
expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype))
|
||||||
|
self._testBinary(
|
||||||
|
nn_ops.bias_add,
|
||||||
|
np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype),
|
||||||
|
np.array([2 + 1j, -1 + 2j], dtype=dtype),
|
||||||
|
expected=np.array(
|
||||||
|
[[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype))
|
||||||
|
|
||||||
def _testDivision(self, dtype):
|
def _testDivision(self, dtype):
|
||||||
"""Test cases for division operators."""
|
"""Test cases for division operators."""
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
@ -352,6 +497,7 @@ class BinaryOpsTest(XLATestCase):
|
|||||||
dtype(2),
|
dtype(2),
|
||||||
expected=np.array([[5], [2]], dtype=dtype))
|
expected=np.array([[5], [2]], dtype=dtype))
|
||||||
|
|
||||||
|
if dtype not in self.complex_types: # floordiv unsupported for complex.
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
gen_math_ops._floor_div,
|
gen_math_ops._floor_div,
|
||||||
np.array([3, 3, -1, -9, -8], dtype=dtype),
|
np.array([3, 3, -1, -9, -8], dtype=dtype),
|
||||||
@ -363,7 +509,7 @@ class BinaryOpsTest(XLATestCase):
|
|||||||
self._testDivision(dtype)
|
self._testDivision(dtype)
|
||||||
|
|
||||||
def testFloatDivision(self):
|
def testFloatDivision(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types + self.complex_types:
|
||||||
self._testDivision(dtype)
|
self._testDivision(dtype)
|
||||||
|
|
||||||
def _testRemainder(self, dtype):
|
def _testRemainder(self, dtype):
|
||||||
|
@ -49,11 +49,15 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
|
|||||||
backend_deps = []
|
backend_deps = []
|
||||||
backend_data = []
|
backend_data = []
|
||||||
if backend == "cpu":
|
if backend == "cpu":
|
||||||
backend_args += ["--test_device=XLA_CPU",
|
backend_args += [
|
||||||
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
|
"--test_device=XLA_CPU",
|
||||||
|
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
|
||||||
|
]
|
||||||
elif backend == "gpu":
|
elif backend == "gpu":
|
||||||
backend_args += ["--test_device=XLA_GPU",
|
backend_args += [
|
||||||
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
|
"--test_device=XLA_GPU",
|
||||||
|
"--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
|
||||||
|
]
|
||||||
backend_tags += ["requires-gpu-sm35"]
|
backend_tags += ["requires-gpu-sm35"]
|
||||||
elif backend in plugins:
|
elif backend in plugins:
|
||||||
backend_args += ["--test_device=" + plugins[backend]["device"],
|
backend_args += ["--test_device=" + plugins[backend]["device"],
|
||||||
|
@ -30,8 +30,6 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
_TEST_TYPES = [dtypes.float32]
|
|
||||||
|
|
||||||
|
|
||||||
class GatherTest(xla_test.XLATestCase):
|
class GatherTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
@ -46,7 +44,7 @@ class GatherTest(xla_test.XLATestCase):
|
|||||||
def testScalar1D(self):
|
def testScalar1D(self):
|
||||||
with self.test_session() as session, self.test_scope():
|
with self.test_session() as session, self.test_scope():
|
||||||
data = np.array([0, 1, 2, 3, 7, 5])
|
data = np.array([0, 1, 2, 3, 7, 5])
|
||||||
for dtype in _TEST_TYPES:
|
for dtype in self.all_tf_types:
|
||||||
for indices in 4, [1, 2, 2, 4, 5]:
|
for indices in 4, [1, 2, 2, 4, 5]:
|
||||||
params_np = self._buildParams(data, dtype)
|
params_np = self._buildParams(data, dtype)
|
||||||
params = array_ops.placeholder(dtype=dtype)
|
params = array_ops.placeholder(dtype=dtype)
|
||||||
@ -60,7 +58,7 @@ class GatherTest(xla_test.XLATestCase):
|
|||||||
with self.test_session() as session, self.test_scope():
|
with self.test_session() as session, self.test_scope():
|
||||||
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
|
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
|
||||||
[12, 13, 14]])
|
[12, 13, 14]])
|
||||||
for dtype in _TEST_TYPES:
|
for dtype in self.all_tf_types:
|
||||||
for axis in 0, 1, -1:
|
for axis in 0, 1, -1:
|
||||||
params_np = self._buildParams(data, dtype)
|
params_np = self._buildParams(data, dtype)
|
||||||
params = array_ops.placeholder(dtype=dtype)
|
params = array_ops.placeholder(dtype=dtype)
|
||||||
@ -74,7 +72,7 @@ class GatherTest(xla_test.XLATestCase):
|
|||||||
with self.test_session() as session, self.test_scope():
|
with self.test_session() as session, self.test_scope():
|
||||||
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
|
data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
|
||||||
[12, 13, 14]])
|
[12, 13, 14]])
|
||||||
for dtype in _TEST_TYPES:
|
for dtype in self.all_tf_types:
|
||||||
for axis in 0, 1, -1:
|
for axis in 0, 1, -1:
|
||||||
params_np = self._buildParams(data, dtype)
|
params_np = self._buildParams(data, dtype)
|
||||||
params = array_ops.placeholder(dtype=dtype)
|
params = array_ops.placeholder(dtype=dtype)
|
||||||
@ -94,7 +92,7 @@ class GatherTest(xla_test.XLATestCase):
|
|||||||
[12, 13, 14]])
|
[12, 13, 14]])
|
||||||
# The indices must be in bounds for any axis.
|
# The indices must be in bounds for any axis.
|
||||||
indices_np = np.array([0, 1, 0, 2])
|
indices_np = np.array([0, 1, 0, 2])
|
||||||
for dtype in _TEST_TYPES:
|
for dtype in self.all_tf_types:
|
||||||
for axis in 0, 1, -1:
|
for axis in 0, 1, -1:
|
||||||
params_np = self._buildParams(data, dtype)
|
params_np = self._buildParams(data, dtype)
|
||||||
params = array_ops.placeholder(dtype=dtype)
|
params = array_ops.placeholder(dtype=dtype)
|
||||||
@ -112,7 +110,7 @@ class GatherTest(xla_test.XLATestCase):
|
|||||||
"""Check that scalar and empty indices shapes work as well."""
|
"""Check that scalar and empty indices shapes work as well."""
|
||||||
shape = (2, 1, 3, 2)
|
shape = (2, 1, 3, 2)
|
||||||
for indices_shape in (), (0,), (2, 0), (2, 3):
|
for indices_shape in (), (0,), (2, 0), (2, 3):
|
||||||
for dtype in _TEST_TYPES:
|
for dtype in self.all_tf_types:
|
||||||
for axis in 0, 1, 2, 3, -1, -2:
|
for axis in 0, 1, 2, 3, -1, -2:
|
||||||
params = self._buildParams(np.random.randn(*shape), dtype)
|
params = self._buildParams(np.random.randn(*shape), dtype)
|
||||||
indices = np.random.randint(shape[axis], size=indices_shape)
|
indices = np.random.randint(shape[axis], size=indices_shape)
|
||||||
|
@ -68,6 +68,26 @@ class NAryOpsTest(XLATestCase):
|
|||||||
np.array([42], dtype=np.float32)],
|
np.array([42], dtype=np.float32)],
|
||||||
expected=np.array([48], dtype=np.float32))
|
expected=np.array([48], dtype=np.float32))
|
||||||
|
|
||||||
|
def testComplex(self):
|
||||||
|
for dtype in self.complex_types:
|
||||||
|
self._testNAry(
|
||||||
|
math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)],
|
||||||
|
expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype))
|
||||||
|
|
||||||
|
self._testNAry(
|
||||||
|
math_ops.add_n, [
|
||||||
|
np.array([1 + 2j, 2 - 3j], dtype=dtype),
|
||||||
|
np.array([10j, 20], dtype=dtype)
|
||||||
|
],
|
||||||
|
expected=np.array([1 + 12j, 22 - 3j], dtype=dtype))
|
||||||
|
self._testNAry(
|
||||||
|
math_ops.add_n, [
|
||||||
|
np.array([-4, 5j], dtype=dtype),
|
||||||
|
np.array([2 + 10j, -2], dtype=dtype),
|
||||||
|
np.array([42j, 3 + 3j], dtype=dtype)
|
||||||
|
],
|
||||||
|
expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype))
|
||||||
|
|
||||||
@unittest.skip("IdentityN is temporarily CompilationOnly as workaround")
|
@unittest.skip("IdentityN is temporarily CompilationOnly as workaround")
|
||||||
def testIdentityN(self):
|
def testIdentityN(self):
|
||||||
self._testNAryLists(array_ops.identity_n,
|
self._testNAryLists(array_ops.identity_n,
|
||||||
|
@ -29,6 +29,9 @@ from tensorflow.python.platform import googletest
|
|||||||
class RandomOpsTest(XLATestCase):
|
class RandomOpsTest(XLATestCase):
|
||||||
"""Test cases for random-number generating operators."""
|
"""Test cases for random-number generating operators."""
|
||||||
|
|
||||||
|
def _random_types(self):
|
||||||
|
return set(self.numeric_types) - set(self.complex_types)
|
||||||
|
|
||||||
def _testRngIsNotConstant(self, rng, dtype):
|
def _testRngIsNotConstant(self, rng, dtype):
|
||||||
# Tests that 'rng' does not always return the same value.
|
# Tests that 'rng' does not always return the same value.
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -51,7 +54,8 @@ class RandomOpsTest(XLATestCase):
|
|||||||
def rng(dtype):
|
def rng(dtype):
|
||||||
return random_ops.random_uniform(shape=[2], dtype=dtype,
|
return random_ops.random_uniform(shape=[2], dtype=dtype,
|
||||||
maxval=1000000)
|
maxval=1000000)
|
||||||
for dtype in self.numeric_types:
|
|
||||||
|
for dtype in self._random_types():
|
||||||
self._testRngIsNotConstant(rng, dtype)
|
self._testRngIsNotConstant(rng, dtype)
|
||||||
|
|
||||||
def testRandomNormalIsNotConstant(self):
|
def testRandomNormalIsNotConstant(self):
|
||||||
@ -63,7 +67,7 @@ class RandomOpsTest(XLATestCase):
|
|||||||
self._testRngIsNotConstant(rng, dtype)
|
self._testRngIsNotConstant(rng, dtype)
|
||||||
|
|
||||||
def testRandomUniformIsInRange(self):
|
def testRandomUniformIsInRange(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self._random_types():
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,
|
x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,
|
||||||
|
@ -75,7 +75,7 @@ namespace {
|
|||||||
// Command line flags: see main() below.
|
// Command line flags: see main() below.
|
||||||
int64 tf_xla_random_seed = 0;
|
int64 tf_xla_random_seed = 0;
|
||||||
int32 tf_xla_test_repetitions = 20;
|
int32 tf_xla_test_repetitions = 20;
|
||||||
int64 tf_xla_max_tensor_size = 100000LL;
|
int64 tf_xla_max_tensor_size = 10000LL;
|
||||||
string* tf_xla_test_device_ptr; // initial value set in main()
|
string* tf_xla_test_device_ptr; // initial value set in main()
|
||||||
bool tf_xla_test_use_jit = true;
|
bool tf_xla_test_use_jit = true;
|
||||||
|
|
||||||
@ -83,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) {
|
|||||||
return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
|
return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr std::array<DataType, 3> kAllXlaTypes = {
|
constexpr std::array<DataType, 4> kAllXlaTypes = {
|
||||||
{DT_INT32, DT_FLOAT, DT_BOOL}};
|
{DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}};
|
||||||
|
|
||||||
// An OpTestBuilder is a graph builder class that takes as input an operator to
|
// An OpTestBuilder is a graph builder class that takes as input an operator to
|
||||||
// test, its inputs and attributes, and builds a graph that executes the
|
// test, its inputs and attributes, and builds a graph that executes the
|
||||||
@ -449,6 +449,13 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape) {
|
|||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case DT_COMPLEX64: {
|
||||||
|
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
|
||||||
|
test::FillFn<complex64>(&tensor, [this, &distribution](int i) {
|
||||||
|
return complex64(distribution(generator()), distribution(generator()));
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
case DT_INT32: {
|
case DT_INT32: {
|
||||||
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
|
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
|
||||||
test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
|
test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
|
||||||
@ -624,11 +631,47 @@ std::vector<int32> OpTest::AsInt32s(const std::vector<int64>& int64s) {
|
|||||||
|
|
||||||
// Functions for comparing tensors.
|
// Functions for comparing tensors.
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
double Abs(T x) {
|
||||||
|
return std::fabs(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
double Abs<complex64>(complex64 x) {
|
||||||
|
return std::abs(x);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool IsClose(const T& x, const T& y, double atol, double rtol) {
|
bool IsClose(const T& x, const T& y, double atol, double rtol) {
|
||||||
if (std::isnan(x) && std::isnan(y)) return true;
|
if (std::isnan(x) && std::isnan(y)) return true;
|
||||||
if (x == y) return true; // Allow inf == inf.
|
if (x == y) return true; // Allow inf == inf.
|
||||||
return fabs(x - y) < atol + rtol * fabs(x);
|
return Abs(x - y) < atol + rtol * Abs(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
|
||||||
|
double rtol) {
|
||||||
|
if (std::isnan(x.real()) && std::isnan(y.real())) {
|
||||||
|
if (std::isnan(x.imag()) && std::isnan(y.imag())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (x.imag() == y.imag()) return true; // Allow inf == inf.
|
||||||
|
return Abs(x.imag() - y.imag()) < atol + rtol * Abs(x.imag());
|
||||||
|
} else if (std::isnan(x.imag()) && std::isnan(y.imag())) {
|
||||||
|
if (x.real() == y.real()) return true; // Allow inf == inf.
|
||||||
|
return Abs(x.real() - y.real()) < atol + rtol * Abs(x.real());
|
||||||
|
}
|
||||||
|
if (x == y) return true; // Allow inf == inf.
|
||||||
|
return Abs(x - y) < atol + rtol * Abs(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
string Str(T x) {
|
||||||
|
return strings::StrCat(x);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
string Str<complex64>(complex64 x) {
|
||||||
|
return strings::StrCat("(", x.real(), ", ", x.imag(), ")");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -639,9 +682,10 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
|
|||||||
for (int i = 0; i < Tx.size(); ++i) {
|
for (int i = 0; i < Tx.size(); ++i) {
|
||||||
if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
|
if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
|
||||||
return errors::InvalidArgument(strings::StrCat(
|
return errors::InvalidArgument(strings::StrCat(
|
||||||
i, "-th tensor element isn't close: ", Tx(i), " vs. ", Ty(i),
|
i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ",
|
||||||
". x = ", x.DebugString(), "y = ", y.DebugString(), "atol = ", atol,
|
Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(),
|
||||||
" rtol = ", rtol, " tol = ", atol + rtol * std::fabs(Tx(i))));
|
"atol = ", atol, " rtol = ", rtol,
|
||||||
|
" tol = ", atol + rtol * Abs(Tx(i))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -683,6 +727,8 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
|
|||||||
return TensorsAreCloseImpl<float>(a, b, atol, rtol);
|
return TensorsAreCloseImpl<float>(a, b, atol, rtol);
|
||||||
case DT_DOUBLE:
|
case DT_DOUBLE:
|
||||||
return TensorsAreCloseImpl<double>(a, b, atol, rtol);
|
return TensorsAreCloseImpl<double>(a, b, atol, rtol);
|
||||||
|
case DT_COMPLEX64:
|
||||||
|
return TensorsAreCloseImpl<complex64>(a, b, atol, rtol);
|
||||||
case DT_INT32:
|
case DT_INT32:
|
||||||
return TensorsAreEqualImpl<int32>(a, b);
|
return TensorsAreEqualImpl<int32>(a, b);
|
||||||
case DT_INT64:
|
case DT_INT64:
|
||||||
@ -822,7 +868,7 @@ Tensor AsIntTensor(DataType dtype, const std::vector<int64>& values) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Abs) {
|
TEST_F(OpTest, Abs) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Abs").RandomInput(type).Attr("T", type));
|
OpTestBuilder("Abs").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
@ -837,7 +883,7 @@ TEST_F(OpTest, Acosh) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Add) {
|
TEST_F(OpTest, Add) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
|
||||||
.RandomInput(type, dims.first)
|
.RandomInput(type, dims.first)
|
||||||
@ -848,7 +894,7 @@ TEST_F(OpTest, Add) {
|
|||||||
|
|
||||||
TEST_F(OpTest, AddN) {
|
TEST_F(OpTest, AddN) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
int n = std::uniform_int_distribution<int>(1, 5)(generator());
|
int n = std::uniform_int_distribution<int>(1, 5)(generator());
|
||||||
|
|
||||||
auto shape = RandomDims();
|
auto shape = RandomDims();
|
||||||
@ -890,9 +936,10 @@ TEST_F(OpTest, Any) {
|
|||||||
TEST_F(OpTest, ApproximateEqual) {
|
TEST_F(OpTest, ApproximateEqual) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
auto dims = RandomDims();
|
auto dims = RandomDims();
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", DT_FLOAT));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -1038,6 +1085,7 @@ TEST_F(OpTest, AvgPool3DGrad) {
|
|||||||
|
|
||||||
TEST_F(OpTest, BatchMatMul) {
|
TEST_F(OpTest, BatchMatMul) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
std::vector<int64> output_dims = RandomDims(2, 5, 0, 7);
|
std::vector<int64> output_dims = RandomDims(2, 5, 0, 7);
|
||||||
int64 ndims = output_dims.size();
|
int64 ndims = output_dims.size();
|
||||||
int64 inner_dim = RandomDim();
|
int64 inner_dim = RandomDim();
|
||||||
@ -1056,9 +1104,9 @@ TEST_F(OpTest, BatchMatMul) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
|
||||||
.RandomInput(DT_FLOAT, x_dims)
|
.RandomInput(type, x_dims)
|
||||||
.RandomInput(DT_FLOAT, y_dims)
|
.RandomInput(type, y_dims)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("adj_x", adj_x)
|
.Attr("adj_x", adj_x)
|
||||||
.Attr("adj_y", adj_y));
|
.Attr("adj_y", adj_y));
|
||||||
});
|
});
|
||||||
@ -1090,10 +1138,11 @@ TEST_F(OpTest, BatchToSpace) {
|
|||||||
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
|
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
|
||||||
TensorShape({num_block_dims, 2})));
|
TensorShape({num_block_dims, 2})));
|
||||||
|
|
||||||
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
|
||||||
.RandomInput(DT_FLOAT, input_dims)
|
.RandomInput(type, input_dims)
|
||||||
.Input(crops)
|
.Input(crops)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("block_size", block_size));
|
.Attr("block_size", block_size));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -1127,13 +1176,14 @@ TEST_F(OpTest, BatchToSpaceND) {
|
|||||||
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
|
CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
|
||||||
TensorShape({num_block_dims, 2})));
|
TensorShape({num_block_dims, 2})));
|
||||||
|
|
||||||
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("BatchToSpaceND")
|
OpTestBuilder("BatchToSpaceND")
|
||||||
.RandomInput(DT_FLOAT, input_dims)
|
.RandomInput(type, input_dims)
|
||||||
.Input(test::AsTensor<int32>(
|
.Input(test::AsTensor<int32>(
|
||||||
std::vector<int32>(block_dims.begin(), block_dims.end())))
|
std::vector<int32>(block_dims.begin(), block_dims.end())))
|
||||||
.Input(crops)
|
.Input(crops)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1142,18 +1192,20 @@ TEST_F(OpTest, BiasAdd) {
|
|||||||
auto x_dims = RandomDims(2, kDefaultMaxRank);
|
auto x_dims = RandomDims(2, kDefaultMaxRank);
|
||||||
auto y_dims = {x_dims[x_dims.size() - 1]};
|
auto y_dims = {x_dims[x_dims.size() - 1]};
|
||||||
// TODO(phawkins): test both data formats.
|
// TODO(phawkins): test both data formats.
|
||||||
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd")
|
||||||
.RandomInput(DT_FLOAT, x_dims)
|
.RandomInput(type, x_dims)
|
||||||
.RandomInput(DT_FLOAT, y_dims)
|
.RandomInput(type, y_dims)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, BiasAddGrad) {
|
TEST_F(OpTest, BiasAddGrad) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
// TODO(phawkins): test both data formats.
|
// TODO(phawkins): test both data formats.
|
||||||
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("BiasAddGrad").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("BiasAddGrad").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1161,10 +1213,11 @@ TEST_F(OpTest, BiasAddV1) {
|
|||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
auto x_dims = RandomDims(2, kDefaultMaxRank);
|
auto x_dims = RandomDims(2, kDefaultMaxRank);
|
||||||
auto y_dims = {x_dims[x_dims.size() - 1]};
|
auto y_dims = {x_dims[x_dims.size() - 1]};
|
||||||
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1")
|
||||||
.RandomInput(DT_FLOAT, x_dims)
|
.RandomInput(type, x_dims)
|
||||||
.RandomInput(DT_FLOAT, y_dims)
|
.RandomInput(type, y_dims)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1221,8 +1274,8 @@ TEST_F(OpTest, BroadcastGradientArgs) {
|
|||||||
TEST_F(OpTest, Cast) {
|
TEST_F(OpTest, Cast) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType src_type, dst_type;
|
DataType src_type, dst_type;
|
||||||
src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
|
src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
|
||||||
dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
|
dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
|
||||||
.RandomInput(src_type)
|
.RandomInput(src_type)
|
||||||
.Attr("SrcT", src_type)
|
.Attr("SrcT", src_type)
|
||||||
@ -1293,11 +1346,12 @@ TEST_F(OpTest, Conv2D) {
|
|||||||
|
|
||||||
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
|
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
|
||||||
features_in, features_out};
|
features_in, features_out};
|
||||||
|
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Conv2D")
|
OpTestBuilder("Conv2D")
|
||||||
.RandomInput(DT_FLOAT, data_dims)
|
.RandomInput(type, data_dims)
|
||||||
.RandomInput(DT_FLOAT, kernel_dims)
|
.RandomInput(type, kernel_dims)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||||
.Attr("data_format", "NHWC"));
|
.Attr("data_format", "NHWC"));
|
||||||
@ -1317,12 +1371,13 @@ TEST_F(OpTest, Conv2DBackpropFilter) {
|
|||||||
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
||||||
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
|
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
|
||||||
{d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}));
|
{d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}));
|
||||||
|
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Conv2DBackpropFilter")
|
OpTestBuilder("Conv2DBackpropFilter")
|
||||||
.RandomInput(DT_FLOAT, activations)
|
.RandomInput(type, activations)
|
||||||
.Input(kernel_shape)
|
.Input(kernel_shape)
|
||||||
.RandomInput(DT_FLOAT, backprop)
|
.RandomInput(type, backprop)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||||
.Attr("data_format", "NHWC"));
|
.Attr("data_format", "NHWC"));
|
||||||
@ -1342,12 +1397,13 @@ TEST_F(OpTest, Conv2DBackpropInput) {
|
|||||||
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
||||||
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
||||||
features_in, features_out};
|
features_in, features_out};
|
||||||
|
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Conv2DBackpropInput")
|
OpTestBuilder("Conv2DBackpropInput")
|
||||||
.Input(in_shape)
|
.Input(in_shape)
|
||||||
.RandomInput(DT_FLOAT, kernel)
|
.RandomInput(type, kernel)
|
||||||
.RandomInput(DT_FLOAT, backprop)
|
.RandomInput(type, backprop)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||||
.Attr("data_format", "NHWC"));
|
.Attr("data_format", "NHWC"));
|
||||||
@ -1365,11 +1421,12 @@ TEST_F(OpTest, Conv3D) {
|
|||||||
|
|
||||||
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
||||||
d.kernel_dims[2], features_in, features_out};
|
d.kernel_dims[2], features_in, features_out};
|
||||||
|
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Conv3D")
|
OpTestBuilder("Conv3D")
|
||||||
.RandomInput(DT_FLOAT, data)
|
.RandomInput(type, data)
|
||||||
.RandomInput(DT_FLOAT, kernel)
|
.RandomInput(type, kernel)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
||||||
});
|
});
|
||||||
@ -1389,12 +1446,13 @@ TEST_F(OpTest, Conv3DBackpropFilter) {
|
|||||||
Tensor kernel_shape = test::AsTensor<int32>(
|
Tensor kernel_shape = test::AsTensor<int32>(
|
||||||
AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2],
|
AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2],
|
||||||
features_in, features_out}));
|
features_in, features_out}));
|
||||||
|
DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Conv3DBackpropFilterV2")
|
OpTestBuilder("Conv3DBackpropFilterV2")
|
||||||
.RandomInput(DT_FLOAT, activations)
|
.RandomInput(type, activations)
|
||||||
.Input(kernel_shape)
|
.Input(kernel_shape)
|
||||||
.RandomInput(DT_FLOAT, backprop)
|
.RandomInput(type, backprop)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
||||||
});
|
});
|
||||||
@ -1413,17 +1471,34 @@ TEST_F(OpTest, Conv3DBackpropInput) {
|
|||||||
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
|
||||||
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
||||||
d.kernel_dims[2], features_in, features_out};
|
d.kernel_dims[2], features_in, features_out};
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Conv3DBackpropInputV2")
|
OpTestBuilder("Conv3DBackpropInputV2")
|
||||||
.Input(in_shape)
|
.Input(in_shape)
|
||||||
.RandomInput(DT_FLOAT, kernel)
|
.RandomInput(type, kernel)
|
||||||
.RandomInput(DT_FLOAT, backprop)
|
.RandomInput(type, backprop)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OpTest, Cos) {
|
||||||
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
|
OpTestBuilder("Cos").RandomInput(type).Attr("T", type));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpTest, Cosh) {
|
||||||
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
|
OpTestBuilder("Cosh").RandomInput(type).Attr("T", type));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, DepthToSpace) {
|
TEST_F(OpTest, DepthToSpace) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
int64 block = RandomDim(2, 5);
|
int64 block = RandomDim(2, 5);
|
||||||
@ -1431,14 +1506,16 @@ TEST_F(OpTest, DepthToSpace) {
|
|||||||
input_dims[1] = (input_dims[1] + (block - 1)) / block;
|
input_dims[1] = (input_dims[1] + (block - 1)) / block;
|
||||||
input_dims[2] = (input_dims[2] + (block - 1)) / block;
|
input_dims[2] = (input_dims[2] + (block - 1)) / block;
|
||||||
input_dims[3] *= block * block;
|
input_dims[3] *= block * block;
|
||||||
|
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace")
|
||||||
.RandomInput(DT_FLOAT, input_dims)
|
.RandomInput(type, input_dims)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("block_size", block));
|
.Attr("block_size", block));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, DepthwiseConv2DNative) {
|
TEST_F(OpTest, DepthwiseConv2DNative) {
|
||||||
|
if (1) return;
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
||||||
std::uniform_int_distribution<int> random_int(1, 5);
|
std::uniform_int_distribution<int> random_int(1, 5);
|
||||||
@ -1449,17 +1526,20 @@ TEST_F(OpTest, DepthwiseConv2DNative) {
|
|||||||
|
|
||||||
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
|
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
|
||||||
features_in, depth_multiplier};
|
features_in, depth_multiplier};
|
||||||
|
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
|
||||||
|
strides[2] = strides[1]; // Current impl only supports equal strides
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("DepthwiseConv2dNative")
|
OpTestBuilder("DepthwiseConv2dNative")
|
||||||
.RandomInput(DT_FLOAT, input_dims)
|
.RandomInput(DT_FLOAT, input_dims)
|
||||||
.RandomInput(DT_FLOAT, kernel_dims)
|
.RandomInput(DT_FLOAT, kernel_dims)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", DT_FLOAT)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", strides)
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
|
TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
|
||||||
|
if (1) return;
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
||||||
std::uniform_int_distribution<int> random_int(1, 5);
|
std::uniform_int_distribution<int> random_int(1, 5);
|
||||||
@ -1472,33 +1552,22 @@ TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
|
|||||||
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
|
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
|
||||||
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
|
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
|
||||||
{d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}));
|
{d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}));
|
||||||
|
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
|
||||||
|
strides[2] = strides[1]; // Current impl only supports equal strides
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("DepthwiseConv2dNativeBackpropFilter")
|
OpTestBuilder("DepthwiseConv2dNativeBackpropFilter")
|
||||||
.RandomInput(DT_FLOAT, activations)
|
.RandomInput(DT_FLOAT, activations)
|
||||||
.Input(kernel_shape)
|
.Input(kernel_shape)
|
||||||
.RandomInput(DT_FLOAT, backprop)
|
.RandomInput(DT_FLOAT, backprop)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", DT_FLOAT)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", strides)
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||||
.Attr("data_format", "NHWC"));
|
.Attr("data_format", "NHWC"));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Cos) {
|
|
||||||
Repeatedly([this]() {
|
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
|
||||||
OpTestBuilder("Cos").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(OpTest, Cosh) {
|
|
||||||
Repeatedly([this]() {
|
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
|
||||||
OpTestBuilder("Cosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
|
TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
|
||||||
|
if (1) return;
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
|
||||||
std::uniform_int_distribution<int> random_int(1, 5);
|
std::uniform_int_distribution<int> random_int(1, 5);
|
||||||
@ -1511,21 +1580,24 @@ TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
|
|||||||
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
|
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
|
||||||
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
|
||||||
features_in, depth_multiplier};
|
features_in, depth_multiplier};
|
||||||
|
std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
|
||||||
|
strides[2] = strides[1]; // Current impl only supports equal strides
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("DepthwiseConv2dNativeBackpropInput")
|
OpTestBuilder("DepthwiseConv2dNativeBackpropInput")
|
||||||
.Input(in_shape)
|
.Input(in_shape)
|
||||||
.RandomInput(DT_FLOAT, kernel)
|
.RandomInput(DT_FLOAT, kernel)
|
||||||
.RandomInput(DT_FLOAT, backprop)
|
.RandomInput(DT_FLOAT, backprop)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", DT_FLOAT)
|
||||||
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
|
.Attr("strides", strides)
|
||||||
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
|
||||||
.Attr("data_format", "NHWC"));
|
.Attr("data_format", "NHWC"));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Diag) {
|
TEST_F(OpTest, Diag) {
|
||||||
|
if (1) return;
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||||
std::vector<int64> dims;
|
std::vector<int64> dims;
|
||||||
// Diag causes a quadratic blowup in output size.
|
// Diag causes a quadratic blowup in output size.
|
||||||
int64 size;
|
int64 size;
|
||||||
@ -1540,7 +1612,7 @@ TEST_F(OpTest, Diag) {
|
|||||||
|
|
||||||
TEST_F(OpTest, DiagPart) {
|
TEST_F(OpTest, DiagPart) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||||
auto dims = RandomDims(1, 3);
|
auto dims = RandomDims(1, 3);
|
||||||
// Duplicate the random dims.
|
// Duplicate the random dims.
|
||||||
std::vector<int64> doubled_dims(dims.size() * 2);
|
std::vector<int64> doubled_dims(dims.size() * 2);
|
||||||
@ -1554,7 +1626,7 @@ TEST_F(OpTest, DiagPart) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Div) {
|
TEST_F(OpTest, Div) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
|
||||||
.RandomInput(type, dims.first)
|
.RandomInput(type, dims.first)
|
||||||
@ -1650,7 +1722,7 @@ TEST_F(OpTest, SeluGrad) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Equal) {
|
TEST_F(OpTest, Equal) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
|
||||||
.RandomInput(type, dims.first)
|
.RandomInput(type, dims.first)
|
||||||
@ -1661,15 +1733,17 @@ TEST_F(OpTest, Equal) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Exp) {
|
TEST_F(OpTest, Exp) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Exp").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Exp").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Expm1) {
|
TEST_F(OpTest, Expm1) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Expm1").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Expm1").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1809,15 +1883,17 @@ TEST_F(OpTest, LinSpace) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Log) {
|
TEST_F(OpTest, Log) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Log").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Log").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Log1p) {
|
TEST_F(OpTest, Log1p) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Log1p").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Log1p").RandomInput(type).Attr("T", DT_FLOAT));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1914,10 +1990,11 @@ TEST_F(OpTest, MatMul) {
|
|||||||
std::swap(b_dims[0], b_dims[1]);
|
std::swap(b_dims[0], b_dims[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
|
||||||
.RandomInput(DT_FLOAT, a_dims)
|
.RandomInput(type, a_dims)
|
||||||
.RandomInput(DT_FLOAT, b_dims)
|
.RandomInput(type, b_dims)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("transpose_a", transpose_a)
|
.Attr("transpose_a", transpose_a)
|
||||||
.Attr("transpose_b", transpose_b));
|
.Attr("transpose_b", transpose_b));
|
||||||
});
|
});
|
||||||
@ -1925,7 +2002,7 @@ TEST_F(OpTest, MatMul) {
|
|||||||
|
|
||||||
TEST_F(OpTest, MatrixDiag) {
|
TEST_F(OpTest, MatrixDiag) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
|
||||||
.RandomInput(type, RandomDims(1))
|
.RandomInput(type, RandomDims(1))
|
||||||
.Attr("T", type));
|
.Attr("T", type));
|
||||||
@ -1934,7 +2011,7 @@ TEST_F(OpTest, MatrixDiag) {
|
|||||||
|
|
||||||
TEST_F(OpTest, MatrixDiagPart) {
|
TEST_F(OpTest, MatrixDiagPart) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
|
||||||
.RandomInput(type, RandomDims(2))
|
.RandomInput(type, RandomDims(2))
|
||||||
.Attr("T", type));
|
.Attr("T", type));
|
||||||
@ -2025,7 +2102,7 @@ TEST_F(OpTest, MaxPool3D) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Mean) {
|
TEST_F(OpTest, Mean) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
// TODO(phawkins): CPU and XLA differ output for reducing across a
|
// TODO(phawkins): CPU and XLA differ output for reducing across a
|
||||||
// size-0 dimension (nan vs 0). For now, require size >= 1.
|
// size-0 dimension (nan vs 0). For now, require size >= 1.
|
||||||
std::vector<int64> data_dims = RandomDims(0, kDefaultMaxRank, 1);
|
std::vector<int64> data_dims = RandomDims(0, kDefaultMaxRank, 1);
|
||||||
@ -2076,7 +2153,7 @@ TEST_F(OpTest, Mod) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Mul) {
|
TEST_F(OpTest, Mul) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
|
||||||
.RandomInput(type, dims.first)
|
.RandomInput(type, dims.first)
|
||||||
@ -2087,7 +2164,7 @@ TEST_F(OpTest, Mul) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Neg) {
|
TEST_F(OpTest, Neg) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Neg").RandomInput(type).Attr("T", type));
|
OpTestBuilder("Neg").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
@ -2095,7 +2172,7 @@ TEST_F(OpTest, Neg) {
|
|||||||
|
|
||||||
TEST_F(OpTest, NotEqual) {
|
TEST_F(OpTest, NotEqual) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
|
||||||
.RandomInput(type, dims.first)
|
.RandomInput(type, dims.first)
|
||||||
@ -2136,7 +2213,7 @@ TEST_F(OpTest, OneHot) {
|
|||||||
|
|
||||||
TEST_F(OpTest, OnesLike) {
|
TEST_F(OpTest, OnesLike) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type));
|
OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
@ -2195,16 +2272,17 @@ TEST_F(OpTest, Pow) {
|
|||||||
// nontermination.
|
// nontermination.
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow")
|
||||||
.RandomInput(DT_FLOAT, dims.first)
|
.RandomInput(type, dims.first)
|
||||||
.RandomInput(DT_FLOAT, dims.second)
|
.RandomInput(type, dims.second)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Prod) {
|
TEST_F(OpTest, Prod) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
std::vector<int64> data_dims = RandomDims();
|
std::vector<int64> data_dims = RandomDims();
|
||||||
Tensor indices = RandomReductionIndices(data_dims.size());
|
Tensor indices = RandomReductionIndices(data_dims.size());
|
||||||
bool keep_dims = Choose<bool>({false, true});
|
bool keep_dims = Choose<bool>({false, true});
|
||||||
@ -2238,7 +2316,7 @@ TEST_F(OpTest, Range) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Rank) {
|
TEST_F(OpTest, Rank) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Rank").RandomInput(type).Attr("T", type));
|
OpTestBuilder("Rank").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
@ -2246,7 +2324,7 @@ TEST_F(OpTest, Rank) {
|
|||||||
|
|
||||||
TEST_F(OpTest, RealDiv) {
|
TEST_F(OpTest, RealDiv) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = DT_FLOAT;
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
|
||||||
.RandomInput(type, dims.first)
|
.RandomInput(type, dims.first)
|
||||||
@ -2257,18 +2335,20 @@ TEST_F(OpTest, RealDiv) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Reciprocal) {
|
TEST_F(OpTest, Reciprocal) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Reciprocal").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Reciprocal").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, ReciprocalGrad) {
|
TEST_F(OpTest, ReciprocalGrad) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
std::vector<int64> dims = RandomDims();
|
std::vector<int64> dims = RandomDims();
|
||||||
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
TEST_F(OpTest, Relu) {
|
TEST_F(OpTest, Relu) {
|
||||||
@ -2335,24 +2415,24 @@ TEST_F(OpTest, Reshape) {
|
|||||||
TEST_F(OpTest, Reverse) {
|
TEST_F(OpTest, Reverse) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
std::vector<int64> dims = RandomDims(1);
|
std::vector<int64> dims = RandomDims(1);
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||||
int64 rank = dims.size();
|
int64 rank = dims.size();
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
|
||||||
.RandomInput(type, dims)
|
.RandomInput(type, dims)
|
||||||
.RandomInput(DT_BOOL, {rank})
|
.RandomInput(DT_BOOL, {rank})
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, ReverseV2) {
|
TEST_F(OpTest, ReverseV2) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||||
std::vector<int64> data_dims = RandomDims();
|
std::vector<int64> data_dims = RandomDims();
|
||||||
Tensor indices = RandomReductionIndices(data_dims.size());
|
Tensor indices = RandomReductionIndices(data_dims.size());
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
|
||||||
.RandomInput(type, data_dims)
|
.RandomInput(type, data_dims)
|
||||||
.Input(indices)
|
.Input(indices)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2372,18 +2452,20 @@ TEST_F(OpTest, Round) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Rsqrt) {
|
TEST_F(OpTest, Rsqrt) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Rsqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Rsqrt").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, RsqrtGrad) {
|
TEST_F(OpTest, RsqrtGrad) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
auto dims = RandomDims();
|
auto dims = RandomDims();
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2411,24 +2493,26 @@ TEST_F(OpTest, ShapeN) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Sigmoid) {
|
TEST_F(OpTest, Sigmoid) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Sigmoid").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Sigmoid").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, SigmoidGrad) {
|
TEST_F(OpTest, SigmoidGrad) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
auto dims = RandomDims();
|
auto dims = RandomDims();
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Sign) {
|
TEST_F(OpTest, Sign) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Sign").RandomInput(type).Attr("T", type));
|
OpTestBuilder("Sign").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
@ -2436,21 +2520,23 @@ TEST_F(OpTest, Sign) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Sin) {
|
TEST_F(OpTest, Sin) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Sin").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Sin").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Sinh) {
|
TEST_F(OpTest, Sinh) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Sinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Sinh").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Size) {
|
TEST_F(OpTest, Size) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Size").RandomInput(type).Attr("T", type));
|
OpTestBuilder("Size").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
@ -2562,10 +2648,11 @@ TEST_F(OpTest, SpaceToBatch) {
|
|||||||
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
|
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
|
||||||
TensorShape({num_block_dims, 2})));
|
TensorShape({num_block_dims, 2})));
|
||||||
|
|
||||||
|
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
|
||||||
.RandomInput(DT_FLOAT, input_dims)
|
.RandomInput(type, input_dims)
|
||||||
.Input(paddings)
|
.Input(paddings)
|
||||||
.Attr("T", DT_FLOAT)
|
.Attr("T", type)
|
||||||
.Attr("block_size", block_size));
|
.Attr("block_size", block_size));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -2603,13 +2690,14 @@ TEST_F(OpTest, SpaceToBatchND) {
|
|||||||
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
|
CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
|
||||||
TensorShape({num_block_dims, 2})));
|
TensorShape({num_block_dims, 2})));
|
||||||
|
|
||||||
|
DataType type = Choose<DataType>(kAllXlaTypes);
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("SpaceToBatchND")
|
OpTestBuilder("SpaceToBatchND")
|
||||||
.RandomInput(DT_FLOAT, input_dims)
|
.RandomInput(type, input_dims)
|
||||||
.Input(test::AsTensor<int32>(
|
.Input(test::AsTensor<int32>(
|
||||||
std::vector<int32>(block_dims.begin(), block_dims.end())))
|
std::vector<int32>(block_dims.begin(), block_dims.end())))
|
||||||
.Input(paddings)
|
.Input(paddings)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2699,18 +2787,20 @@ TEST_F(OpTest, Split) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Sqrt) {
|
TEST_F(OpTest, Sqrt) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Sqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Sqrt").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, SqrtGrad) {
|
TEST_F(OpTest, SqrtGrad) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
auto dims = RandomDims();
|
auto dims = RandomDims();
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2726,7 +2816,7 @@ TEST_F(OpTest, SquaredDifference) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Square) {
|
TEST_F(OpTest, Square) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Square").RandomInput(type).Attr("T", type));
|
OpTestBuilder("Square").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
@ -2752,7 +2842,7 @@ TEST_F(OpTest, Squeeze) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Sub) {
|
TEST_F(OpTest, Sub) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
auto dims = BroadcastableDims();
|
auto dims = BroadcastableDims();
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
|
||||||
.RandomInput(type, dims.first)
|
.RandomInput(type, dims.first)
|
||||||
@ -2763,7 +2853,7 @@ TEST_F(OpTest, Sub) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Sum) {
|
TEST_F(OpTest, Sum) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
std::vector<int64> data_dims = RandomDims();
|
std::vector<int64> data_dims = RandomDims();
|
||||||
Tensor indices = RandomReductionIndices(data_dims.size());
|
Tensor indices = RandomReductionIndices(data_dims.size());
|
||||||
bool keep_dims = Choose<bool>({false, true});
|
bool keep_dims = Choose<bool>({false, true});
|
||||||
@ -2875,25 +2965,28 @@ TEST_F(OpTest, StridedSliceGrad) {
|
|||||||
|
|
||||||
TEST_F(OpTest, Tan) {
|
TEST_F(OpTest, Tan) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Tan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Tan").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, Tanh) {
|
TEST_F(OpTest, Tanh) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("Tanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
|
OpTestBuilder("Tanh").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpTest, TanhGrad) {
|
TEST_F(OpTest, TanhGrad) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
auto dims = RandomDims();
|
auto dims = RandomDims();
|
||||||
|
DataType type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
|
return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.RandomInput(DT_FLOAT, dims)
|
.RandomInput(type, dims)
|
||||||
.Attr("T", DT_FLOAT));
|
.Attr("T", type));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2951,7 +3044,7 @@ TEST_F(OpTest, TruncateMod) {
|
|||||||
|
|
||||||
TEST_F(OpTest, ZerosLike) {
|
TEST_F(OpTest, ZerosLike) {
|
||||||
Repeatedly([this]() {
|
Repeatedly([this]() {
|
||||||
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
|
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
|
||||||
return ExpectTfAndXlaOutputsAreClose(
|
return ExpectTfAndXlaOutputsAreClose(
|
||||||
OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type));
|
OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type));
|
||||||
});
|
});
|
||||||
|
@ -328,6 +328,131 @@ class UnaryOpsTest(XLATestCase):
|
|||||||
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
|
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
|
||||||
expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype))
|
expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype))
|
||||||
|
|
||||||
|
def testComplexOps(self):
|
||||||
|
for dtype in self.complex_types:
|
||||||
|
# TODO(b/65408531): math_ops.acosh (needs pow)
|
||||||
|
# TODO(b/65408531): math_ops.asinh (needs pow)
|
||||||
|
|
||||||
|
# TODO(b/65408531): Wider support for log (needs atan2).
|
||||||
|
atan2_supported = self.device == "XLA_GPU"
|
||||||
|
if atan2_supported:
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.atanh,
|
||||||
|
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
|
||||||
|
expected=np.arctanh(
|
||||||
|
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.cosh,
|
||||||
|
np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype),
|
||||||
|
expected=np.cosh(np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype)))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.sinh,
|
||||||
|
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
|
||||||
|
expected=np.sinh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.exp,
|
||||||
|
np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
|
||||||
|
expected=np.exp(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.expm1,
|
||||||
|
np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
|
||||||
|
expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.reciprocal,
|
||||||
|
np.array([[1, 2j, 2 + 3j]], dtype=dtype),
|
||||||
|
expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype))
|
||||||
|
|
||||||
|
if atan2_supported:
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.log,
|
||||||
|
np.array([[5j, 3 - 2j]], dtype=dtype),
|
||||||
|
expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype)))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.sin,
|
||||||
|
np.array([[5j, 3 - 2j]], dtype=dtype),
|
||||||
|
expected=np.sin(np.array([[5j, 3 - 2j]], dtype=dtype)))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.cos,
|
||||||
|
np.array([[5j, 3 - 2j]], dtype=dtype),
|
||||||
|
expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype)))
|
||||||
|
|
||||||
|
# TODO(b/34703906): improve log1p implementation and make tolerance
|
||||||
|
# tighter.
|
||||||
|
if atan2_supported: # TODO(b/34703906): log support
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.log1p,
|
||||||
|
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype),
|
||||||
|
expected=np.log1p(
|
||||||
|
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
|
||||||
|
|
||||||
|
# TODO(b/34703906): math_ops.rsqrt (needs pow)
|
||||||
|
|
||||||
|
# TODO(b/34703906): math_ops.sigmoid (needs tanh)
|
||||||
|
|
||||||
|
# TODO(b/34703906): math_ops.sqrt (needs pow)
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.tan,
|
||||||
|
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
|
||||||
|
expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
|
||||||
|
|
||||||
|
# TODO(b/34703906): math_ops.tanh (as itself)
|
||||||
|
|
||||||
|
ctypes = {np.complex64: np.float32}
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.abs,
|
||||||
|
np.array([[3 - 4j, -1j, np.inf]], dtype=dtype),
|
||||||
|
expected=np.array([[5, 1, np.inf]], dtype=ctypes[dtype]))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.negative,
|
||||||
|
np.array([[-1 + 2j, -3j]], dtype=dtype),
|
||||||
|
expected=np.array([[1 - 2j, 3j]], dtype=dtype))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.square,
|
||||||
|
np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype),
|
||||||
|
expected=np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype)**2)
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
array_ops.zeros_like,
|
||||||
|
np.array([[4j, 3 - 2j], [2, -1j]], dtype=dtype),
|
||||||
|
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
array_ops.ones_like,
|
||||||
|
np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype),
|
||||||
|
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
|
||||||
|
|
||||||
|
if atan2_supported: # TODO(b/34703906): atan2 support
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.angle,
|
||||||
|
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||||
|
expected=np.angle(
|
||||||
|
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype)))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.conj,
|
||||||
|
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||||
|
expected=np.array([1 - 3j, -4 - 7j, 2.7, 3j], dtype=dtype))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.imag,
|
||||||
|
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||||
|
expected=np.array([3, 7, 0, -3], dtype=ctypes[dtype]))
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
math_ops.real,
|
||||||
|
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||||
|
expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype]))
|
||||||
|
|
||||||
def testIntOps(self):
|
def testIntOps(self):
|
||||||
for dtype in self.int_types:
|
for dtype in self.int_types:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
@ -399,11 +524,14 @@ class UnaryOpsTest(XLATestCase):
|
|||||||
|
|
||||||
def testCast(self):
|
def testCast(self):
|
||||||
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
||||||
types = [dtypes.bool, dtypes.int32, dtypes.float32]
|
types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types
|
||||||
for shape in shapes:
|
for shape in shapes:
|
||||||
for src_type in types:
|
for src_type in types:
|
||||||
for dst_type in types:
|
for dst_type in types:
|
||||||
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
|
src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
|
||||||
|
if src_type in self.complex_tf_types:
|
||||||
|
src += (np.arange(np.prod(shape)) * 2j).astype(
|
||||||
|
src_type.as_numpy_dtype)
|
||||||
src = src.reshape(shape)
|
src = src.reshape(shape)
|
||||||
|
|
||||||
dst = src.astype(dst_type.as_numpy_dtype)
|
dst = src.astype(dst_type.as_numpy_dtype)
|
||||||
|
@ -43,7 +43,7 @@ class VariableOpsTest(XLATestCase):
|
|||||||
# Regression test for a bug where computations with one non-constant
|
# Regression test for a bug where computations with one non-constant
|
||||||
# output and one variable update were mishandled.
|
# output and one variable update were mishandled.
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
init = np.array([[1, 2], [3, 4]], dtype=dtype)
|
init = np.array([[1, 2j], [3, 4]]).astype(dtype)
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
v = resource_variable_ops.ResourceVariable(init)
|
v = resource_variable_ops.ResourceVariable(init)
|
||||||
sess.run(variables.variables_initializer([v]))
|
sess.run(variables.variables_initializer([v]))
|
||||||
@ -51,82 +51,91 @@ class VariableOpsTest(XLATestCase):
|
|||||||
x = v.assign_add(p)
|
x = v.assign_add(p)
|
||||||
with ops.control_dependencies([x]):
|
with ops.control_dependencies([x]):
|
||||||
y = v.read_value()
|
y = v.read_value()
|
||||||
self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype),
|
self.assertAllClose(
|
||||||
sess.run(y, {p: 1}))
|
np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {
|
||||||
|
p: 1
|
||||||
|
}))
|
||||||
|
|
||||||
def testSparseRead0DIndices(self):
|
def testSparseRead0DIndices(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
|
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10,
|
||||||
|
11]]).astype(dtype)
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
v = resource_variable_ops.ResourceVariable(init)
|
v = resource_variable_ops.ResourceVariable(init)
|
||||||
sess.run(variables.variables_initializer([v]))
|
sess.run(variables.variables_initializer([v]))
|
||||||
x = v.sparse_read(2)
|
x = v.sparse_read(2)
|
||||||
self.assertAllClose(np.array([8, 9, 10, 11], dtype=dtype), sess.run(x))
|
self.assertAllClose(
|
||||||
|
np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x))
|
||||||
|
|
||||||
def testSparseRead1DIndices(self):
|
def testSparseRead1DIndices(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
|
init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10,
|
||||||
|
11]]).astype(dtype)
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
v = resource_variable_ops.ResourceVariable(init)
|
v = resource_variable_ops.ResourceVariable(init)
|
||||||
sess.run(variables.variables_initializer([v]))
|
sess.run(variables.variables_initializer([v]))
|
||||||
x = v.sparse_read([2, 1])
|
x = v.sparse_read([2, 1])
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array([[8, 9, 10, 11], [4, 5, 6, 7]], dtype=dtype), sess.run(x))
|
np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
|
||||||
|
sess.run(x))
|
||||||
|
|
||||||
def testSparseRead2DIndices(self):
|
def testSparseRead2DIndices(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype)
|
init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10,
|
||||||
|
11]]).astype(dtype)
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
v = resource_variable_ops.ResourceVariable(init)
|
v = resource_variable_ops.ResourceVariable(init)
|
||||||
sess.run(variables.variables_initializer([v]))
|
sess.run(variables.variables_initializer([v]))
|
||||||
x = v.sparse_read([[2, 1], [0, 2]])
|
x = v.sparse_read([[2, 1], [0, 2]])
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array(
|
np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
|
||||||
[[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2, 3], [8, 9, 10,
|
[[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
|
||||||
11]]],
|
sess.run(x))
|
||||||
dtype=dtype), sess.run(x))
|
|
||||||
|
|
||||||
def testSparseRead2DIndices3DTensor(self):
|
def testSparseRead2DIndices3DTensor(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
init = np.array(
|
init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
|
||||||
[[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]],
|
[[20, 21, 22], [23, 24j, 25]],
|
||||||
[[20, 21, 22], [23, 24, 25]], [[30, 31, 32], [33, 34, 35]]],
|
[[30, 31, 32], [33, 34, 35]]]).astype(dtype)
|
||||||
dtype=dtype)
|
|
||||||
with self.test_session() as sess, self.test_scope():
|
with self.test_session() as sess, self.test_scope():
|
||||||
v = resource_variable_ops.ResourceVariable(init)
|
v = resource_variable_ops.ResourceVariable(init)
|
||||||
sess.run(variables.variables_initializer([v]))
|
sess.run(variables.variables_initializer([v]))
|
||||||
x = v.sparse_read([[2, 1], [3, 0]])
|
x = v.sparse_read([[2, 1], [3, 0]])
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
np.array(
|
np.array(
|
||||||
[[[[20, 21, 22], [23, 24, 25]], [[10, 11, 12], [13, 14, 15]]],
|
[[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
|
||||||
[[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]],
|
[[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]],
|
||||||
dtype=dtype), sess.run(x))
|
).astype(dtype), sess.run(x))
|
||||||
|
|
||||||
def testReadWrite(self):
|
def testReadWrite(self):
|
||||||
"""Tests initialization, reading, and writing a resource variable."""
|
"""Tests initialization, reading, and writing a resource variable."""
|
||||||
|
for dtype in self.numeric_types:
|
||||||
with self.test_session() as session:
|
with self.test_session() as session:
|
||||||
|
print(ops.get_default_graph())
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
with variable_scope.variable_scope("ascope", use_resource=True):
|
with variable_scope.variable_scope("ascope", use_resource=True):
|
||||||
x = variable_scope.get_variable(
|
x = variable_scope.get_variable(
|
||||||
"x",
|
"x",
|
||||||
shape=[],
|
shape=[],
|
||||||
dtype=dtypes.float32,
|
dtype=dtype,
|
||||||
initializer=init_ops.constant_initializer(2))
|
initializer=init_ops.constant_initializer(2))
|
||||||
a = x.read_value()
|
a = x.read_value()
|
||||||
with ops.control_dependencies([a]):
|
with ops.control_dependencies([a]):
|
||||||
b = state_ops.assign(x, 47)
|
b = state_ops.assign(x, dtype(47))
|
||||||
with ops.control_dependencies([b]):
|
with ops.control_dependencies([b]):
|
||||||
c = x.read_value()
|
c = x.read_value()
|
||||||
with ops.control_dependencies([c]):
|
with ops.control_dependencies([c]):
|
||||||
d = state_ops.assign_add(x, 3)
|
d = state_ops.assign_add(x, np.array(6 + 2j).astype(dtype))
|
||||||
with ops.control_dependencies([d]):
|
with ops.control_dependencies([d]):
|
||||||
e = x.read_value()
|
e = state_ops.assign_sub(x, dtype(3))
|
||||||
|
with ops.control_dependencies([e]):
|
||||||
|
f = x.read_value()
|
||||||
|
|
||||||
session.run(variables.global_variables_initializer())
|
session.run(variables.global_variables_initializer())
|
||||||
v1, v2, v3 = session.run([a, c, e])
|
v1, v2, v3 = session.run([a, c, f])
|
||||||
self.assertAllClose(2.0, v1)
|
self.assertAllClose(dtype(2), v1)
|
||||||
self.assertAllClose(47.0, v2)
|
self.assertAllClose(dtype(47), v2)
|
||||||
self.assertAllClose(50.0, v3)
|
self.assertAllClose(np.array(50 + 2j).astype(dtype), v3)
|
||||||
|
|
||||||
def testTraining(self):
|
def testTraining(self):
|
||||||
"""Tests a gradient descent step for a simple model."""
|
"""Tests a gradient descent step for a simple model."""
|
||||||
|
@ -63,12 +63,19 @@ class XLATestCase(test.TestCase):
|
|||||||
self.float_tf_types = [
|
self.float_tf_types = [
|
||||||
dtype for dtype in self.all_tf_types if dtype.is_floating
|
dtype for dtype in self.all_tf_types if dtype.is_floating
|
||||||
]
|
]
|
||||||
self.numeric_tf_types = self.int_tf_types + self.float_tf_types
|
self.complex_tf_types = [
|
||||||
|
dtype for dtype in self.all_tf_types if dtype.is_complex
|
||||||
|
]
|
||||||
|
self.numeric_tf_types = (
|
||||||
|
self.int_tf_types + self.float_tf_types + self.complex_tf_types)
|
||||||
|
|
||||||
self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
|
self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
|
||||||
self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types]
|
self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types]
|
||||||
self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types]
|
self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types]
|
||||||
self.numeric_types = self.int_types + self.float_types
|
self.complex_types = [
|
||||||
|
dtype.as_numpy_dtype for dtype in self.complex_tf_types
|
||||||
|
]
|
||||||
|
self.numeric_types = self.int_types + self.float_types + self.complex_types
|
||||||
|
|
||||||
# Parse the manifest file, if any, into a regex identifying tests to
|
# Parse the manifest file, if any, into a regex identifying tests to
|
||||||
# disable
|
# disable
|
||||||
|
@ -77,7 +77,13 @@ class BatchMatMulOp : public XlaOpKernel {
|
|||||||
xla::ComputationBuilder* builder = ctx->builder();
|
xla::ComputationBuilder* builder = ctx->builder();
|
||||||
|
|
||||||
xla::ComputationDataHandle x_handle = ctx->Input(0);
|
xla::ComputationDataHandle x_handle = ctx->Input(0);
|
||||||
|
if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) {
|
||||||
|
x_handle = builder->Conj(x_handle);
|
||||||
|
}
|
||||||
xla::ComputationDataHandle y_handle = ctx->Input(1);
|
xla::ComputationDataHandle y_handle = ctx->Input(1);
|
||||||
|
if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) {
|
||||||
|
y_handle = builder->Conj(y_handle);
|
||||||
|
}
|
||||||
|
|
||||||
// Reshape input tensors into 3D tensors by flattening the batch
|
// Reshape input tensors into 3D tensors by flattening the batch
|
||||||
// dimensions. This makes it easier to unroll the batch dimension.
|
// dimensions. This makes it easier to unroll the batch dimension.
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// Native XLA implementations of simple unary Ops
|
// Native XLA implementations of simple binary Ops
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
|
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -50,6 +51,9 @@ XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions));
|
|||||||
XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions));
|
XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions));
|
||||||
XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions));
|
XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions));
|
||||||
|
|
||||||
|
XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions));
|
||||||
|
XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions));
|
||||||
|
|
||||||
// Implementation of FloorDiv. Pseudo-code:
|
// Implementation of FloorDiv. Pseudo-code:
|
||||||
// if ((x < 0) != (y < 0)) {
|
// if ((x < 0) != (y < 0)) {
|
||||||
// T abs_x = std::abs(x);
|
// T abs_x = std::abs(x);
|
||||||
@ -171,8 +175,12 @@ class ApproximateEqualOp : public XlaOpKernel {
|
|||||||
// Computes the max of the scalar input x and 0.
|
// Computes the max of the scalar input x and 0.
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
xla::ComputationBuilder* b = ctx->builder();
|
xla::ComputationBuilder* b = ctx->builder();
|
||||||
auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))),
|
auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1)));
|
||||||
XlaHelpers::FloatLiteral(b, input_type(0), tolerance_));
|
auto abs_shape = b->GetShape(abs);
|
||||||
|
OP_REQUIRES_OK(ctx, abs_shape.status());
|
||||||
|
auto abs_type = abs_shape.ValueOrDie()->element_type();
|
||||||
|
auto result = b->Lt(
|
||||||
|
abs, b->ConvertElementType(b->ConstantR0<float>(tolerance_), abs_type));
|
||||||
ctx->SetOutput(0, result);
|
ctx->SetOutput(0, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -40,6 +41,11 @@ class CastOp : public XlaOpKernel {
|
|||||||
output = input;
|
output = input;
|
||||||
} else if (dst_dtype_ == DT_BOOL) {
|
} else if (dst_dtype_ == DT_BOOL) {
|
||||||
output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_));
|
output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_));
|
||||||
|
} else if (xla::primitive_util::IsComplexType(src_type_) &&
|
||||||
|
!xla::primitive_util::IsComplexType(dst_type_)) {
|
||||||
|
// As in cast_op.h, we replicate the numpy behavior of truncating the
|
||||||
|
// imaginary part.
|
||||||
|
output = builder->ConvertElementType(builder->Real(input), dst_type_);
|
||||||
} else {
|
} else {
|
||||||
output = builder->ConvertElementType(input, dst_type_);
|
output = builder->ConvertElementType(input, dst_type_);
|
||||||
}
|
}
|
||||||
|
@ -192,7 +192,7 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) {
|
|||||||
errors::InvalidArgument("indices must be int32 or int64"));
|
errors::InvalidArgument("indices must be int32 or int64"));
|
||||||
|
|
||||||
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
|
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
|
||||||
context, input, input_shape, indices, indices_shape, axis, DT_FLOAT,
|
context, input, input_shape, indices, indices_shape, axis, input_type(0),
|
||||||
index_type, builder);
|
index_type, builder);
|
||||||
context->SetOutput(0, gather);
|
context->SetOutput(0, gather);
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,9 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
constexpr std::array<DataType, 4> kMatmulTypes = {
|
||||||
|
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
|
||||||
|
|
||||||
class MatMulOp : public XlaOpKernel {
|
class MatMulOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false)
|
explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false)
|
||||||
@ -73,7 +76,7 @@ class MatMulOp : public XlaOpKernel {
|
|||||||
bool transpose_b_;
|
bool transpose_b_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kFloatTypes), MatMulOp);
|
REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kMatmulTypes), MatMulOp);
|
||||||
|
|
||||||
class SparseMatMulOp : public MatMulOp {
|
class SparseMatMulOp : public MatMulOp {
|
||||||
public:
|
public:
|
||||||
|
@ -37,7 +37,8 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyGradientDescent"),
|
REGISTER_XLA_OP(
|
||||||
|
Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes),
|
||||||
ResourceApplyGradientDescent);
|
ResourceApplyGradientDescent);
|
||||||
|
|
||||||
class ResourceApplyMomentum : public XlaOpKernel {
|
class ResourceApplyMomentum : public XlaOpKernel {
|
||||||
@ -109,7 +110,8 @@ class ResourceApplyMomentum : public XlaOpKernel {
|
|||||||
private:
|
private:
|
||||||
bool use_nesterov_;
|
bool use_nesterov_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyMomentum"), ResourceApplyMomentum);
|
REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes),
|
||||||
|
ResourceApplyMomentum);
|
||||||
|
|
||||||
class ResourceApplyAdagrad : public XlaOpKernel {
|
class ResourceApplyAdagrad : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
@ -163,7 +165,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyAdagrad"), ResourceApplyAdagrad);
|
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
|
||||||
|
ResourceApplyAdagrad);
|
||||||
|
|
||||||
class ResourceApplyAdam : public XlaOpKernel {
|
class ResourceApplyAdam : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
@ -263,7 +266,8 @@ class ResourceApplyAdam : public XlaOpKernel {
|
|||||||
private:
|
private:
|
||||||
DataType dtype_;
|
DataType dtype_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyAdam"), ResourceApplyAdam);
|
REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes),
|
||||||
|
ResourceApplyAdam);
|
||||||
|
|
||||||
class ResourceApplyRMSProp : public XlaOpKernel {
|
class ResourceApplyRMSProp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
@ -362,7 +366,8 @@ class ResourceApplyRMSProp : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom));
|
OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp);
|
REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
|
||||||
|
ResourceApplyRMSProp);
|
||||||
|
|
||||||
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
|
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
|
||||||
bool has_l2_shrinkage) {
|
bool has_l2_shrinkage) {
|
||||||
@ -500,7 +505,8 @@ class ResourceApplyFtrl : public XlaOpKernel {
|
|||||||
private:
|
private:
|
||||||
DataType dtype_;
|
DataType dtype_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl);
|
REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes),
|
||||||
|
ResourceApplyFtrl);
|
||||||
|
|
||||||
class ResourceApplyFtrlV2 : public XlaOpKernel {
|
class ResourceApplyFtrlV2 : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
@ -515,7 +521,8 @@ class ResourceApplyFtrlV2 : public XlaOpKernel {
|
|||||||
private:
|
private:
|
||||||
DataType dtype_;
|
DataType dtype_;
|
||||||
};
|
};
|
||||||
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2"), ResourceApplyFtrlV2);
|
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
|
||||||
|
ResourceApplyFtrlV2);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -41,6 +41,12 @@ namespace {
|
|||||||
}; \
|
}; \
|
||||||
REGISTER_XLA_OP(Name(#NAME), NAME##Op);
|
REGISTER_XLA_OP(Name(#NAME), NAME##Op);
|
||||||
|
|
||||||
|
XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x));
|
||||||
|
|
||||||
|
XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x)));
|
||||||
|
|
||||||
|
XLAJIT_MAKE_UNARY(Conj, b->Complex(b->Real(x), b->Neg(b->Imag(x))));
|
||||||
|
|
||||||
// Return x if x>0, otherwise -x.
|
// Return x if x>0, otherwise -x.
|
||||||
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
|
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
|
||||||
|
|
||||||
@ -162,6 +168,9 @@ XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
|
|||||||
XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x)));
|
XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x)));
|
||||||
XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x));
|
XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x));
|
||||||
|
|
||||||
|
XLAJIT_MAKE_UNARY(Real, b->Real(x));
|
||||||
|
XLAJIT_MAKE_UNARY(Imag, b->Imag(x));
|
||||||
|
|
||||||
#undef XLAJIT_MAKE_UNARY
|
#undef XLAJIT_MAKE_UNARY
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -97,6 +97,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
|
|||||||
case xla::F64:
|
case xla::F64:
|
||||||
literal = *xla::Literal::CreateR0<double>(value);
|
literal = *xla::Literal::CreateR0<double>(value);
|
||||||
break;
|
break;
|
||||||
|
case xla::C64:
|
||||||
|
literal = *xla::Literal::CreateR0<complex64>(value);
|
||||||
|
break;
|
||||||
case xla::PRED:
|
case xla::PRED:
|
||||||
LOG(FATAL) << "pred element type is not integral";
|
LOG(FATAL) << "pred element type is not integral";
|
||||||
case xla::S16:
|
case xla::S16:
|
||||||
@ -132,6 +135,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
|
|||||||
case xla::F64:
|
case xla::F64:
|
||||||
return b->ConstantR0<double>(value);
|
return b->ConstantR0<double>(value);
|
||||||
break;
|
break;
|
||||||
|
case xla::C64:
|
||||||
|
return b->ConstantR0<complex64>(value);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unhandled element type " << type;
|
LOG(FATAL) << "unhandled element type " << type;
|
||||||
}
|
}
|
||||||
|
@ -47,14 +47,17 @@ extern const char* const DEVICE_XLA_GPU;
|
|||||||
|
|
||||||
constexpr std::array<DataType, 3> kFloatTypes = {
|
constexpr std::array<DataType, 3> kFloatTypes = {
|
||||||
{DT_HALF, DT_FLOAT, DT_DOUBLE}};
|
{DT_HALF, DT_FLOAT, DT_DOUBLE}};
|
||||||
constexpr std::array<DataType, 7> kNumericTypes = {
|
constexpr std::array<DataType, 8> kNumericTypes = {
|
||||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}};
|
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
|
||||||
|
DT_COMPLEX64}};
|
||||||
|
|
||||||
constexpr std::array<DataType, 7> kCpuAllTypes = {
|
constexpr std::array<DataType, 8> kCpuAllTypes = {
|
||||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
|
||||||
|
DT_COMPLEX64, DT_BOOL}};
|
||||||
|
|
||||||
constexpr std::array<DataType, 7> kGpuAllTypes = {
|
constexpr std::array<DataType, 8> kGpuAllTypes = {
|
||||||
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
|
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
|
||||||
|
DT_COMPLEX64, DT_BOOL}};
|
||||||
|
|
||||||
// Class that manages registrations of operators and devices for the XLA JIT.
|
// Class that manages registrations of operators and devices for the XLA JIT.
|
||||||
// Not thread-safe.
|
// Not thread-safe.
|
||||||
|
@ -913,6 +913,17 @@ ComputationDataHandle ComputationBuilder::CustomCall(
|
|||||||
return ParseOpResponse(s, &response);
|
return ParseOpResponse(s, &response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle ComputationBuilder::Complex(
|
||||||
|
const ComputationDataHandle& real, const ComputationDataHandle& imag,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||||
|
return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle ComputationBuilder::Conj(
|
||||||
|
const ComputationDataHandle& operand) {
|
||||||
|
return Complex(Real(operand), Neg(Imag(operand)));
|
||||||
|
}
|
||||||
|
|
||||||
ComputationDataHandle ComputationBuilder::Add(
|
ComputationDataHandle ComputationBuilder::Add(
|
||||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||||
@ -995,6 +1006,12 @@ ComputationDataHandle ComputationBuilder::Abs(
|
|||||||
return UnaryOp(UNOP_ABS, operand);
|
return UnaryOp(UNOP_ABS, operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle ComputationBuilder::Atan2(
|
||||||
|
const ComputationDataHandle& y, const ComputationDataHandle& x,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
|
||||||
|
return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
ComputationDataHandle ComputationBuilder::Exp(
|
ComputationDataHandle ComputationBuilder::Exp(
|
||||||
const ComputationDataHandle& operand) {
|
const ComputationDataHandle& operand) {
|
||||||
return UnaryOp(UNOP_EXP, operand);
|
return UnaryOp(UNOP_EXP, operand);
|
||||||
@ -1040,6 +1057,16 @@ ComputationDataHandle ComputationBuilder::Tanh(
|
|||||||
return UnaryOp(UNOP_TANH, operand);
|
return UnaryOp(UNOP_TANH, operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle ComputationBuilder::Real(
|
||||||
|
const ComputationDataHandle& operand) {
|
||||||
|
return UnaryOp(UNOP_REAL, operand);
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle ComputationBuilder::Imag(
|
||||||
|
const ComputationDataHandle& operand) {
|
||||||
|
return UnaryOp(UNOP_IMAG, operand);
|
||||||
|
}
|
||||||
|
|
||||||
ComputationDataHandle ComputationBuilder::IsFinite(
|
ComputationDataHandle ComputationBuilder::IsFinite(
|
||||||
const ComputationDataHandle& operand) {
|
const ComputationDataHandle& operand) {
|
||||||
return UnaryOp(UNOP_IS_FINITE, operand);
|
return UnaryOp(UNOP_IS_FINITE, operand);
|
||||||
|
@ -431,6 +431,14 @@ class ComputationBuilder {
|
|||||||
// of the operands is a scalar, or an explicit broadcast dimension is given
|
// of the operands is a scalar, or an explicit broadcast dimension is given
|
||||||
// (see g3doc for more details).
|
// (see g3doc for more details).
|
||||||
|
|
||||||
|
// Enqueues a complex compose instruction onto the computation.
|
||||||
|
ComputationDataHandle Complex(
|
||||||
|
const ComputationDataHandle& real, const ComputationDataHandle& imag,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
|
// Enqueues a complex conjugate instruction onto the computation.
|
||||||
|
ComputationDataHandle Conj(const ComputationDataHandle& operand);
|
||||||
|
|
||||||
// Enqueues an add instruction onto the computation.
|
// Enqueues an add instruction onto the computation.
|
||||||
ComputationDataHandle Add(
|
ComputationDataHandle Add(
|
||||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||||
@ -542,6 +550,11 @@ class ComputationBuilder {
|
|||||||
// Enqueues an abs instruction onto the computation.
|
// Enqueues an abs instruction onto the computation.
|
||||||
ComputationDataHandle Abs(const ComputationDataHandle& operand);
|
ComputationDataHandle Abs(const ComputationDataHandle& operand);
|
||||||
|
|
||||||
|
// Enqueues a atan2 instruction onto the computation.
|
||||||
|
ComputationDataHandle Atan2(
|
||||||
|
const ComputationDataHandle& y, const ComputationDataHandle& x,
|
||||||
|
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues an exp instruction onto the computation.
|
// Enqueues an exp instruction onto the computation.
|
||||||
ComputationDataHandle Exp(const ComputationDataHandle& operand);
|
ComputationDataHandle Exp(const ComputationDataHandle& operand);
|
||||||
|
|
||||||
@ -570,6 +583,12 @@ class ComputationBuilder {
|
|||||||
// Enqueues a tanh instruction onto the computation.
|
// Enqueues a tanh instruction onto the computation.
|
||||||
ComputationDataHandle Tanh(const ComputationDataHandle& operand);
|
ComputationDataHandle Tanh(const ComputationDataHandle& operand);
|
||||||
|
|
||||||
|
// Enqueues a real-part instruction onto the computation.
|
||||||
|
ComputationDataHandle Real(const ComputationDataHandle& operand);
|
||||||
|
|
||||||
|
// Enqueues an imaginary-part instruction onto the computation.
|
||||||
|
ComputationDataHandle Imag(const ComputationDataHandle& operand);
|
||||||
|
|
||||||
// Enqueues a float32 sqrt instruction onto the computation.
|
// Enqueues a float32 sqrt instruction onto the computation.
|
||||||
// (float32 is specified as there is an implicit float32 0.5f constant
|
// (float32 is specified as there is an implicit float32 0.5f constant
|
||||||
// exponent).
|
// exponent).
|
||||||
|
@ -204,6 +204,8 @@ Status Literal::Copy(const Literal& src_literal,
|
|||||||
return *Literal::CreateR0<float>(0);
|
return *Literal::CreateR0<float>(0);
|
||||||
case F64:
|
case F64:
|
||||||
return *Literal::CreateR0<double>(0);
|
return *Literal::CreateR0<double>(0);
|
||||||
|
case C64:
|
||||||
|
return *Literal::CreateR0<complex64>(0);
|
||||||
case PRED:
|
case PRED:
|
||||||
return *Literal::CreateR0<bool>(false);
|
return *Literal::CreateR0<bool>(false);
|
||||||
case S16:
|
case S16:
|
||||||
@ -236,6 +238,8 @@ Status Literal::Copy(const Literal& src_literal,
|
|||||||
return *Literal::CreateR0<float>(1);
|
return *Literal::CreateR0<float>(1);
|
||||||
case F64:
|
case F64:
|
||||||
return *Literal::CreateR0<double>(1);
|
return *Literal::CreateR0<double>(1);
|
||||||
|
case C64:
|
||||||
|
return *Literal::CreateR0<complex64>(1);
|
||||||
case PRED:
|
case PRED:
|
||||||
return *Literal::CreateR0<bool>(true);
|
return *Literal::CreateR0<bool>(true);
|
||||||
case S16:
|
case S16:
|
||||||
@ -271,6 +275,8 @@ Status Literal::Copy(const Literal& src_literal,
|
|||||||
case F64:
|
case F64:
|
||||||
return *Literal::CreateR0<double>(
|
return *Literal::CreateR0<double>(
|
||||||
-std::numeric_limits<double>::infinity());
|
-std::numeric_limits<double>::infinity());
|
||||||
|
case C64:
|
||||||
|
LOG(FATAL) << "C64 element type has no minimum value";
|
||||||
case PRED:
|
case PRED:
|
||||||
return *Literal::CreateR0<bool>(false);
|
return *Literal::CreateR0<bool>(false);
|
||||||
case S16:
|
case S16:
|
||||||
|
@ -141,6 +141,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
Status HandleConvert(HloInstruction* convert) override;
|
Status HandleConvert(HloInstruction* convert) override;
|
||||||
|
|
||||||
|
Status HandleReal(HloInstruction* real, HloInstruction* operand) override;
|
||||||
|
Status HandleImag(HloInstruction* imag, HloInstruction* operand) override;
|
||||||
|
|
||||||
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
|
Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
|
||||||
HloInstruction* rhs, const Window& window) override;
|
HloInstruction* rhs, const Window& window) override;
|
||||||
|
|
||||||
@ -967,6 +970,24 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Real(Complex(r, i)) -> r
|
||||||
|
Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real,
|
||||||
|
HloInstruction* operand) {
|
||||||
|
if (operand->opcode() == HloOpcode::kComplex) {
|
||||||
|
return ReplaceInstruction(real, operand->mutable_operand(0));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Imag(Complex(r, i)) -> i
|
||||||
|
Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag,
|
||||||
|
HloInstruction* operand) {
|
||||||
|
if (operand->opcode() == HloOpcode::kComplex) {
|
||||||
|
return ReplaceInstruction(imag, operand->mutable_operand(1));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
||||||
// Eliminate nop pads (padding all zero), and replace a pad with negative
|
// Eliminate nop pads (padding all zero), and replace a pad with negative
|
||||||
// padding with a pad with non-negative padding followed by a slice.
|
// padding with a pad with non-negative padding followed by a slice.
|
||||||
|
@ -433,6 +433,56 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
|
|||||||
EXPECT_EQ(root, param0);
|
EXPECT_EQ(root, param0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that real(complex(r,i)) is simplified to r.
|
||||||
|
TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
|
||||||
|
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
HloInstruction* param0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||||
|
HloInstruction* param1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(1, r2f32, "param1"));
|
||||||
|
HloInstruction* cplx = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
|
||||||
|
HloOpcode::kComplex, param0, param1));
|
||||||
|
HloInstruction* real = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
|
||||||
|
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root, real);
|
||||||
|
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||||
|
non_bitcasting_callback());
|
||||||
|
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root, param0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that imag(complex(r,i)) is simplified to i.
|
||||||
|
TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
|
||||||
|
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
HloInstruction* param0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, r2f32, "param0"));
|
||||||
|
HloInstruction* param1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(1, r2f32, "param1"));
|
||||||
|
HloInstruction* cplx = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
|
||||||
|
HloOpcode::kComplex, param0, param1));
|
||||||
|
HloInstruction* imag = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
|
||||||
|
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
HloInstruction* root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root, imag);
|
||||||
|
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||||
|
non_bitcasting_callback());
|
||||||
|
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||||
|
root = computation->root_instruction();
|
||||||
|
EXPECT_EQ(root, param1);
|
||||||
|
}
|
||||||
|
|
||||||
// Test that get_element(make_tuple({A,B}),1) is simplified to B
|
// Test that get_element(make_tuple({A,B}),1) is simplified to B
|
||||||
TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
|
TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
|
||||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||||
|
@ -63,7 +63,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
|
|||||||
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
|
llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder,
|
||||||
const HloModuleConfig& hlo_module_config) {
|
const HloModuleConfig& hlo_module_config) {
|
||||||
PrimitiveType type = target_array.GetShape().element_type();
|
PrimitiveType type = target_array.GetShape().element_type();
|
||||||
TF_RET_CHECK(F32 == type || F64 == type);
|
TF_RET_CHECK(F32 == type || F64 == type || C64 == type);
|
||||||
DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
|
DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
|
||||||
lhs_array, rhs_array, executable_run_options_value,
|
lhs_array, rhs_array, executable_run_options_value,
|
||||||
ir_builder, hlo_module_config);
|
ir_builder, hlo_module_config);
|
||||||
@ -176,7 +176,7 @@ tensorflow::Status DotOpEmitter::Emit() {
|
|||||||
llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
|
llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
|
||||||
ir_builder_->SetInsertPoint(preheader_bb->getTerminator());
|
ir_builder_->SetInsertPoint(preheader_bb->getTerminator());
|
||||||
|
|
||||||
ir_builder_->CreateStore(llvm::ConstantFP::get(accum_type, 0.0),
|
ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type),
|
||||||
accum_address);
|
accum_address);
|
||||||
|
|
||||||
// Body basic block of reduction loop:
|
// Body basic block of reduction loop:
|
||||||
@ -191,9 +191,29 @@ tensorflow::Status DotOpEmitter::Emit() {
|
|||||||
llvm::Value* rhs_element =
|
llvm::Value* rhs_element =
|
||||||
rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_);
|
rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_);
|
||||||
|
|
||||||
llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
|
|
||||||
llvm::Value* accum = ir_builder_->CreateLoad(accum_address);
|
llvm::Value* accum = ir_builder_->CreateLoad(accum_address);
|
||||||
llvm::Value* updated_accum = ir_builder_->CreateFAdd(accum, product);
|
llvm::Value* updated_accum;
|
||||||
|
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
|
||||||
|
auto real = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {0});
|
||||||
|
};
|
||||||
|
auto imag = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {1});
|
||||||
|
};
|
||||||
|
llvm::Value* product_real = ir_builder_->CreateFSub(
|
||||||
|
ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)),
|
||||||
|
ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element)));
|
||||||
|
llvm::Value* product_imag = ir_builder_->CreateFAdd(
|
||||||
|
ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)),
|
||||||
|
ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element)));
|
||||||
|
updated_accum = ir_builder_->CreateInsertValue(
|
||||||
|
accum, ir_builder_->CreateFAdd(real(accum), product_real), {0});
|
||||||
|
updated_accum = ir_builder_->CreateInsertValue(
|
||||||
|
updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1});
|
||||||
|
} else {
|
||||||
|
llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
|
||||||
|
updated_accum = ir_builder_->CreateFAdd(accum, product);
|
||||||
|
}
|
||||||
ir_builder_->CreateStore(updated_accum, accum_address);
|
ir_builder_->CreateStore(updated_accum, accum_address);
|
||||||
|
|
||||||
// Exit basic block of reduction loop.
|
// Exit basic block of reduction loop.
|
||||||
@ -230,11 +250,28 @@ tensorflow::Status DotOpEmitter::Emit() {
|
|||||||
|
|
||||||
tensorflow::Status DotOpEmitter::EmitScalarDot() {
|
tensorflow::Status DotOpEmitter::EmitScalarDot() {
|
||||||
// A scalar dot is just a scalar multiply.
|
// A scalar dot is just a scalar multiply.
|
||||||
|
llvm::Value* result;
|
||||||
llvm::Value* lhs_value =
|
llvm::Value* lhs_value =
|
||||||
lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
|
lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
|
||||||
llvm::Value* rhs_value =
|
llvm::Value* rhs_value =
|
||||||
rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
|
rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
|
||||||
llvm::Value* result = ir_builder_->CreateFMul(lhs_value, rhs_value);
|
if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
|
||||||
|
#define REAL(x) ir_builder_->CreateExtractValue(x, {0})
|
||||||
|
#define IMAG(x) ir_builder_->CreateExtractValue(x, {1})
|
||||||
|
llvm::Value* real = ir_builder_->CreateFSub(
|
||||||
|
ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value)));
|
||||||
|
llvm::Value* imag = ir_builder_->CreateFAdd(
|
||||||
|
ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value)));
|
||||||
|
#undef IMAG
|
||||||
|
#undef REAL
|
||||||
|
result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType());
|
||||||
|
result = ir_builder_->CreateInsertValue(result, real, {0});
|
||||||
|
result = ir_builder_->CreateInsertValue(result, imag, {1});
|
||||||
|
} else {
|
||||||
|
result = ir_builder_->CreateFMul(lhs_value, rhs_value);
|
||||||
|
}
|
||||||
target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
|
target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -46,8 +46,8 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
|
|||||||
}
|
}
|
||||||
// Create function type for the function.
|
// Create function type for the function.
|
||||||
llvm::FunctionType* function_type = llvm::FunctionType::get(
|
llvm::FunctionType* function_type = llvm::FunctionType::get(
|
||||||
llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
|
||||||
llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
|
||||||
/*isVarArg=*/false);
|
/*isVarArg=*/false);
|
||||||
// Create function declaration for 'tanhf'.
|
// Create function declaration for 'tanhf'.
|
||||||
llvm::Function* function =
|
llvm::Function* function =
|
||||||
|
@ -41,6 +41,12 @@ bool PotentiallyImplementedAsEigenConvolution(
|
|||||||
ShapeUtil::HasZeroElements(kernel_shape)) {
|
ShapeUtil::HasZeroElements(kernel_shape)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// TODO(b/65408531): Explore using Eigen dot for complex64 type.
|
||||||
|
if (ShapeUtil::ElementIsComplex(input_shape) ||
|
||||||
|
ShapeUtil::ElementIsComplex(kernel_shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
const ConvolutionDimensionNumbers& dnums =
|
const ConvolutionDimensionNumbers& dnums =
|
||||||
convolution.convolution_dimension_numbers();
|
convolution.convolution_dimension_numbers();
|
||||||
// Only 1D and 2D convolutions are supported at the moment.
|
// Only 1D and 2D convolutions are supported at the moment.
|
||||||
|
@ -288,7 +288,7 @@ Status IrEmitter::HandleConstant(HloInstruction* constant,
|
|||||||
MinimumAlignmentForShape(literal.shape()));
|
MinimumAlignmentForShape(literal.shape()));
|
||||||
} else {
|
} else {
|
||||||
llvm::Constant* initializer =
|
llvm::Constant* initializer =
|
||||||
llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
|
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
|
||||||
global_for_const = new llvm::GlobalVariable(
|
global_for_const = new llvm::GlobalVariable(
|
||||||
/*Module=*/*module_,
|
/*Module=*/*module_,
|
||||||
/*Type=*/initializer->getType(),
|
/*Type=*/initializer->getType(),
|
||||||
@ -401,7 +401,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
|
|||||||
const Shape& shape = get_tuple_element->shape();
|
const Shape& shape = get_tuple_element->shape();
|
||||||
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
|
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
|
||||||
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
|
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
|
||||||
GetEmittedValueFor(operand), &ir_builder_);
|
GetEmittedValueFor(operand), &ir_builder_, module_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -412,9 +412,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
|
|||||||
|
|
||||||
if (ShapeUtil::IsTuple(select->shape())) {
|
if (ShapeUtil::IsTuple(select->shape())) {
|
||||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select));
|
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select));
|
||||||
llvm_ir::EmitTupleSelect(GetIrArrayFor(select), GetIrArrayFor(pred),
|
llvm_ir::EmitTupleSelect(
|
||||||
GetEmittedValueFor(on_true),
|
GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true),
|
||||||
GetEmittedValueFor(on_false), &ir_builder_);
|
GetEmittedValueFor(on_false), &ir_builder_, module_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -459,7 +459,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
|
|||||||
tuple_element_addresses.push_back(tuple_element_address);
|
tuple_element_addresses.push_back(tuple_element_address);
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_);
|
llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_,
|
||||||
|
module_);
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape,
|
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape,
|
||||||
GetEmittedValueFor(infeed)));
|
GetEmittedValueFor(infeed)));
|
||||||
@ -562,7 +563,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
|||||||
ShapeUtil::GetTupleElementShape(operand_shape, i);
|
ShapeUtil::GetTupleElementShape(operand_shape, i);
|
||||||
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
|
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
|
||||||
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
|
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
|
||||||
value, &ir_builder_);
|
value, &ir_builder_, module_);
|
||||||
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
|
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
|
||||||
tuple_element_shape, tuple_element));
|
tuple_element_shape, tuple_element));
|
||||||
}
|
}
|
||||||
@ -583,7 +584,7 @@ Status IrEmitter::HandleTuple(
|
|||||||
for (auto operand : operands) {
|
for (auto operand : operands) {
|
||||||
base_ptrs.push_back(GetEmittedValueFor(operand));
|
base_ptrs.push_back(GetEmittedValueFor(operand));
|
||||||
}
|
}
|
||||||
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_);
|
llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -644,7 +645,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window,
|
|||||||
// the initial value on the reduce_window.
|
// the initial value on the reduce_window.
|
||||||
PrimitiveType operand_element_type = operand->shape().element_type();
|
PrimitiveType operand_element_type = operand->shape().element_type();
|
||||||
llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
|
||||||
"reduce_window_accumulator_address", &ir_builder_,
|
"reduce_window_accumulator_address", &ir_builder_,
|
||||||
MinimumAlignmentForPrimitiveType(operand_element_type));
|
MinimumAlignmentForPrimitiveType(operand_element_type));
|
||||||
ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
|
ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
|
||||||
@ -769,7 +770,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
|||||||
// Allocate space to keep the currently selected value, its index, and
|
// Allocate space to keep the currently selected value, its index, and
|
||||||
// the boolean initialized_flag, which is initially set to false.
|
// the boolean initialized_flag, which is initially set to false.
|
||||||
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
|
||||||
"selected_value_address", &ir_builder_,
|
"selected_value_address", &ir_builder_,
|
||||||
MinimumAlignmentForPrimitiveType(operand_element_type));
|
MinimumAlignmentForPrimitiveType(operand_element_type));
|
||||||
llvm::Value* selected_index_address =
|
llvm::Value* selected_index_address =
|
||||||
@ -851,8 +852,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
|||||||
// If the 'select' function returns false, update the selected value and the
|
// If the 'select' function returns false, update the selected value and the
|
||||||
// index to the currently visiting operand.
|
// index to the currently visiting operand.
|
||||||
llvm::Value* cond = ir_builder_.CreateICmpNE(
|
llvm::Value* cond = ir_builder_.CreateICmpNE(
|
||||||
result, llvm::ConstantInt::get(
|
result,
|
||||||
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
|
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
|
||||||
"boolean_predicate");
|
"boolean_predicate");
|
||||||
llvm_ir::LlvmIfData if_select_lhs =
|
llvm_ir::LlvmIfData if_select_lhs =
|
||||||
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
|
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
|
||||||
@ -895,7 +896,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
|||||||
HloInstruction* rhs) {
|
HloInstruction* rhs) {
|
||||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||||
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
|
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
|
||||||
/*supported_types=*/{F32, F64}));
|
/*supported_types=*/{F32, F64, C64}));
|
||||||
|
|
||||||
llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
|
llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
|
||||||
llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
|
llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
|
||||||
@ -923,7 +924,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
|
|||||||
const Window& window) {
|
const Window& window) {
|
||||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||||
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
|
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
|
||||||
/*supported_types=*/{F32}));
|
/*supported_types=*/{F32, C64}));
|
||||||
|
|
||||||
const ConvolutionDimensionNumbers& dnums =
|
const ConvolutionDimensionNumbers& dnums =
|
||||||
convolution->convolution_dimension_numbers();
|
convolution->convolution_dimension_numbers();
|
||||||
@ -1079,7 +1080,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution,
|
|||||||
// the output entry at the given index.
|
// the output entry at the given index.
|
||||||
PrimitiveType lhs_element_type = lhs->shape().element_type();
|
PrimitiveType lhs_element_type = lhs->shape().element_type();
|
||||||
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, &ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_),
|
||||||
"convolution_sum_address", &ir_builder_,
|
"convolution_sum_address", &ir_builder_,
|
||||||
MinimumAlignmentForPrimitiveType(lhs_element_type));
|
MinimumAlignmentForPrimitiveType(lhs_element_type));
|
||||||
ir_builder_.CreateStore(
|
ir_builder_.CreateStore(
|
||||||
@ -1295,14 +1296,14 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
|
|||||||
PrimitiveType element_type = operand->shape().element_type();
|
PrimitiveType element_type = operand->shape().element_type();
|
||||||
// Used to calculate E(X).
|
// Used to calculate E(X).
|
||||||
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
|
||||||
"sum_address", &ir_builder_,
|
"sum_address", &ir_builder_,
|
||||||
MinimumAlignmentForPrimitiveType(element_type));
|
MinimumAlignmentForPrimitiveType(element_type));
|
||||||
|
|
||||||
// Used to calculate E(X^2).
|
// Used to calculate E(X^2).
|
||||||
llvm::Value* sum_square_address =
|
llvm::Value* sum_square_address =
|
||||||
llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(element_type, module_),
|
||||||
"sum_square_address", &ir_builder_,
|
"sum_square_address", &ir_builder_,
|
||||||
MinimumAlignmentForPrimitiveType(element_type));
|
MinimumAlignmentForPrimitiveType(element_type));
|
||||||
|
|
||||||
@ -1425,7 +1426,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) {
|
|||||||
.EmitLoop(IrName(batch_norm_training, "normalize")));
|
.EmitLoop(IrName(batch_norm_training, "normalize")));
|
||||||
|
|
||||||
llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training),
|
llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training),
|
||||||
{normalized, mean, var}, &ir_builder_);
|
{normalized, mean, var}, &ir_builder_, module_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1488,6 +1489,14 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const Shape& root_shape = root_instruction->shape();
|
const Shape& root_shape = root_instruction->shape();
|
||||||
|
if (ShapeUtil::ElementIsComplex(root_shape)) {
|
||||||
|
// TODO(b/65408531): Complex add could by done via bitcast to <float x [2N]>
|
||||||
|
// Complex multiply would be more challenging. We could perhaps use a
|
||||||
|
// strided load to get all reals in a vector, all imags in a vector, or use
|
||||||
|
// CreateShuffleVector on a bitcast to float x [2N].
|
||||||
|
*failure_reason = "complex values not supported";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
|
bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape);
|
||||||
bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
|
bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape);
|
||||||
bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
|
bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape);
|
||||||
@ -1509,7 +1518,7 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
|
|||||||
// This is visually similar to ElementalIrEmitter, though conceptually we're
|
// This is visually similar to ElementalIrEmitter, though conceptually we're
|
||||||
// doing something different here. ElementalIrEmitter emits scalar operations
|
// doing something different here. ElementalIrEmitter emits scalar operations
|
||||||
// while these emit scalar or vector operations depending on the type of the
|
// while these emit scalar or vector operations depending on the type of the
|
||||||
// operands.
|
// operands. See CreateShardedVectorType for the actual types in use here.
|
||||||
switch (root_instruction->opcode()) {
|
switch (root_instruction->opcode()) {
|
||||||
default:
|
default:
|
||||||
*failure_reason = "did not recognize root instruction opcode";
|
*failure_reason = "did not recognize root instruction opcode";
|
||||||
@ -1586,7 +1595,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
|
|||||||
|
|
||||||
ShardedVectorType sharded_vector_type;
|
ShardedVectorType sharded_vector_type;
|
||||||
llvm::Type* element_ir_type =
|
llvm::Type* element_ir_type =
|
||||||
llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_);
|
llvm_ir::PrimitiveTypeToIrType(element_type, module_);
|
||||||
|
|
||||||
for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
|
for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) {
|
||||||
// For every power of two present in element_count, we generate one or more
|
// For every power of two present in element_count, we generate one or more
|
||||||
@ -1919,7 +1928,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
|||||||
// Initialize an accumulator with init_value.
|
// Initialize an accumulator with init_value.
|
||||||
PrimitiveType accumulator_type = reduce->shape().element_type();
|
PrimitiveType accumulator_type = reduce->shape().element_type();
|
||||||
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(accumulator_type, &ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_),
|
||||||
"accumulator", &ir_builder_,
|
"accumulator", &ir_builder_,
|
||||||
MinimumAlignmentForPrimitiveType(accumulator_type));
|
MinimumAlignmentForPrimitiveType(accumulator_type));
|
||||||
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
|
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
|
||||||
@ -2248,6 +2257,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion,
|
} else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion,
|
||||||
assignment_)) {
|
assignment_)) {
|
||||||
|
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
|
||||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
|
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
|
||||||
|
|
||||||
@ -2257,6 +2267,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
|||||||
fusion, operands, GetIrArrayFor(fusion), &elemental_emitter,
|
fusion, operands, GetIrArrayFor(fusion), &elemental_emitter,
|
||||||
&ir_builder_);
|
&ir_builder_);
|
||||||
} else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
} else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||||
|
VLOG(3) << "HandleFusion kLoop";
|
||||||
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
|
||||||
auto operands = GetIrArraysForOperandsOf(fusion);
|
auto operands = GetIrArraysForOperandsOf(fusion);
|
||||||
FusedIrEmitter fused_emitter(operands, &elemental_emitter);
|
FusedIrEmitter fused_emitter(operands, &elemental_emitter);
|
||||||
@ -2400,8 +2411,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
|
|||||||
{while_result}, IrName(xla_while, "cond"));
|
{while_result}, IrName(xla_while, "cond"));
|
||||||
llvm::Value* while_predicate = ir_builder_.CreateICmpNE(
|
llvm::Value* while_predicate = ir_builder_.CreateICmpNE(
|
||||||
while_condition,
|
while_condition,
|
||||||
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_),
|
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
|
||||||
0));
|
|
||||||
|
|
||||||
// Branches to the body or to the while exit depending on the condition.
|
// Branches to the body or to the while exit depending on the condition.
|
||||||
llvm::BasicBlock* body_bb = llvm::BasicBlock::Create(
|
llvm::BasicBlock* body_bb = llvm::BasicBlock::Create(
|
||||||
@ -2542,7 +2552,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
|
|||||||
unsigned element_alignment = GCD(
|
unsigned element_alignment = GCD(
|
||||||
primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
|
primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
|
||||||
llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
|
llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
|
||||||
llvm_ir::PrimitiveTypeToIrType(primitive_type, &ir_builder_));
|
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
|
||||||
|
|
||||||
if (element_count == 1) {
|
if (element_count == 1) {
|
||||||
auto* load_instruction = ir_builder_.CreateAlignedLoad(
|
auto* load_instruction = ir_builder_.CreateAlignedLoad(
|
||||||
@ -2755,7 +2765,7 @@ llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
|
llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
|
||||||
return llvm_ir::ShapeToIrType(shape, &ir_builder_);
|
return llvm_ir::ShapeToIrType(shape, module_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llvm::Type*> IrEmitter::GetComputeFunctionParams() {
|
std::vector<llvm::Type*> IrEmitter::GetComputeFunctionParams() {
|
||||||
@ -2925,7 +2935,7 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall(
|
|||||||
PrimitiveType return_type = return_shape.element_type();
|
PrimitiveType return_type = return_shape.element_type();
|
||||||
llvm::Value* return_value_buffer =
|
llvm::Value* return_value_buffer =
|
||||||
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||||
llvm_ir::PrimitiveTypeToIrType(return_type, &ir_builder_), elements,
|
llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
|
||||||
tensorflow::strings::StrCat(name, "_return_value_address"),
|
tensorflow::strings::StrCat(name, "_return_value_address"),
|
||||||
&ir_builder_, MinimumAlignmentForPrimitiveType(return_type));
|
&ir_builder_, MinimumAlignmentForPrimitiveType(return_type));
|
||||||
EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
|
EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
|
||||||
@ -3100,7 +3110,7 @@ Status IrEmitter::EmitTargetElementLoop(
|
|||||||
for (int64 i = 0; i < output_arrays.size(); ++i) {
|
for (int64 i = 0; i < output_arrays.size(); ++i) {
|
||||||
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
|
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
|
||||||
}
|
}
|
||||||
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_);
|
llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if (ShouldEmitParallelLoopFor(*target_op)) {
|
if (ShouldEmitParallelLoopFor(*target_op)) {
|
||||||
|
@ -85,6 +85,10 @@ class DfsHloVisitor {
|
|||||||
virtual Status HandleCopy(HloInstruction* copy) {
|
virtual Status HandleCopy(HloInstruction* copy) {
|
||||||
return HandleElementwiseUnary(copy);
|
return HandleElementwiseUnary(copy);
|
||||||
}
|
}
|
||||||
|
virtual Status HandleComplex(HloInstruction* complex, HloInstruction* real,
|
||||||
|
HloInstruction* imag) {
|
||||||
|
return HandleElementwiseBinary(complex);
|
||||||
|
}
|
||||||
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
|
virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) {
|
HloInstruction* rhs) {
|
||||||
return HandleElementwiseBinary(multiply);
|
return HandleElementwiseBinary(multiply);
|
||||||
@ -122,6 +126,10 @@ class DfsHloVisitor {
|
|||||||
virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
|
virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
|
||||||
return HandleElementwiseUnary(abs);
|
return HandleElementwiseUnary(abs);
|
||||||
}
|
}
|
||||||
|
virtual Status HandleAtan2(HloInstruction* atan2, HloInstruction* y,
|
||||||
|
HloInstruction* x) {
|
||||||
|
return HandleElementwiseBinary(atan2);
|
||||||
|
}
|
||||||
virtual Status HandleRound(HloInstruction* round) {
|
virtual Status HandleRound(HloInstruction* round) {
|
||||||
return HandleElementwiseUnary(round);
|
return HandleElementwiseUnary(round);
|
||||||
}
|
}
|
||||||
@ -152,6 +160,12 @@ class DfsHloVisitor {
|
|||||||
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
|
virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
|
||||||
return HandleElementwiseUnary(tanh);
|
return HandleElementwiseUnary(tanh);
|
||||||
}
|
}
|
||||||
|
virtual Status HandleReal(HloInstruction* real, HloInstruction* operand) {
|
||||||
|
return HandleElementwiseUnary(real);
|
||||||
|
}
|
||||||
|
virtual Status HandleImag(HloInstruction* imag, HloInstruction* operand) {
|
||||||
|
return HandleElementwiseUnary(imag);
|
||||||
|
}
|
||||||
virtual Status HandleIsFinite(HloInstruction* is_finite,
|
virtual Status HandleIsFinite(HloInstruction* is_finite,
|
||||||
HloInstruction* operand) {
|
HloInstruction* operand) {
|
||||||
return HandleElementwiseUnary(is_finite);
|
return HandleElementwiseUnary(is_finite);
|
||||||
|
@ -54,10 +54,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
|
|||||||
const HloInstruction* op, llvm::Value* operand_value) const {
|
const HloInstruction* op, llvm::Value* operand_value) const {
|
||||||
if (op->opcode() == HloOpcode::kCopy) {
|
if (op->opcode() == HloOpcode::kCopy) {
|
||||||
return operand_value;
|
return operand_value;
|
||||||
|
} else if (operand_value->getType()->isIntegerTy()) {
|
||||||
|
return EmitIntegerUnaryOp(op, operand_value);
|
||||||
|
} else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
|
||||||
|
return EmitComplexUnaryOp(op, operand_value);
|
||||||
} else {
|
} else {
|
||||||
return operand_value->getType()->isIntegerTy()
|
return EmitFloatUnaryOp(op, operand_value);
|
||||||
? EmitIntegerUnaryOp(op, operand_value)
|
|
||||||
: EmitFloatUnaryOp(op, operand_value);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,20 +75,35 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
|||||||
}
|
}
|
||||||
if (primitive_util::IsIntegralType(to_type)) {
|
if (primitive_util::IsIntegralType(to_type)) {
|
||||||
return ir_builder_->CreateIntCast(
|
return ir_builder_->CreateIntCast(
|
||||||
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_),
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
|
||||||
primitive_util::IsSignedIntegralType(to_type));
|
primitive_util::IsSignedIntegralType(to_type));
|
||||||
}
|
}
|
||||||
if (primitive_util::IsFloatingPointType(to_type)) {
|
if (primitive_util::IsFloatingPointType(to_type)) {
|
||||||
if (primitive_util::IsSignedIntegralType(from_type)) {
|
if (primitive_util::IsSignedIntegralType(from_type)) {
|
||||||
return ir_builder_->CreateSIToFP(
|
return ir_builder_->CreateSIToFP(
|
||||||
operand_value,
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
|
||||||
}
|
}
|
||||||
if (primitive_util::IsUnsignedIntegralType(from_type) ||
|
if (primitive_util::IsUnsignedIntegralType(from_type) ||
|
||||||
from_type == PRED) {
|
from_type == PRED) {
|
||||||
return ir_builder_->CreateUIToFP(
|
return ir_builder_->CreateUIToFP(
|
||||||
operand_value,
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
}
|
||||||
|
}
|
||||||
|
if (primitive_util::IsComplexType(to_type)) {
|
||||||
|
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
|
||||||
|
primitive_util::ComplexComponentType(to_type), module_);
|
||||||
|
if (primitive_util::IsSignedIntegralType(from_type)) {
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateSIToFP(operand_value, to_ir_component_type),
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
|
if (primitive_util::IsUnsignedIntegralType(from_type) ||
|
||||||
|
from_type == PRED) {
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateUIToFP(operand_value, to_ir_component_type),
|
||||||
|
nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Unimplemented("conversion from primitive type %s to %s",
|
return Unimplemented("conversion from primitive type %s to %s",
|
||||||
@ -97,8 +114,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
|||||||
bool is_signed =
|
bool is_signed =
|
||||||
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
||||||
if (is_signed) {
|
if (is_signed) {
|
||||||
auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
|
auto type =
|
||||||
ir_builder_);
|
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
||||||
auto zero = llvm::ConstantInt::get(type, 0);
|
auto zero = llvm::ConstantInt::get(type, 0);
|
||||||
auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
|
auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
|
||||||
return ir_builder_->CreateSelect(cmp, operand_value,
|
return ir_builder_->CreateSelect(cmp, operand_value,
|
||||||
@ -110,8 +127,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
|||||||
case HloOpcode::kSign: {
|
case HloOpcode::kSign: {
|
||||||
bool is_signed =
|
bool is_signed =
|
||||||
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
primitive_util::IsSignedIntegralType(op->shape().element_type());
|
||||||
auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
|
auto type =
|
||||||
ir_builder_);
|
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
||||||
auto zero = llvm::ConstantInt::get(type, 0);
|
auto zero = llvm::ConstantInt::get(type, 0);
|
||||||
auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
|
auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
|
||||||
if (is_signed) {
|
if (is_signed) {
|
||||||
@ -135,7 +152,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
|||||||
return ir_builder_->CreateZExt(
|
return ir_builder_->CreateZExt(
|
||||||
ir_builder_->CreateNot(ir_builder_->CreateTrunc(
|
ir_builder_->CreateNot(ir_builder_->CreateTrunc(
|
||||||
operand_value, ir_builder_->getInt1Ty())),
|
operand_value, ir_builder_->getInt1Ty())),
|
||||||
llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
|
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
|
||||||
} else if (primitive_util::IsIntegralType(type)) {
|
} else if (primitive_util::IsIntegralType(type)) {
|
||||||
return ir_builder_->CreateNot(operand_value);
|
return ir_builder_->CreateNot(operand_value);
|
||||||
}
|
}
|
||||||
@ -157,20 +174,30 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
|||||||
if (from_type == to_type) {
|
if (from_type == to_type) {
|
||||||
return operand_value;
|
return operand_value;
|
||||||
}
|
}
|
||||||
|
if (primitive_util::IsComplexType(to_type)) {
|
||||||
|
PrimitiveType to_component_type =
|
||||||
|
primitive_util::ComplexComponentType(to_type);
|
||||||
|
if (from_type == to_component_type) {
|
||||||
|
return ComposeComplex(op, operand_value, nullptr);
|
||||||
|
}
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateFPCast(
|
||||||
|
operand_value,
|
||||||
|
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
|
||||||
|
nullptr);
|
||||||
|
}
|
||||||
if (primitive_util::IsFloatingPointType(to_type)) {
|
if (primitive_util::IsFloatingPointType(to_type)) {
|
||||||
return ir_builder_->CreateFPCast(
|
return ir_builder_->CreateFPCast(
|
||||||
operand_value,
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
|
||||||
}
|
}
|
||||||
if (primitive_util::IsSignedIntegralType(to_type)) {
|
if (primitive_util::IsSignedIntegralType(to_type)) {
|
||||||
return ir_builder_->CreateFPToSI(
|
return ir_builder_->CreateFPToSI(
|
||||||
operand_value,
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
|
||||||
}
|
}
|
||||||
if (primitive_util::IsUnsignedIntegralType(to_type)) {
|
if (primitive_util::IsUnsignedIntegralType(to_type)) {
|
||||||
return ir_builder_->CreateFPToUI(
|
return ir_builder_->CreateFPToUI(
|
||||||
operand_value,
|
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
|
||||||
llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
|
|
||||||
}
|
}
|
||||||
return Unimplemented("unhandled conversion operation: %s => %s",
|
return Unimplemented("unhandled conversion operation: %s => %s",
|
||||||
PrimitiveType_Name(from_type).c_str(),
|
PrimitiveType_Name(from_type).c_str(),
|
||||||
@ -230,7 +257,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
|||||||
auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
|
auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
|
||||||
auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
|
auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
|
||||||
return ir_builder_->CreateZExt(
|
return ir_builder_->CreateZExt(
|
||||||
result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
|
result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
|
||||||
}
|
}
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
return ir_builder_->CreateFNeg(operand_value);
|
return ir_builder_->CreateFNeg(operand_value);
|
||||||
@ -240,20 +267,164 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
|
||||||
|
const HloInstruction* op, llvm::Value* operand_value) const {
|
||||||
|
auto real = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {0});
|
||||||
|
};
|
||||||
|
auto imag = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {1});
|
||||||
|
};
|
||||||
|
switch (op->opcode()) {
|
||||||
|
// TODO(b/65209142): Angle/Log require atan2.
|
||||||
|
// case HloOpcode::kAngle:
|
||||||
|
// case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
|
||||||
|
case HloOpcode::kConvert: {
|
||||||
|
PrimitiveType from_type = op->operand(0)->shape().element_type();
|
||||||
|
TF_RET_CHECK(primitive_util::IsComplexType(from_type));
|
||||||
|
PrimitiveType to_type = op->shape().element_type();
|
||||||
|
TF_RET_CHECK(primitive_util::IsComplexType(to_type));
|
||||||
|
if (from_type == to_type) {
|
||||||
|
return operand_value;
|
||||||
|
}
|
||||||
|
PrimitiveType to_component_type =
|
||||||
|
primitive_util::ComplexComponentType(to_type);
|
||||||
|
auto to_ir_component_type =
|
||||||
|
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type),
|
||||||
|
ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type));
|
||||||
|
}
|
||||||
|
case HloOpcode::kExp: {
|
||||||
|
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
|
||||||
|
auto exp_a = llvm_ir::EmitCallToIntrinsic(
|
||||||
|
llvm::Intrinsic::exp, {real(operand_value)},
|
||||||
|
{real(operand_value)->getType()}, ir_builder_);
|
||||||
|
auto cos_b = llvm_ir::EmitCallToIntrinsic(
|
||||||
|
llvm::Intrinsic::cos, {imag(operand_value)},
|
||||||
|
{imag(operand_value)->getType()}, ir_builder_);
|
||||||
|
auto sin_b = llvm_ir::EmitCallToIntrinsic(
|
||||||
|
llvm::Intrinsic::sin, {imag(operand_value)},
|
||||||
|
{imag(operand_value)->getType()}, ir_builder_);
|
||||||
|
return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
|
||||||
|
ir_builder_->CreateFMul(exp_a, sin_b));
|
||||||
|
}
|
||||||
|
case HloOpcode::kCos: {
|
||||||
|
// cos(z) = .5(e^(iz) + e^(-iz))
|
||||||
|
// cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
|
||||||
|
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
|
||||||
|
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
|
||||||
|
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
|
||||||
|
// cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
|
||||||
|
// = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
|
||||||
|
auto a = real(operand_value);
|
||||||
|
auto b = imag(operand_value);
|
||||||
|
auto type = a->getType();
|
||||||
|
auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
|
||||||
|
{type}, ir_builder_);
|
||||||
|
auto half_exp_b =
|
||||||
|
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
|
||||||
|
auto half_exp_neg_b =
|
||||||
|
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
|
||||||
|
auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a},
|
||||||
|
{type}, ir_builder_);
|
||||||
|
auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
|
||||||
|
{type}, ir_builder_);
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateFMul(
|
||||||
|
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
|
||||||
|
ir_builder_->CreateFMul(
|
||||||
|
sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
|
||||||
|
}
|
||||||
|
case HloOpcode::kSin: {
|
||||||
|
// sin(z) = .5i(e^(-iz) - e^(iz))
|
||||||
|
// sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
|
||||||
|
// = .5i(e^(b-ai) - e^(-b+ai))
|
||||||
|
// now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
|
||||||
|
// sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
|
||||||
|
// = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
|
||||||
|
// cos(-x) = cos(x) and sin(-x) = -sin(x), so
|
||||||
|
// = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
|
||||||
|
// = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
|
||||||
|
auto a = real(operand_value);
|
||||||
|
auto b = imag(operand_value);
|
||||||
|
auto type = a->getType();
|
||||||
|
auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b},
|
||||||
|
{type}, ir_builder_);
|
||||||
|
auto half_exp_b =
|
||||||
|
ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
|
||||||
|
auto half_exp_neg_b =
|
||||||
|
ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
|
||||||
|
auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a},
|
||||||
|
{type}, ir_builder_);
|
||||||
|
auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a},
|
||||||
|
{type}, ir_builder_);
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateFMul(
|
||||||
|
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
|
||||||
|
ir_builder_->CreateFMul(
|
||||||
|
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
|
||||||
|
}
|
||||||
|
case HloOpcode::kAbs: {
|
||||||
|
auto sum_sq = ir_builder_->CreateFAdd(
|
||||||
|
ir_builder_->CreateFMul(real(operand_value), real(operand_value)),
|
||||||
|
ir_builder_->CreateFMul(imag(operand_value), imag(operand_value)));
|
||||||
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
|
||||||
|
{sum_sq->getType()}, ir_builder_);
|
||||||
|
}
|
||||||
|
case HloOpcode::kSign: { // Sign(c) = c / |c|
|
||||||
|
auto sum_sq = ir_builder_->CreateFAdd(
|
||||||
|
ir_builder_->CreateFMul(real(operand_value), real(operand_value)),
|
||||||
|
ir_builder_->CreateFMul(imag(operand_value), imag(operand_value)));
|
||||||
|
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
|
||||||
|
llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_);
|
||||||
|
auto type = cplx_abs->getType();
|
||||||
|
auto zero = llvm::ConstantFP::get(type, 0.0);
|
||||||
|
auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero);
|
||||||
|
return ir_builder_->CreateSelect(
|
||||||
|
oeq, ComposeComplex(op, zero, zero),
|
||||||
|
ComposeComplex(
|
||||||
|
op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs),
|
||||||
|
ir_builder_->CreateFDiv(imag(operand_value), cplx_abs)));
|
||||||
|
}
|
||||||
|
case HloOpcode::kNegate:
|
||||||
|
return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)),
|
||||||
|
ir_builder_->CreateFNeg(imag(operand_value)));
|
||||||
|
case HloOpcode::kReal:
|
||||||
|
return real(operand_value);
|
||||||
|
case HloOpcode::kImag:
|
||||||
|
return imag(operand_value);
|
||||||
|
default:
|
||||||
|
return Unimplemented("unary complex op '%s'",
|
||||||
|
HloOpcodeString(op->opcode()).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
|
||||||
const HloInstruction* op, llvm::Value* lhs_value,
|
const HloInstruction* op, llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const {
|
llvm::Value* rhs_value) const {
|
||||||
return lhs_value->getType()->isIntegerTy()
|
PrimitiveType operand_type = op->operand(0)->shape().element_type();
|
||||||
? EmitIntegerBinaryOp(op, lhs_value, rhs_value,
|
if (lhs_value->getType()->isIntegerTy()) {
|
||||||
primitive_util::IsSignedIntegralType(
|
return EmitIntegerBinaryOp(
|
||||||
op->operand(0)->shape().element_type()))
|
op, lhs_value, rhs_value,
|
||||||
: EmitFloatBinaryOp(op, lhs_value, rhs_value);
|
primitive_util::IsSignedIntegralType(operand_type));
|
||||||
|
} else if (primitive_util::IsComplexType(operand_type)) {
|
||||||
|
return EmitComplexBinaryOp(op, lhs_value, rhs_value);
|
||||||
|
} else {
|
||||||
|
return EmitFloatBinaryOp(op, lhs_value, rhs_value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
||||||
const HloInstruction* op, llvm::Value* lhs_value,
|
const HloInstruction* op, llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const {
|
llvm::Value* rhs_value) const {
|
||||||
switch (op->opcode()) {
|
switch (op->opcode()) {
|
||||||
|
// case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support
|
||||||
|
case HloOpcode::kComplex:
|
||||||
|
return ComposeComplex(op, lhs_value, rhs_value);
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
return ir_builder_->CreateFAdd(lhs_value, rhs_value);
|
return ir_builder_->CreateFAdd(lhs_value, rhs_value);
|
||||||
case HloOpcode::kSubtract:
|
case HloOpcode::kSubtract:
|
||||||
@ -305,6 +476,88 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
|
||||||
|
const HloInstruction* op, llvm::Value* lhs_value,
|
||||||
|
llvm::Value* rhs_value) const {
|
||||||
|
auto real = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {0});
|
||||||
|
};
|
||||||
|
auto imag = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {1});
|
||||||
|
};
|
||||||
|
switch (op->opcode()) {
|
||||||
|
case HloOpcode::kAdd:
|
||||||
|
return ComposeComplex(
|
||||||
|
op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)),
|
||||||
|
ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value)));
|
||||||
|
case HloOpcode::kSubtract:
|
||||||
|
return ComposeComplex(
|
||||||
|
op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)),
|
||||||
|
ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value)));
|
||||||
|
case HloOpcode::kMultiply:
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateFSub(
|
||||||
|
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))),
|
||||||
|
ir_builder_->CreateFAdd(
|
||||||
|
ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))));
|
||||||
|
case HloOpcode::kDivide: {
|
||||||
|
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
|
||||||
|
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
|
||||||
|
auto rhs_sum_sq = ir_builder_->CreateFAdd(
|
||||||
|
ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value)));
|
||||||
|
auto type = rhs_sum_sq->getType();
|
||||||
|
auto zero = llvm::ConstantFP::get(type, 0.0);
|
||||||
|
auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero);
|
||||||
|
return ir_builder_->CreateSelect(
|
||||||
|
oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero),
|
||||||
|
ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateFDiv(
|
||||||
|
ir_builder_->CreateFAdd(
|
||||||
|
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(imag(lhs_value),
|
||||||
|
imag(rhs_value))),
|
||||||
|
rhs_sum_sq),
|
||||||
|
ir_builder_->CreateFDiv(
|
||||||
|
ir_builder_->CreateFSub(
|
||||||
|
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(real(lhs_value),
|
||||||
|
imag(rhs_value))),
|
||||||
|
rhs_sum_sq)));
|
||||||
|
}
|
||||||
|
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
|
||||||
|
// comparisons always return false when one of the operands is NaN, whereas
|
||||||
|
// unordered comparisons return true.
|
||||||
|
//
|
||||||
|
// We use ordered comparisons for everything except kNe, where we use an
|
||||||
|
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
||||||
|
// matches C++'s semantics.
|
||||||
|
case HloOpcode::kEq:
|
||||||
|
return ir_builder_->CreateAnd(
|
||||||
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value),
|
||||||
|
real(rhs_value), ir_builder_),
|
||||||
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value),
|
||||||
|
imag(rhs_value), ir_builder_));
|
||||||
|
case HloOpcode::kNe:
|
||||||
|
return ir_builder_->CreateOr(
|
||||||
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value),
|
||||||
|
real(rhs_value), ir_builder_),
|
||||||
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value),
|
||||||
|
imag(rhs_value), ir_builder_));
|
||||||
|
|
||||||
|
// TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic
|
||||||
|
// case HloOpcode::kPower:
|
||||||
|
// // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2)
|
||||||
|
default:
|
||||||
|
return Unimplemented("binary complex op '%s'",
|
||||||
|
HloOpcodeString(op->opcode()).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const {
|
llvm::Value* rhs_value) const {
|
||||||
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_);
|
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_);
|
||||||
@ -396,7 +649,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
|||||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
|
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
|
||||||
PrimitiveType prim_type, llvm::Value* value) const {
|
PrimitiveType prim_type, llvm::Value* value) const {
|
||||||
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
|
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
|
||||||
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, ir_builder_);
|
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
|
||||||
auto one = llvm::ConstantFP::get(type, 1.0);
|
auto one = llvm::ConstantFP::get(type, 1.0);
|
||||||
return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
|
return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
|
||||||
}
|
}
|
||||||
@ -619,7 +872,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
|
|||||||
const {
|
const {
|
||||||
PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
|
PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
|
||||||
llvm::Type* param_ir_type =
|
llvm::Type* param_ir_type =
|
||||||
llvm_ir::PrimitiveTypeToIrType(param_prim_type, ir_builder_);
|
llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_);
|
||||||
|
|
||||||
// Same values as PCG library
|
// Same values as PCG library
|
||||||
// https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
|
// https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
|
||||||
@ -783,7 +1036,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
|
|||||||
return ir_builder_->CreateZExt(
|
return ir_builder_->CreateZExt(
|
||||||
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p),
|
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p),
|
||||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||||
ir_builder_));
|
module_));
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -806,9 +1059,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
case HloOpcode::kSin:
|
case HloOpcode::kSin:
|
||||||
case HloOpcode::kTanh:
|
case HloOpcode::kTanh:
|
||||||
@ -821,6 +1076,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
return EmitUnaryOp(hlo, operand_value);
|
return EmitUnaryOp(hlo, operand_value);
|
||||||
};
|
};
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kEq:
|
||||||
case HloOpcode::kGe:
|
case HloOpcode::kGe:
|
||||||
@ -913,9 +1170,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
|
llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
|
||||||
llvm::PHINode* output = ir_builder_->CreatePHI(
|
llvm::PHINode* output =
|
||||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(
|
||||||
ir_builder_),
|
hlo->shape().element_type(), module_),
|
||||||
hlo->operands().size());
|
hlo->operands().size());
|
||||||
auto prior_insert_point = ir_builder_->GetInsertPoint();
|
auto prior_insert_point = ir_builder_->GetInsertPoint();
|
||||||
|
|
||||||
@ -1075,7 +1332,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
// else -> return data from 'index'.
|
// else -> return data from 'index'.
|
||||||
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||||
ir_builder_),
|
module_),
|
||||||
"ret_value_addr", ir_builder_);
|
"ret_value_addr", ir_builder_);
|
||||||
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
|
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
|
||||||
slice_intersection, "slice_intersection", ir_builder_);
|
slice_intersection, "slice_intersection", ir_builder_);
|
||||||
@ -1164,7 +1421,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
// }
|
// }
|
||||||
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||||
ir_builder_),
|
module_),
|
||||||
"pad_result_addr", ir_builder_);
|
"pad_result_addr", ir_builder_);
|
||||||
llvm_ir::LlvmIfData if_data =
|
llvm_ir::LlvmIfData if_data =
|
||||||
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
|
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
|
||||||
@ -1206,7 +1463,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
ir_builder_);
|
ir_builder_);
|
||||||
PrimitiveType primitive_type = hlo->shape().element_type();
|
PrimitiveType primitive_type = hlo->shape().element_type();
|
||||||
llvm::Type* primitive_type_llvm =
|
llvm::Type* primitive_type_llvm =
|
||||||
llvm_ir::PrimitiveTypeToIrType(primitive_type, ir_builder_);
|
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
|
||||||
llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
primitive_type_llvm, "dot_acc", ir_builder_);
|
primitive_type_llvm, "dot_acc", ir_builder_);
|
||||||
ir_builder_->CreateStore(
|
ir_builder_->CreateStore(
|
||||||
@ -1239,7 +1496,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
|
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
|
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
|
||||||
llvm::Value* next_accumulator;
|
llvm::Value* next_accumulator;
|
||||||
if (primitive_util::IsFloatingPointType(primitive_type)) {
|
if (primitive_util::IsComplexType(primitive_type)) {
|
||||||
|
auto real = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {0});
|
||||||
|
};
|
||||||
|
auto imag = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {1});
|
||||||
|
};
|
||||||
|
llvm::Value* product_real = ir_builder_->CreateFSub(
|
||||||
|
ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value)));
|
||||||
|
llvm::Value* product_imag = ir_builder_->CreateFAdd(
|
||||||
|
ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)),
|
||||||
|
ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)));
|
||||||
|
next_accumulator = ir_builder_->CreateInsertValue(
|
||||||
|
current_accumulator,
|
||||||
|
ir_builder_->CreateFAdd(real(current_accumulator), product_real),
|
||||||
|
{0});
|
||||||
|
next_accumulator = ir_builder_->CreateInsertValue(
|
||||||
|
next_accumulator,
|
||||||
|
ir_builder_->CreateFAdd(imag(current_accumulator), product_imag),
|
||||||
|
{1});
|
||||||
|
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
|
||||||
next_accumulator = ir_builder_->CreateFAdd(
|
next_accumulator = ir_builder_->CreateFAdd(
|
||||||
current_accumulator,
|
current_accumulator,
|
||||||
ir_builder_->CreateFMul(lhs_value, rhs_value));
|
ir_builder_->CreateFMul(lhs_value, rhs_value));
|
||||||
@ -1261,4 +1539,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op,
|
||||||
|
llvm::Value* real,
|
||||||
|
llvm::Value* imag) const {
|
||||||
|
auto cplx_type =
|
||||||
|
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
|
||||||
|
auto complex = ir_builder_->CreateInsertValue(
|
||||||
|
llvm::ConstantAggregateZero::get(cplx_type), real, {0});
|
||||||
|
if (imag != nullptr) {
|
||||||
|
complex = ir_builder_->CreateInsertValue(complex, imag, {1});
|
||||||
|
}
|
||||||
|
return complex;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -55,6 +55,7 @@ class ElementalIrEmitter {
|
|||||||
const HloToElementGeneratorMap& operand_to_generator) const;
|
const HloToElementGeneratorMap& operand_to_generator) const;
|
||||||
|
|
||||||
llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
|
llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
|
||||||
|
llvm::Module* module() const { return module_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
|
virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
|
||||||
@ -63,6 +64,9 @@ class ElementalIrEmitter {
|
|||||||
virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
||||||
const HloInstruction* op, llvm::Value* operand_value) const;
|
const HloInstruction* op, llvm::Value* operand_value) const;
|
||||||
|
|
||||||
|
virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
|
||||||
|
const HloInstruction* op, llvm::Value* operand_value) const;
|
||||||
|
|
||||||
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
|
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
|
||||||
llvm::Value* lhs_value,
|
llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value,
|
llvm::Value* rhs_value,
|
||||||
@ -72,6 +76,10 @@ class ElementalIrEmitter {
|
|||||||
const HloInstruction* op, llvm::Value* lhs_value,
|
const HloInstruction* op, llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const;
|
llvm::Value* rhs_value) const;
|
||||||
|
|
||||||
|
virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(
|
||||||
|
const HloInstruction* op, llvm::Value* lhs_value,
|
||||||
|
llvm::Value* rhs_value) const;
|
||||||
|
|
||||||
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
|
virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const;
|
llvm::Value* rhs_value) const;
|
||||||
|
|
||||||
@ -109,6 +117,11 @@ class ElementalIrEmitter {
|
|||||||
// compiled executable outside of the HLO code itself.
|
// compiled executable outside of the HLO code itself.
|
||||||
const HloModuleConfig& hlo_module_config_;
|
const HloModuleConfig& hlo_module_config_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Composes a complex struct. imag may be nullptr for simple cast operations.
|
||||||
|
llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real,
|
||||||
|
llvm::Value* imag) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Returns a ElementGenerator for a RNG HloInstruction.
|
// Returns a ElementGenerator for a RNG HloInstruction.
|
||||||
llvm_ir::ElementGenerator MakeRngElementGenerator(
|
llvm_ir::ElementGenerator MakeRngElementGenerator(
|
||||||
|
@ -135,6 +135,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
|
|||||||
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
|
PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
|
||||||
PrimitiveType output_type = op->shape().element_type();
|
PrimitiveType output_type = op->shape().element_type();
|
||||||
switch (op->opcode()) {
|
switch (op->opcode()) {
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
|
return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value},
|
||||||
|
{lhs_input_type, rhs_input_type},
|
||||||
|
output_type);
|
||||||
case HloOpcode::kRemainder: {
|
case HloOpcode::kRemainder: {
|
||||||
return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value},
|
return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value},
|
||||||
{lhs_input_type, rhs_input_type},
|
{lhs_input_type, rhs_input_type},
|
||||||
@ -226,6 +230,112 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitComplexUnaryOp(
|
||||||
|
const HloInstruction* op, llvm::Value* operand_value) const {
|
||||||
|
PrimitiveType input_type = op->operand(0)->shape().element_type();
|
||||||
|
PrimitiveType component_type =
|
||||||
|
primitive_util::IsComplexType(input_type)
|
||||||
|
? primitive_util::ComplexComponentType(input_type)
|
||||||
|
: input_type;
|
||||||
|
auto real = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {0});
|
||||||
|
};
|
||||||
|
auto imag = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_->CreateExtractValue(x, {1});
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (op->opcode()) {
|
||||||
|
case HloOpcode::kLog: {
|
||||||
|
// log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
|
||||||
|
auto a = real(operand_value);
|
||||||
|
auto b = imag(operand_value);
|
||||||
|
llvm::Type* llvm_ty = a->getType();
|
||||||
|
auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
|
||||||
|
ir_builder_->CreateFMul(b, b));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto log_sum_sq,
|
||||||
|
EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type},
|
||||||
|
component_type));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a},
|
||||||
|
{component_type, component_type},
|
||||||
|
component_type));
|
||||||
|
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
|
||||||
|
return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq),
|
||||||
|
angle);
|
||||||
|
}
|
||||||
|
// TODO(b/65408531): Implement kPower on GPU, where atan2 is available.
|
||||||
|
// case HloOpcode::kPower:
|
||||||
|
// // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di))
|
||||||
|
case HloOpcode::kExp: {
|
||||||
|
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
|
||||||
|
auto b = imag(operand_value);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)},
|
||||||
|
{component_type}, component_type));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type},
|
||||||
|
component_type));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type},
|
||||||
|
component_type));
|
||||||
|
return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
|
||||||
|
ir_builder_->CreateFMul(exp_a, sin_b));
|
||||||
|
}
|
||||||
|
case HloOpcode::kCos: {
|
||||||
|
// cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
|
||||||
|
auto a = real(operand_value);
|
||||||
|
auto llvm_ty = a->getType();
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)},
|
||||||
|
{component_type}, component_type));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
|
||||||
|
component_type));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type},
|
||||||
|
component_type));
|
||||||
|
auto half_exp_b =
|
||||||
|
ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
|
||||||
|
auto half_exp_neg_b =
|
||||||
|
ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateFMul(
|
||||||
|
cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
|
||||||
|
ir_builder_->CreateFMul(
|
||||||
|
sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
case HloOpcode::kSin: {
|
||||||
|
// sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
|
||||||
|
auto a = real(operand_value);
|
||||||
|
auto llvm_ty = a->getType();
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)},
|
||||||
|
{component_type}, component_type));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type},
|
||||||
|
component_type));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type},
|
||||||
|
component_type));
|
||||||
|
auto half_exp_b =
|
||||||
|
ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
|
||||||
|
auto half_exp_neg_b =
|
||||||
|
ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b);
|
||||||
|
return ComposeComplex(
|
||||||
|
op,
|
||||||
|
ir_builder_->CreateFMul(
|
||||||
|
sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
|
||||||
|
ir_builder_->CreateFMul(
|
||||||
|
cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
|
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
|
||||||
const string& callee_name,
|
const string& callee_name,
|
||||||
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
||||||
@ -235,12 +345,11 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
|
|||||||
std::vector<llvm::Type*> ir_input_types;
|
std::vector<llvm::Type*> ir_input_types;
|
||||||
for (PrimitiveType input_type : input_types) {
|
for (PrimitiveType input_type : input_types) {
|
||||||
ir_input_types.push_back(
|
ir_input_types.push_back(
|
||||||
llvm_ir::PrimitiveTypeToIrType(input_type, ir_builder_));
|
llvm_ir::PrimitiveTypeToIrType(input_type, module_));
|
||||||
}
|
}
|
||||||
llvm::FunctionType* callee_type = llvm::FunctionType::get(
|
llvm::FunctionType* callee_type = llvm::FunctionType::get(
|
||||||
llvm_ir::PrimitiveTypeToIrType(output_type,
|
llvm_ir::PrimitiveTypeToIrType(output_type, module_), // Return type.
|
||||||
ir_builder_), // The return type.
|
ir_input_types, // Parameter types.
|
||||||
ir_input_types, // The parameter types.
|
|
||||||
false); // No variadic arguments.
|
false); // No variadic arguments.
|
||||||
|
|
||||||
// Declares the callee if it is not declared already.
|
// Declares the callee if it is not declared already.
|
||||||
@ -315,7 +424,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
|||||||
|
|
||||||
PrimitiveType operand_element_type = operand->shape().element_type();
|
PrimitiveType operand_element_type = operand->shape().element_type();
|
||||||
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
|
||||||
"reduce_window_accum_ptr", ir_builder_);
|
"reduce_window_accum_ptr", ir_builder_);
|
||||||
{
|
{
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
|
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
|
||||||
@ -377,7 +486,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
|||||||
const HloInstruction* operand = hlo->operand(0);
|
const HloInstruction* operand = hlo->operand(0);
|
||||||
llvm::Value* accum_ptr =
|
llvm::Value* accum_ptr =
|
||||||
ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
|
ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
|
||||||
hlo->shape().element_type(), ir_builder()));
|
hlo->shape().element_type(), module_));
|
||||||
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
|
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
|
||||||
operand_to_generator.at(hlo->operand(1))({}));
|
operand_to_generator.at(hlo->operand(1))({}));
|
||||||
ir_builder()->CreateStore(init_value, accum_ptr);
|
ir_builder()->CreateStore(init_value, accum_ptr);
|
||||||
|
@ -54,6 +54,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
|||||||
StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
StatusOr<llvm::Value*> EmitFloatUnaryOp(
|
||||||
const HloInstruction* op, llvm::Value* operand_value) const override;
|
const HloInstruction* op, llvm::Value* operand_value) const override;
|
||||||
|
|
||||||
|
StatusOr<llvm::Value*> EmitComplexUnaryOp(
|
||||||
|
const HloInstruction* op, llvm::Value* operand_value) const override;
|
||||||
|
|
||||||
StatusOr<llvm::Value*> EmitFloatBinaryOp(
|
StatusOr<llvm::Value*> EmitFloatBinaryOp(
|
||||||
const HloInstruction* op, llvm::Value* lhs_value,
|
const HloInstruction* op, llvm::Value* lhs_value,
|
||||||
llvm::Value* rhs_value) const override;
|
llvm::Value* rhs_value) const override;
|
||||||
|
@ -102,7 +102,7 @@ void HloToIrBindings::EmitBasePointersForHlos(
|
|||||||
slice_result.ConsumeValueOrDie();
|
slice_result.ConsumeValueOrDie();
|
||||||
if (slice.allocation()->is_thread_local()) {
|
if (slice.allocation()->is_thread_local()) {
|
||||||
llvm::Type* pointee_type =
|
llvm::Type* pointee_type =
|
||||||
llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_);
|
llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
|
||||||
BindHloToIrValue(*non_io_hlo,
|
BindHloToIrValue(*non_io_hlo,
|
||||||
ir_builder_->CreateAlloca(pointee_type), index);
|
ir_builder_->CreateAlloca(pointee_type), index);
|
||||||
} else {
|
} else {
|
||||||
@ -124,18 +124,18 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
|
|||||||
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
|
if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
|
||||||
return llvm_ir::EmitGetTupleElement(
|
return llvm_ir::EmitGetTupleElement(
|
||||||
gte->shape(), gte->tuple_index(), /*alignment=*/1,
|
gte->shape(), gte->tuple_index(), /*alignment=*/1,
|
||||||
GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_);
|
GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_);
|
||||||
}
|
}
|
||||||
return llvm_ir::EmitGetTupleElement(
|
return llvm_ir::EmitGetTupleElement(
|
||||||
gte->shape(), gte->tuple_index(), /*alignment=*/1,
|
gte->shape(), gte->tuple_index(), /*alignment=*/1,
|
||||||
EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_);
|
EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_);
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
|
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
|
||||||
const ShapeIndex& shape_index,
|
const ShapeIndex& shape_index,
|
||||||
llvm::Value* ir_value) {
|
llvm::Value* ir_value) {
|
||||||
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
|
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
|
||||||
ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_builder_);
|
ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
|
||||||
llvm::Type* dest_type = pointee_type->getPointerTo();
|
llvm::Type* dest_type = pointee_type->getPointerTo();
|
||||||
|
|
||||||
llvm::Value* typed_ir_value;
|
llvm::Value* typed_ir_value;
|
||||||
|
@ -36,10 +36,12 @@ class HloToIrBindings {
|
|||||||
public:
|
public:
|
||||||
HloToIrBindings(const HloModule& module,
|
HloToIrBindings(const HloModule& module,
|
||||||
const BufferAssignment* buffer_assignment,
|
const BufferAssignment* buffer_assignment,
|
||||||
llvm::IRBuilder<>* ir_builder, bool is_nested)
|
llvm::IRBuilder<>* ir_builder, llvm::Module* llvm_module,
|
||||||
|
bool is_nested)
|
||||||
: buffer_assignment_(buffer_assignment),
|
: buffer_assignment_(buffer_assignment),
|
||||||
is_nested_(is_nested),
|
is_nested_(is_nested),
|
||||||
ir_builder_(ir_builder),
|
ir_builder_(ir_builder),
|
||||||
|
module_(llvm_module),
|
||||||
alias_analysis_(module, *buffer_assignment_,
|
alias_analysis_(module, *buffer_assignment_,
|
||||||
&ir_builder_->getContext()) {}
|
&ir_builder_->getContext()) {}
|
||||||
|
|
||||||
@ -93,6 +95,7 @@ class HloToIrBindings {
|
|||||||
const bool is_nested_;
|
const bool is_nested_;
|
||||||
|
|
||||||
llvm::IRBuilder<>* ir_builder_;
|
llvm::IRBuilder<>* ir_builder_;
|
||||||
|
llvm::Module* module_;
|
||||||
|
|
||||||
// Stores the underlying llvm::IrArray for each HloInstruction.
|
// Stores the underlying llvm::IrArray for each HloInstruction.
|
||||||
// For an instruction that generates multiple outputs, the root will be a
|
// For an instruction that generates multiple outputs, the root will be a
|
||||||
|
@ -53,9 +53,10 @@ namespace gpu {
|
|||||||
IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
|
IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
|
||||||
IrEmitterContext* ir_emitter_context, bool is_nested)
|
IrEmitterContext* ir_emitter_context, bool is_nested)
|
||||||
: ir_emitter_context_(ir_emitter_context),
|
: ir_emitter_context_(ir_emitter_context),
|
||||||
ir_builder_(ir_emitter_context->llvm_module()->getContext()),
|
module_(ir_emitter_context->llvm_module()),
|
||||||
|
ir_builder_(module_->getContext()),
|
||||||
bindings_(ir_emitter_context->hlo_module(),
|
bindings_(ir_emitter_context->hlo_module(),
|
||||||
&ir_emitter_context->buffer_assignment(), &ir_builder_,
|
&ir_emitter_context->buffer_assignment(), &ir_builder_, module_,
|
||||||
is_nested),
|
is_nested),
|
||||||
hlo_module_config_(hlo_module_config) {
|
hlo_module_config_(hlo_module_config) {
|
||||||
ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
|
ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
|
||||||
@ -71,18 +72,17 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
return EmitTargetElementLoop(
|
return EmitTargetElementLoop(
|
||||||
*hlo, GpuElementalIrEmitter(hlo_module_config_,
|
*hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
|
||||||
ir_emitter_context_->llvm_module(),
|
GetNestedComputer())
|
||||||
&ir_builder_, GetNestedComputer())
|
|
||||||
.MakeElementGenerator(hlo, operand_to_generator));
|
.MakeElementGenerator(hlo, operand_to_generator));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleConstant(HloInstruction* constant,
|
Status IrEmitter::HandleConstant(HloInstruction* constant,
|
||||||
const Literal& literal) {
|
const Literal& literal) {
|
||||||
llvm::Constant* initializer =
|
llvm::Constant* initializer =
|
||||||
llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
|
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
|
||||||
llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
|
llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
|
||||||
*ir_emitter_context_->llvm_module(), initializer->getType(),
|
*module_, initializer->getType(),
|
||||||
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
|
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
|
||||||
/*Name=*/"");
|
/*Name=*/"");
|
||||||
VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl
|
VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl
|
||||||
@ -115,7 +115,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
|
|||||||
get_tuple_element->shape(), get_tuple_element->tuple_index(),
|
get_tuple_element->shape(), get_tuple_element->tuple_index(),
|
||||||
// TODO(b/26344050): tighten the alignment here
|
// TODO(b/26344050): tighten the alignment here
|
||||||
// based on the real element type.
|
// based on the real element type.
|
||||||
/*alignment=*/1, GetBasePointer(*operand), &ir_builder_));
|
/*alignment=*/1, GetBasePointer(*operand), &ir_builder_, module_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ Status IrEmitter::HandleTuple(
|
|||||||
for (const HloInstruction* operand : operands) {
|
for (const HloInstruction* operand : operands) {
|
||||||
base_ptrs.push_back(GetBasePointer(*operand));
|
base_ptrs.push_back(GetBasePointer(*operand));
|
||||||
}
|
}
|
||||||
llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_);
|
llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,7 +329,7 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
|
|||||||
if (ShapeUtil::IsTuple(select->shape())) {
|
if (ShapeUtil::IsTuple(select->shape())) {
|
||||||
llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred),
|
llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred),
|
||||||
GetBasePointer(*on_true),
|
GetBasePointer(*on_true),
|
||||||
GetBasePointer(*on_false), &ir_builder_);
|
GetBasePointer(*on_false), &ir_builder_, module_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,7 +355,26 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
|
|||||||
lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
|
lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
|
||||||
llvm::Value* rhs_value =
|
llvm::Value* rhs_value =
|
||||||
rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
|
rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
|
||||||
llvm::Value* result = ir_builder_.CreateFMul(lhs_value, rhs_value);
|
llvm::Value* result;
|
||||||
|
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
|
||||||
|
auto real = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_.CreateExtractValue(x, {0});
|
||||||
|
};
|
||||||
|
auto imag = [&](llvm::Value* x) {
|
||||||
|
return ir_builder_.CreateExtractValue(x, {1});
|
||||||
|
};
|
||||||
|
llvm::Value* real_result = ir_builder_.CreateFSub(
|
||||||
|
ir_builder_.CreateFMul(real(lhs_value), real(rhs_value)),
|
||||||
|
ir_builder_.CreateFMul(imag(lhs_value), imag(rhs_value)));
|
||||||
|
llvm::Value* imag_result = ir_builder_.CreateFAdd(
|
||||||
|
ir_builder_.CreateFMul(real(lhs_value), imag(rhs_value)),
|
||||||
|
ir_builder_.CreateFMul(imag(lhs_value), real(rhs_value)));
|
||||||
|
result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType());
|
||||||
|
result = ir_builder_.CreateInsertValue(result, real_result, {0});
|
||||||
|
result = ir_builder_.CreateInsertValue(result, imag_result, {1});
|
||||||
|
} else {
|
||||||
|
result = ir_builder_.CreateFMul(lhs_value, rhs_value);
|
||||||
|
}
|
||||||
target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_);
|
target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -411,7 +430,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
|
|||||||
|
|
||||||
// Initialize the accumulator in the preheader to zero.
|
// Initialize the accumulator in the preheader to zero.
|
||||||
new llvm::StoreInst(
|
new llvm::StoreInst(
|
||||||
llvm::ConstantFP::get(accum_type, 0.0), // The value stored.
|
llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0
|
||||||
accum_address, // The address.
|
accum_address, // The address.
|
||||||
reduction_loop->GetPreheaderBasicBlock()
|
reduction_loop->GetPreheaderBasicBlock()
|
||||||
->getTerminator()); // The instruction this store is inserted before.
|
->getTerminator()); // The instruction this store is inserted before.
|
||||||
@ -427,9 +446,27 @@ Status IrEmitter::HandleDot(HloInstruction* dot,
|
|||||||
lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_);
|
lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_);
|
||||||
llvm::Value* rhs_element =
|
llvm::Value* rhs_element =
|
||||||
rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_);
|
rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_);
|
||||||
llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element);
|
|
||||||
llvm::Value* accum = ir_builder_.CreateLoad(accum_address);
|
llvm::Value* accum = ir_builder_.CreateLoad(accum_address);
|
||||||
llvm::Value* updated_accum = ir_builder_.CreateFAdd(accum, product);
|
llvm::Value* updated_accum;
|
||||||
|
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
|
||||||
|
#define REAL(x) ir_builder_.CreateExtractValue(x, {0})
|
||||||
|
#define IMAG(x) ir_builder_.CreateExtractValue(x, {1})
|
||||||
|
llvm::Value* product_real = ir_builder_.CreateFSub(
|
||||||
|
ir_builder_.CreateFMul(REAL(lhs_element), REAL(rhs_element)),
|
||||||
|
ir_builder_.CreateFMul(IMAG(lhs_element), IMAG(rhs_element)));
|
||||||
|
llvm::Value* product_imag = ir_builder_.CreateFAdd(
|
||||||
|
ir_builder_.CreateFMul(REAL(lhs_element), IMAG(rhs_element)),
|
||||||
|
ir_builder_.CreateFMul(IMAG(lhs_element), REAL(rhs_element)));
|
||||||
|
updated_accum = ir_builder_.CreateInsertValue(
|
||||||
|
accum, ir_builder_.CreateFAdd(REAL(accum), product_real), {0});
|
||||||
|
updated_accum = ir_builder_.CreateInsertValue(
|
||||||
|
updated_accum, ir_builder_.CreateFAdd(IMAG(accum), product_imag), {1});
|
||||||
|
#undef IMAG
|
||||||
|
#undef REAL
|
||||||
|
} else {
|
||||||
|
llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element);
|
||||||
|
updated_accum = ir_builder_.CreateFAdd(accum, product);
|
||||||
|
}
|
||||||
ir_builder_.CreateStore(updated_accum, accum_address);
|
ir_builder_.CreateStore(updated_accum, accum_address);
|
||||||
|
|
||||||
// After the reduction loop exits, store the accumulator into the target
|
// After the reduction loop exits, store the accumulator into the target
|
||||||
@ -494,7 +531,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
|||||||
// Initialize an accumulator with init_value.
|
// Initialize an accumulator with init_value.
|
||||||
llvm::AllocaInst* accumulator_addr =
|
llvm::AllocaInst* accumulator_addr =
|
||||||
ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
|
ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
|
||||||
reduce->shape().element_type(), &ir_builder_));
|
reduce->shape().element_type(), module_));
|
||||||
ir_builder_.CreateStore(
|
ir_builder_.CreateStore(
|
||||||
ir_builder_.CreateLoad(GetBasePointer(*init_value)),
|
ir_builder_.CreateLoad(GetBasePointer(*init_value)),
|
||||||
accumulator_addr);
|
accumulator_addr);
|
||||||
@ -547,8 +584,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
|
|||||||
for (HloInstruction* operand : fusion->operands()) {
|
for (HloInstruction* operand : fusion->operands()) {
|
||||||
parameter_arrays.push_back(GetIrArray(*operand));
|
parameter_arrays.push_back(GetIrArray(*operand));
|
||||||
}
|
}
|
||||||
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
|
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_,
|
||||||
ir_emitter_context_->llvm_module(),
|
|
||||||
&ir_builder_, GetNestedComputer());
|
&ir_builder_, GetNestedComputer());
|
||||||
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
|
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
|
||||||
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
|
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
|
||||||
@ -591,9 +627,8 @@ Status IrEmitter::HandleRng(HloInstruction* random,
|
|||||||
// Emits a single-threaded loop because the loop body generated by the element
|
// Emits a single-threaded loop because the loop body generated by the element
|
||||||
// generator for Rng can't be parallelized (b/32333178).
|
// generator for Rng can't be parallelized (b/32333178).
|
||||||
return llvm_ir::LoopEmitter(
|
return llvm_ir::LoopEmitter(
|
||||||
GpuElementalIrEmitter(hlo_module_config_,
|
GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_,
|
||||||
ir_emitter_context_->llvm_module(),
|
GetNestedComputer())
|
||||||
&ir_builder_, GetNestedComputer())
|
|
||||||
.MakeElementGenerator(random, operand_to_generator),
|
.MakeElementGenerator(random, operand_to_generator),
|
||||||
GetIrArray(*random), &ir_builder_)
|
GetIrArray(*random), &ir_builder_)
|
||||||
.EmitLoop(IrName(random));
|
.EmitLoop(IrName(random));
|
||||||
@ -634,7 +669,7 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
|
|||||||
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
|
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
|
||||||
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(
|
llvm_ir::PrimitiveTypeToIrType(
|
||||||
computation.root_instruction()->shape().element_type(), &ir_builder_),
|
computation.root_instruction()->shape().element_type(), module_),
|
||||||
"return_buffer", &ir_builder_);
|
"return_buffer", &ir_builder_);
|
||||||
std::vector<llvm::Value*> parameter_buffers;
|
std::vector<llvm::Value*> parameter_buffers;
|
||||||
for (llvm::Value* parameter_element : parameter_elements) {
|
for (llvm::Value* parameter_element : parameter_elements) {
|
||||||
|
@ -162,6 +162,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
|||||||
}
|
}
|
||||||
|
|
||||||
IrEmitterContext* ir_emitter_context_;
|
IrEmitterContext* ir_emitter_context_;
|
||||||
|
llvm::Module* module_;
|
||||||
|
|
||||||
// The following fields track the IR emission state. According to LLVM memory
|
// The following fields track the IR emission state. According to LLVM memory
|
||||||
// management rules, their memory is owned by the module.
|
// management rules, their memory is owned by the module.
|
||||||
|
@ -52,9 +52,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
|
|||||||
io_hlos->push_back(param);
|
io_hlos->push_back(param);
|
||||||
const Shape& param_shape = param->shape();
|
const Shape& param_shape = param->shape();
|
||||||
argument_types.push_back(
|
argument_types.push_back(
|
||||||
llvm_ir::ShapeToIrType(param_shape, &ir_builder_)->getPointerTo());
|
llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
|
||||||
int64 param_size = llvm_ir::ByteSizeOf(
|
int64 param_size =
|
||||||
param_shape, ir_emitter_context_->llvm_module()->getDataLayout());
|
llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
|
||||||
argument_dereferenceable_bytes.push_back(param_size);
|
argument_dereferenceable_bytes.push_back(param_size);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
@ -62,7 +62,7 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
|
|||||||
io_hlos->push_back(root);
|
io_hlos->push_back(root);
|
||||||
const Shape& root_shape = root->shape();
|
const Shape& root_shape = root->shape();
|
||||||
argument_types.push_back(
|
argument_types.push_back(
|
||||||
llvm_ir::ShapeToIrType(root_shape, &ir_builder_)->getPointerTo());
|
llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
|
||||||
int64 root_size = llvm_ir::ByteSizeOf(
|
int64 root_size = llvm_ir::ByteSizeOf(
|
||||||
root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
|
root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
|
||||||
argument_dereferenceable_bytes.push_back(root_size);
|
argument_dereferenceable_bytes.push_back(root_size);
|
||||||
|
@ -757,8 +757,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
|||||||
auto loop_body_emitter =
|
auto loop_body_emitter =
|
||||||
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
|
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
|
||||||
// Emit the loop body that reduces one tile.
|
// Emit the loop body that reduces one tile.
|
||||||
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
|
llvm::Type* element_ir_type =
|
||||||
input_shape.element_type(), &ir_builder_);
|
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
|
||||||
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
|
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
|
||||||
element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
|
element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
|
||||||
{
|
{
|
||||||
@ -973,7 +973,7 @@ Status IrEmitterUnnested::EmitRowReduction(
|
|||||||
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
|
[=](const llvm_ir::IrArray::Index& tile_index) -> Status {
|
||||||
// Emit the loop body that reduces one tile.
|
// Emit the loop body that reduces one tile.
|
||||||
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
|
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
|
||||||
input_shape.element_type(), &ir_builder_);
|
input_shape.element_type(), ir_emitter_context_->llvm_module());
|
||||||
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
|
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
|
||||||
element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
|
element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
|
||||||
{
|
{
|
||||||
@ -1360,7 +1360,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
|||||||
// boolean flag if the value is initialized. The initialized_flag is set
|
// boolean flag if the value is initialized. The initialized_flag is set
|
||||||
// false.
|
// false.
|
||||||
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(operand_element_type,
|
||||||
|
ir_emitter_context_->llvm_module()),
|
||||||
"selected_value_address", &ir_builder_);
|
"selected_value_address", &ir_builder_);
|
||||||
llvm::Value* selected_index_address =
|
llvm::Value* selected_index_address =
|
||||||
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||||
@ -1440,7 +1441,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
|||||||
llvm::Value* operand_address =
|
llvm::Value* operand_address =
|
||||||
operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
|
operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
|
||||||
llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||||
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_),
|
llvm_ir::PrimitiveTypeToIrType(PRED,
|
||||||
|
ir_emitter_context_->llvm_module()),
|
||||||
"select_return_buffer", &ir_builder_);
|
"select_return_buffer", &ir_builder_);
|
||||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
||||||
*select_and_scatter->select(),
|
*select_and_scatter->select(),
|
||||||
@ -1450,8 +1452,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
|
|||||||
// If the 'select' function returns false, update the selected value and the
|
// If the 'select' function returns false, update the selected value and the
|
||||||
// index to the currently visiting operand.
|
// index to the currently visiting operand.
|
||||||
llvm::Value* cond = ir_builder_.CreateICmpNE(
|
llvm::Value* cond = ir_builder_.CreateICmpNE(
|
||||||
result, llvm::ConstantInt::get(
|
result,
|
||||||
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
|
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
|
||||||
|
PRED, ir_emitter_context_->llvm_module()),
|
||||||
|
0),
|
||||||
"boolean_predicate");
|
"boolean_predicate");
|
||||||
llvm_ir::LlvmIfData if_select_lhs =
|
llvm_ir::LlvmIfData if_select_lhs =
|
||||||
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
|
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
|
||||||
@ -1877,7 +1881,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
|||||||
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
|
tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
|
||||||
}
|
}
|
||||||
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
|
ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator());
|
||||||
llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_);
|
llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_,
|
||||||
|
module_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,6 +50,12 @@ namespace xla {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_complex_t : public std::false_type {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct is_complex_t<complex64> : public std::true_type {};
|
||||||
|
|
||||||
template <typename OperandT>
|
template <typename OperandT>
|
||||||
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
||||||
const Literal& lhs_literal,
|
const Literal& lhs_literal,
|
||||||
@ -101,6 +107,37 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
|||||||
return std::move(result);
|
return std::move(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
|
||||||
|
const Shape& shape, HloOpcode opcode, const Literal& lhs_literal,
|
||||||
|
const Literal& rhs_literal) {
|
||||||
|
std::function<bool(complex64, complex64)> compare_op;
|
||||||
|
switch (opcode) {
|
||||||
|
case HloOpcode::kEq:
|
||||||
|
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
||||||
|
return lhs_el == rhs_el;
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case HloOpcode::kNe:
|
||||||
|
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
||||||
|
return lhs_el != rhs_el;
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
||||||
|
<< HloOpcodeString(opcode);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = Literal::CreateFromShape(shape);
|
||||||
|
TF_RETURN_IF_ERROR(result->Populate<bool>(
|
||||||
|
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||||
|
return compare_op(lhs_literal.Get<complex64>(multi_index),
|
||||||
|
rhs_literal.Get<complex64>(multi_index));
|
||||||
|
}));
|
||||||
|
|
||||||
|
return std::move(result);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename ReturnT, typename NativeT>
|
template <typename ReturnT, typename NativeT>
|
||||||
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
|
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
|
||||||
HloInstruction* instruction,
|
HloInstruction* instruction,
|
||||||
@ -138,7 +175,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
Status DefaultAction(HloInstruction* hlo_instruction) override {
|
Status DefaultAction(HloInstruction* hlo_instruction) override {
|
||||||
return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
|
return Unimplemented("unhandled HLO ops for HloEvaluator: %s.",
|
||||||
HloOpcodeString(hlo_instruction->opcode()).c_str());
|
HloOpcodeString(hlo_instruction->opcode()).c_str());
|
||||||
};
|
}
|
||||||
|
|
||||||
// TODO(b/35950897): many of the stl functions used in the handlers are not
|
// TODO(b/35950897): many of the stl functions used in the handlers are not
|
||||||
// overloaded for every XLA primitive types.
|
// overloaded for every XLA primitive types.
|
||||||
@ -156,7 +193,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
template <
|
template <
|
||||||
typename NativeT,
|
typename NativeT,
|
||||||
typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr>
|
typename std::enable_if<std::is_signed<NativeT>::value ||
|
||||||
|
is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
|
Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs],
|
||||||
ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
|
ElementWiseUnaryOp(abs, [](NativeT elem_operand) {
|
||||||
@ -169,7 +207,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return HandleAbs<ReturnT>(abs, operand);
|
return HandleAbs<ReturnT>(abs, operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status HandleRound(HloInstruction* round) override {
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleRound(HloInstruction* round) {
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[round],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[round],
|
||||||
ElementWiseUnaryOp(round, [](ReturnT elem_operand) {
|
ElementWiseUnaryOp(round, [](ReturnT elem_operand) {
|
||||||
return std::round(elem_operand);
|
return std::round(elem_operand);
|
||||||
@ -177,6 +218,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleRound(HloInstruction* round) {
|
||||||
|
return InvalidArgument("Unsupported type for Round");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleRound(HloInstruction* round) override {
|
||||||
|
return HandleRound<ReturnT>(round);
|
||||||
|
}
|
||||||
|
|
||||||
Status HandleBroadcast(HloInstruction* broadcast) override {
|
Status HandleBroadcast(HloInstruction* broadcast) override {
|
||||||
parent_->evaluated_[broadcast] =
|
parent_->evaluated_[broadcast] =
|
||||||
Literal::CreateFromShape(broadcast->shape());
|
Literal::CreateFromShape(broadcast->shape());
|
||||||
@ -205,15 +257,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
}
|
}
|
||||||
return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
|
return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
|
||||||
});
|
});
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override {
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleCeil(HloInstruction* ceil) {
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil],
|
||||||
ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) {
|
ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) {
|
||||||
return std::ceil(elem_operand);
|
return std::ceil(elem_operand);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleCeil(HloInstruction* ceil) {
|
||||||
|
return InvalidArgument("Unsupported type for Ceil");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override {
|
||||||
|
return HandleCeil<ReturnT>(ceil);
|
||||||
|
}
|
||||||
|
|
||||||
Status HandleConvert(HloInstruction* convert) override {
|
Status HandleConvert(HloInstruction* convert) override {
|
||||||
const HloInstruction* operand = convert->operand(0);
|
const HloInstruction* operand = convert->operand(0);
|
||||||
@ -237,15 +303,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return std::exp(elem_operand);
|
return std::exp(elem_operand);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override {
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleFloor(HloInstruction* floor) {
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor],
|
||||||
ElementWiseUnaryOp(floor, [](ReturnT elem_operand) {
|
ElementWiseUnaryOp(floor, [](ReturnT elem_operand) {
|
||||||
return std::floor(elem_operand);
|
return std::floor(elem_operand);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleFloor(HloInstruction* floor) {
|
||||||
|
return InvalidArgument("Unsupported type for Floor");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override {
|
||||||
|
return HandleFloor<ReturnT>(floor);
|
||||||
|
}
|
||||||
|
|
||||||
Status HandleLog(HloInstruction* log, HloInstruction* operand) override {
|
Status HandleLog(HloInstruction* log, HloInstruction* operand) override {
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
|
||||||
@ -253,15 +333,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return std::log(elem_operand);
|
return std::log(elem_operand);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleNot(HloInstruction* not_, HloInstruction* operand) override {
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleNot(HloInstruction* not_) {
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_],
|
||||||
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
|
ElementWiseUnaryOp(not_, [](ReturnT elem_operand) {
|
||||||
return !elem_operand;
|
return !elem_operand;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleNot(HloInstruction* not_) {
|
||||||
|
return InvalidArgument("Unsupported type for Not");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleNot(HloInstruction* not_, HloInstruction* operand) override {
|
||||||
|
return HandleNot<ReturnT>(not_);
|
||||||
|
}
|
||||||
|
|
||||||
Status HandleNegate(HloInstruction* negate,
|
Status HandleNegate(HloInstruction* negate,
|
||||||
HloInstruction* operand) override {
|
HloInstruction* operand) override {
|
||||||
@ -270,16 +364,36 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return -elem_operand;
|
return -elem_operand;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleSign(HloInstruction* sign, HloInstruction* operand) override {
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleSign(HloInstruction* sign) {
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
|
||||||
ElementWiseUnaryOp(sign, [](ReturnT elem_operand) {
|
ElementWiseUnaryOp(sign, [](ReturnT elem_operand) {
|
||||||
return (ReturnT(0) < elem_operand) -
|
return (ReturnT(0) < elem_operand) -
|
||||||
(elem_operand < ReturnT(0));
|
(elem_operand < ReturnT(0));
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleSign(HloInstruction* sign) {
|
||||||
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
|
||||||
|
ElementWiseUnaryOp(sign, [](ReturnT elem_operand) {
|
||||||
|
auto abs_val = std::abs(elem_operand);
|
||||||
|
return 0 == abs_val ? ReturnT(0)
|
||||||
|
: elem_operand / abs_val;
|
||||||
|
}));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleSign(HloInstruction* sign, HloInstruction* operand) override {
|
||||||
|
return HandleSign<ReturnT>(sign);
|
||||||
|
}
|
||||||
|
|
||||||
Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override {
|
Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override {
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
|
||||||
@ -287,7 +401,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return std::tanh(elem_operand);
|
return std::tanh(elem_operand);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
|
Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) override {
|
HloInstruction* rhs) override {
|
||||||
@ -297,7 +411,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return lhs_elem * rhs_elem;
|
return lhs_elem * rhs_elem;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
|
Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) override {
|
HloInstruction* rhs) override {
|
||||||
@ -307,7 +421,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return lhs_elem - rhs_elem;
|
return lhs_elem - rhs_elem;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
|
Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) override {
|
HloInstruction* rhs) override {
|
||||||
@ -317,7 +431,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return lhs_elem + rhs_elem;
|
return lhs_elem + rhs_elem;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
|
Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) override {
|
HloInstruction* rhs) override {
|
||||||
@ -327,25 +441,53 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return lhs_elem / rhs_elem;
|
return lhs_elem / rhs_elem;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleMaximum(HloInstruction* maximum) override {
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleMaximum(HloInstruction* maximum) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
parent_->evaluated_[maximum],
|
parent_->evaluated_[maximum],
|
||||||
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
|
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
|
||||||
return std::fmax(lhs, rhs);
|
return std::fmax(lhs, rhs);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleMinimum(HloInstruction* minimum) override {
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleMaximum(HloInstruction* maximum) {
|
||||||
|
return InvalidArgument("Unsupported type for Maximum");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleMaximum(HloInstruction* maximum) override {
|
||||||
|
return HandleMaximum<ReturnT>(maximum);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleMinimum(HloInstruction* minimum) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
parent_->evaluated_[minimum],
|
parent_->evaluated_[minimum],
|
||||||
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
|
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||||
return std::fmin(lhs_el, rhs_el);
|
return std::fmin(lhs_el, rhs_el);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleMinimum(HloInstruction* minimum) {
|
||||||
|
return InvalidArgument("Unsupported type for Minimum");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleMinimum(HloInstruction* minimum) override {
|
||||||
|
return HandleMinimum<ReturnT>(minimum);
|
||||||
|
}
|
||||||
|
|
||||||
Status HandlePower(HloInstruction* power, HloInstruction* lhs,
|
Status HandlePower(HloInstruction* power, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) override {
|
HloInstruction* rhs) override {
|
||||||
@ -355,37 +497,79 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return std::pow(lhs_el, rhs_el);
|
return std::pow(lhs_el, rhs_el);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
|
template <
|
||||||
HloInstruction* rhs) override {
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleRemainder(HloInstruction* remainder) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
parent_->evaluated_[remainder],
|
parent_->evaluated_[remainder],
|
||||||
ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) {
|
ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||||
return std::fmod(lhs_el, rhs_el);
|
return std::fmod(lhs_el, rhs_el);
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleAnd(HloInstruction* and_, HloInstruction* lhs,
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleRemainder(HloInstruction* remainder) {
|
||||||
|
return InvalidArgument("Unsupported type for Remainder");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) override {
|
HloInstruction* rhs) override {
|
||||||
|
return HandleRemainder<ReturnT>(remainder);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleAnd(HloInstruction* and_) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
parent_->evaluated_[and_],
|
parent_->evaluated_[and_],
|
||||||
ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) {
|
ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||||
return lhs_el && rhs_el;
|
return lhs_el && rhs_el;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleOr(HloInstruction* or_, HloInstruction* lhs,
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleAnd(HloInstruction* and_) {
|
||||||
|
return InvalidArgument("Unsupported type for And");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleAnd(HloInstruction* and_, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) override {
|
HloInstruction* rhs) override {
|
||||||
|
return HandleAnd<ReturnT>(and_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleOr(HloInstruction* or_) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
parent_->evaluated_[or_],
|
parent_->evaluated_[or_],
|
||||||
ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) {
|
ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||||
return lhs_el || rhs_el;
|
return lhs_el || rhs_el;
|
||||||
}));
|
}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleOr(HloInstruction* or_) {
|
||||||
|
return InvalidArgument("Unsupported type for Or");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleOr(HloInstruction* or_, HloInstruction* lhs,
|
||||||
|
HloInstruction* rhs) override {
|
||||||
|
return HandleOr<ReturnT>(or_);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NativeT,
|
template <typename NativeT,
|
||||||
typename std::enable_if<
|
typename std::enable_if<
|
||||||
@ -474,8 +658,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return HandleShiftRightLogical<ReturnT>(shrl, lhs, rhs);
|
return HandleShiftRightLogical<ReturnT>(shrl, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||||
HloInstruction* arg, HloInstruction* max) override {
|
HloInstruction* arg, HloInstruction* max) {
|
||||||
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
|
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
|
||||||
[](ReturnT low, ReturnT high, ReturnT value) {
|
[](ReturnT low, ReturnT high, ReturnT value) {
|
||||||
return std::fmax(low, std::fmin(value, high));
|
return std::fmax(low, std::fmin(value, high));
|
||||||
@ -483,7 +670,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
|
||||||
ElementWiseTernaryOp(clamp, std::move(clamp_op)));
|
ElementWiseTernaryOp(clamp, std::move(clamp_op)));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename NativeT,
|
||||||
|
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||||
|
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||||
|
HloInstruction* arg, HloInstruction* max) {
|
||||||
|
return InvalidArgument("Unsupported type for Clamp");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
|
||||||
|
HloInstruction* arg, HloInstruction* max) override {
|
||||||
|
return HandleClamp<ReturnT>(clamp, min, arg, max);
|
||||||
|
}
|
||||||
|
|
||||||
Status HandleSelect(HloInstruction* select, HloInstruction* pred,
|
Status HandleSelect(HloInstruction* select, HloInstruction* pred,
|
||||||
HloInstruction* on_true,
|
HloInstruction* on_true,
|
||||||
@ -499,7 +699,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
|
TF_ASSIGN_OR_RETURN(parent_->evaluated_[select],
|
||||||
ElementWiseTernaryOp(select, std::move(select_op)));
|
ElementWiseTernaryOp(select, std::move(select_op)));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleReverse(HloInstruction* reverse,
|
Status HandleReverse(HloInstruction* reverse,
|
||||||
HloInstruction* operand) override {
|
HloInstruction* operand) override {
|
||||||
@ -529,7 +729,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
parent_->evaluated_[reverse] = std::move(result);
|
parent_->evaluated_[reverse] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
|
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
|
||||||
HloInstruction* rhs, const Window& window) override {
|
HloInstruction* rhs, const Window& window) override {
|
||||||
@ -652,7 +852,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
parent_->evaluated_[conv] = std::move(result);
|
parent_->evaluated_[conv] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
||||||
HloInstruction* rhs) override {
|
HloInstruction* rhs) override {
|
||||||
@ -719,7 +919,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
parent_->evaluated_[dot] = std::move(result);
|
parent_->evaluated_[dot] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandlePad(HloInstruction* pad) override {
|
Status HandlePad(HloInstruction* pad) override {
|
||||||
CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
|
CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
|
||||||
@ -788,7 +988,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
parent_->evaluated_[pad] = std::move(result);
|
parent_->evaluated_[pad] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
|
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
|
||||||
HloInstruction* operand,
|
HloInstruction* operand,
|
||||||
@ -841,7 +1041,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
|
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
|
||||||
HloInstruction* operand,
|
HloInstruction* operand,
|
||||||
@ -897,7 +1097,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
||||||
HloInstruction* init_value,
|
HloInstruction* init_value,
|
||||||
@ -985,7 +1185,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
parent_->evaluated_[reduce] = std::move(result);
|
parent_->evaluated_[reduce] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleReduceWindow(HloInstruction* reduce_window,
|
Status HandleReduceWindow(HloInstruction* reduce_window,
|
||||||
HloInstruction* operand, const Window& window,
|
HloInstruction* operand, const Window& window,
|
||||||
@ -1072,7 +1272,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
parent_->evaluated_[reduce_window] = std::move(result);
|
parent_->evaluated_[reduce_window] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override {
|
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override {
|
||||||
const Shape& shape = slice->shape();
|
const Shape& shape = slice->shape();
|
||||||
@ -1101,7 +1301,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
|
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
|
||||||
parent_->evaluated_[slice] = std::move(result);
|
parent_->evaluated_[slice] = std::move(result);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename IndexT>
|
template <typename IndexT>
|
||||||
@ -1244,35 +1444,33 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
}
|
}
|
||||||
|
|
||||||
HloEvaluator* parent_;
|
HloEvaluator* parent_;
|
||||||
}; // namespace xla
|
}; // class HloEvaluator::TypedVisitor
|
||||||
|
|
||||||
HloEvaluator::HloEvaluator() {
|
HloEvaluator::HloEvaluator() {
|
||||||
typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this);
|
typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this);
|
||||||
typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this);
|
typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this);
|
||||||
typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||||
return Unimplemented("unhandled primitive type: U16.");
|
return Unimplemented("HloEvaluator: unhandled primitive type: U16.");
|
||||||
});
|
});
|
||||||
typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this);
|
typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this);
|
||||||
typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this);
|
typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this);
|
||||||
typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this);
|
typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this);
|
||||||
typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||||
return Unimplemented("unhandled primitive type: S16.");
|
return Unimplemented("HloEvaluator: unhandled primitive type: S16.");
|
||||||
});
|
});
|
||||||
typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
|
typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this);
|
||||||
typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
|
typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this);
|
||||||
typed_visitors_[F16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
typed_visitors_[F16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||||
return Unimplemented("unhandled primitive type: F16.");
|
return Unimplemented("HloEvaluator: unhandled primitive type: F16.");
|
||||||
});
|
});
|
||||||
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
|
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
|
||||||
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
|
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
|
||||||
typed_visitors_[C64] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
|
||||||
return Unimplemented("unhandled primitive type: C64.");
|
|
||||||
});
|
|
||||||
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||||
return Unimplemented("unhandled primitive type: TUPLE.");
|
return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
|
||||||
});
|
});
|
||||||
typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||||
return Unimplemented("unhandled primitive type: OPAQUE.");
|
return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE.");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1573,6 +1771,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
|||||||
evaluated_[compare],
|
evaluated_[compare],
|
||||||
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
|
case C64: {
|
||||||
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
|
Compare<complex64>(compare->shape(), opcode,
|
||||||
|
lhs_literal, rhs_literal));
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "HandleCompare: unknown primitive type: "
|
LOG(FATAL) << "HandleCompare: unknown primitive type: "
|
||||||
<< PrimitiveType_Name(lhs->shape().element_type());
|
<< PrimitiveType_Name(lhs->shape().element_type());
|
||||||
|
@ -826,8 +826,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
|||||||
case HloOpcode::kAbs:
|
case HloOpcode::kAbs:
|
||||||
case HloOpcode::kRoundNearestAfz:
|
case HloOpcode::kRoundNearestAfz:
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
case HloOpcode::kCeil:
|
case HloOpcode::kCeil:
|
||||||
case HloOpcode::kClamp:
|
case HloOpcode::kClamp:
|
||||||
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kConvert:
|
case HloOpcode::kConvert:
|
||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
@ -836,6 +838,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
|||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
case HloOpcode::kGe:
|
case HloOpcode::kGe:
|
||||||
case HloOpcode::kGt:
|
case HloOpcode::kGt:
|
||||||
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIndex:
|
case HloOpcode::kIndex:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLe:
|
case HloOpcode::kLe:
|
||||||
@ -850,6 +853,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
|||||||
case HloOpcode::kNe:
|
case HloOpcode::kNe:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kRemainder:
|
case HloOpcode::kRemainder:
|
||||||
case HloOpcode::kShiftLeft:
|
case HloOpcode::kShiftLeft:
|
||||||
case HloOpcode::kShiftRightArithmetic:
|
case HloOpcode::kShiftRightArithmetic:
|
||||||
|
@ -219,10 +219,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
|||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
case HloOpcode::kSin:
|
case HloOpcode::kSin:
|
||||||
case HloOpcode::kSort:
|
case HloOpcode::kSort:
|
||||||
@ -241,26 +243,28 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
|||||||
// Only certain opcodes are supported with CreateBinary: opcodes of binary
|
// Only certain opcodes are supported with CreateBinary: opcodes of binary
|
||||||
// instructions with no auxiliary fields.
|
// instructions with no auxiliary fields.
|
||||||
switch (opcode) {
|
switch (opcode) {
|
||||||
case (HloOpcode::kAdd):
|
case HloOpcode::kAdd:
|
||||||
case (HloOpcode::kDivide):
|
case HloOpcode::kAtan2:
|
||||||
case (HloOpcode::kDot):
|
case HloOpcode::kDivide:
|
||||||
case (HloOpcode::kEq):
|
case HloOpcode::kComplex:
|
||||||
case (HloOpcode::kGe):
|
case HloOpcode::kDot:
|
||||||
case (HloOpcode::kGt):
|
case HloOpcode::kEq:
|
||||||
case (HloOpcode::kLe):
|
case HloOpcode::kGe:
|
||||||
case (HloOpcode::kLt):
|
case HloOpcode::kGt:
|
||||||
case (HloOpcode::kMaximum):
|
case HloOpcode::kLe:
|
||||||
case (HloOpcode::kMinimum):
|
case HloOpcode::kLt:
|
||||||
case (HloOpcode::kMultiply):
|
case HloOpcode::kMaximum:
|
||||||
case (HloOpcode::kNe):
|
case HloOpcode::kMinimum:
|
||||||
case (HloOpcode::kPower):
|
case HloOpcode::kMultiply:
|
||||||
case (HloOpcode::kRemainder):
|
case HloOpcode::kNe:
|
||||||
case (HloOpcode::kSubtract):
|
case HloOpcode::kPower:
|
||||||
case (HloOpcode::kAnd):
|
case HloOpcode::kRemainder:
|
||||||
case (HloOpcode::kOr):
|
case HloOpcode::kSubtract:
|
||||||
case (HloOpcode::kShiftLeft):
|
case HloOpcode::kAnd:
|
||||||
case (HloOpcode::kShiftRightArithmetic):
|
case HloOpcode::kOr:
|
||||||
case (HloOpcode::kShiftRightLogical):
|
case HloOpcode::kShiftLeft:
|
||||||
|
case HloOpcode::kShiftRightArithmetic:
|
||||||
|
case HloOpcode::kShiftRightLogical:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Invalid binary instruction opcode "
|
LOG(FATAL) << "Invalid binary instruction opcode "
|
||||||
@ -978,11 +982,13 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
case HloOpcode::kCopy:
|
case HloOpcode::kCopy:
|
||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
case HloOpcode::kSin:
|
case HloOpcode::kSin:
|
||||||
case HloOpcode::kSort:
|
case HloOpcode::kSort:
|
||||||
@ -992,6 +998,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
break;
|
break;
|
||||||
// Binary ops.
|
// Binary ops.
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kSubtract:
|
case HloOpcode::kSubtract:
|
||||||
@ -1403,10 +1411,12 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
// The result of these instructions only depend upon their opcode and
|
// The result of these instructions only depend upon their opcode and
|
||||||
// operands.
|
// operands.
|
||||||
case HloOpcode::kAbs:
|
case HloOpcode::kAbs:
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
case HloOpcode::kRoundNearestAfz:
|
case HloOpcode::kRoundNearestAfz:
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
case HloOpcode::kCeil:
|
case HloOpcode::kCeil:
|
||||||
case HloOpcode::kClamp:
|
case HloOpcode::kClamp:
|
||||||
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kCopy:
|
case HloOpcode::kCopy:
|
||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kCrossReplicaSum:
|
case HloOpcode::kCrossReplicaSum:
|
||||||
@ -1417,6 +1427,7 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
case HloOpcode::kGe:
|
case HloOpcode::kGe:
|
||||||
case HloOpcode::kGt:
|
case HloOpcode::kGt:
|
||||||
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLe:
|
case HloOpcode::kLe:
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
@ -1430,6 +1441,7 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
case HloOpcode::kNe:
|
case HloOpcode::kNe:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kRemainder:
|
case HloOpcode::kRemainder:
|
||||||
case HloOpcode::kSelect:
|
case HloOpcode::kSelect:
|
||||||
case HloOpcode::kShiftLeft:
|
case HloOpcode::kShiftLeft:
|
||||||
@ -2117,6 +2129,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
|||||||
switch (opcode_) {
|
switch (opcode_) {
|
||||||
case HloOpcode::kAbs:
|
case HloOpcode::kAbs:
|
||||||
return visitor->HandleAbs(this, operands_[0]);
|
return visitor->HandleAbs(this, operands_[0]);
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
|
return visitor->HandleAtan2(this, operands_[0], operands_[1]);
|
||||||
case HloOpcode::kRoundNearestAfz:
|
case HloOpcode::kRoundNearestAfz:
|
||||||
return visitor->HandleRound(this);
|
return visitor->HandleRound(this);
|
||||||
case HloOpcode::kBatchNormTraining:
|
case HloOpcode::kBatchNormTraining:
|
||||||
@ -2140,6 +2154,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
|||||||
case HloOpcode::kLt:
|
case HloOpcode::kLt:
|
||||||
case HloOpcode::kNe:
|
case HloOpcode::kNe:
|
||||||
return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]);
|
return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]);
|
||||||
|
case HloOpcode::kComplex:
|
||||||
|
return visitor->HandleComplex(this, operands_[0], operands_[1]);
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
return visitor->HandleAdd(this, operands_[0], operands_[1]);
|
return visitor->HandleAdd(this, operands_[0], operands_[1]);
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
@ -2214,6 +2230,10 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
|||||||
return visitor->HandleCos(this, operands_[0]);
|
return visitor->HandleCos(this, operands_[0]);
|
||||||
case HloOpcode::kSin:
|
case HloOpcode::kSin:
|
||||||
return visitor->HandleSin(this, operands_[0]);
|
return visitor->HandleSin(this, operands_[0]);
|
||||||
|
case HloOpcode::kReal:
|
||||||
|
return visitor->HandleReal(this, operands_[0]);
|
||||||
|
case HloOpcode::kImag:
|
||||||
|
return visitor->HandleImag(this, operands_[0]);
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
return visitor->HandleIsFinite(this, operands_[0]);
|
return visitor->HandleIsFinite(this, operands_[0]);
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
@ -2305,7 +2325,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
|
|||||||
//
|
//
|
||||||
// We need to keep track of both the id and the instruction because
|
// We need to keep track of both the id and the instruction because
|
||||||
// instructions can get deleted while they are on the stack, so we
|
// instructions can get deleted while they are on the stack, so we
|
||||||
// can't always use the (potentiall dead) instruction object to grab
|
// can't always use the (potentially dead) instruction object to grab
|
||||||
// its id.
|
// its id.
|
||||||
DFSStack dfs_stack;
|
DFSStack dfs_stack;
|
||||||
dfs_stack.emplace_back(root->unique_id(), root);
|
dfs_stack.emplace_back(root->unique_id(), root);
|
||||||
@ -2505,6 +2525,7 @@ bool HloInstruction::IsElementwiseBinary() const {
|
|||||||
// Binary elementwise operations. If you update this, please update
|
// Binary elementwise operations. If you update this, please update
|
||||||
// IsElementwise() accordingly.
|
// IsElementwise() accordingly.
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kEq:
|
||||||
case HloOpcode::kGe:
|
case HloOpcode::kGe:
|
||||||
@ -2537,6 +2558,7 @@ bool HloInstruction::IsElementwise() const {
|
|||||||
|
|
||||||
// Unary elementwise operations.
|
// Unary elementwise operations.
|
||||||
case HloOpcode::kAbs:
|
case HloOpcode::kAbs:
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
case HloOpcode::kRoundNearestAfz:
|
case HloOpcode::kRoundNearestAfz:
|
||||||
case HloOpcode::kCeil:
|
case HloOpcode::kCeil:
|
||||||
case HloOpcode::kConvert:
|
case HloOpcode::kConvert:
|
||||||
@ -2544,10 +2566,12 @@ bool HloInstruction::IsElementwise() const {
|
|||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kReducePrecision:
|
case HloOpcode::kReducePrecision:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
case HloOpcode::kSin:
|
case HloOpcode::kSin:
|
||||||
@ -2557,6 +2581,7 @@ bool HloInstruction::IsElementwise() const {
|
|||||||
// Binary elementwise operations, the same as in IsElementwiseBinary().
|
// Binary elementwise operations, the same as in IsElementwiseBinary().
|
||||||
// If you update this, please update IsElementwiseBinary() accordingly.
|
// If you update this, please update IsElementwiseBinary() accordingly.
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kEq:
|
||||||
case HloOpcode::kGe:
|
case HloOpcode::kGe:
|
||||||
|
@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
|||||||
return "abs";
|
return "abs";
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
return "add";
|
return "add";
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
|
return "atan2";
|
||||||
case HloOpcode::kBatchNormTraining:
|
case HloOpcode::kBatchNormTraining:
|
||||||
return "batch-norm-training";
|
return "batch-norm-training";
|
||||||
case HloOpcode::kBatchNormInference:
|
case HloOpcode::kBatchNormInference:
|
||||||
@ -47,6 +49,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
|||||||
return "call";
|
return "call";
|
||||||
case HloOpcode::kClamp:
|
case HloOpcode::kClamp:
|
||||||
return "clamp";
|
return "clamp";
|
||||||
|
case HloOpcode::kComplex:
|
||||||
|
return "complex";
|
||||||
case HloOpcode::kConcatenate:
|
case HloOpcode::kConcatenate:
|
||||||
return "concatenate";
|
return "concatenate";
|
||||||
case HloOpcode::kConstant:
|
case HloOpcode::kConstant:
|
||||||
@ -87,6 +91,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
|||||||
return "get-tuple-element";
|
return "get-tuple-element";
|
||||||
case HloOpcode::kGt:
|
case HloOpcode::kGt:
|
||||||
return "greater-than";
|
return "greater-than";
|
||||||
|
case HloOpcode::kImag:
|
||||||
|
return "imag";
|
||||||
case HloOpcode::kIndex:
|
case HloOpcode::kIndex:
|
||||||
return "index";
|
return "index";
|
||||||
case HloOpcode::kInfeed:
|
case HloOpcode::kInfeed:
|
||||||
@ -125,6 +131,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
|||||||
return "parameter";
|
return "parameter";
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
return "power";
|
return "power";
|
||||||
|
case HloOpcode::kReal:
|
||||||
|
return "real";
|
||||||
case HloOpcode::kRecv:
|
case HloOpcode::kRecv:
|
||||||
return "recv";
|
return "recv";
|
||||||
case HloOpcode::kReduce:
|
case HloOpcode::kReduce:
|
||||||
|
@ -31,6 +31,7 @@ namespace xla {
|
|||||||
enum class HloOpcode {
|
enum class HloOpcode {
|
||||||
kAbs,
|
kAbs,
|
||||||
kAdd,
|
kAdd,
|
||||||
|
kAtan2,
|
||||||
kBatchNormGrad,
|
kBatchNormGrad,
|
||||||
kBatchNormInference,
|
kBatchNormInference,
|
||||||
kBatchNormTraining,
|
kBatchNormTraining,
|
||||||
@ -39,6 +40,7 @@ enum class HloOpcode {
|
|||||||
kCall,
|
kCall,
|
||||||
kCeil,
|
kCeil,
|
||||||
kClamp,
|
kClamp,
|
||||||
|
kComplex,
|
||||||
kConcatenate,
|
kConcatenate,
|
||||||
kConstant,
|
kConstant,
|
||||||
kConvert,
|
kConvert,
|
||||||
@ -58,6 +60,7 @@ enum class HloOpcode {
|
|||||||
kGe,
|
kGe,
|
||||||
kGetTupleElement,
|
kGetTupleElement,
|
||||||
kGt,
|
kGt,
|
||||||
|
kImag,
|
||||||
kIndex,
|
kIndex,
|
||||||
kInfeed,
|
kInfeed,
|
||||||
kIsFinite,
|
kIsFinite,
|
||||||
@ -77,6 +80,7 @@ enum class HloOpcode {
|
|||||||
kPad,
|
kPad,
|
||||||
kParameter,
|
kParameter,
|
||||||
kPower,
|
kPower,
|
||||||
|
kReal,
|
||||||
kRecv,
|
kRecv,
|
||||||
kReduce,
|
kReduce,
|
||||||
kReducePrecision,
|
kReducePrecision,
|
||||||
|
@ -59,6 +59,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
|||||||
for (auto& invariant_checker : invariant_checkers_) {
|
for (auto& invariant_checker : invariant_checkers_) {
|
||||||
VLOG(1) << " Invariant checker " << invariant_checker->name();
|
VLOG(1) << " Invariant checker " << invariant_checker->name();
|
||||||
StatusOr<bool> changed_status = invariant_checker->Run(module);
|
StatusOr<bool> changed_status = invariant_checker->Run(module);
|
||||||
|
VLOG(1) << " Invariant checker done " << invariant_checker->name();
|
||||||
if (!changed_status.ok()) {
|
if (!changed_status.ok()) {
|
||||||
VLOG(2) << "Module failed invariant check:";
|
VLOG(2) << "Module failed invariant check:";
|
||||||
XLA_VLOG_LINES(2, module->ToString());
|
XLA_VLOG_LINES(2, module->ToString());
|
||||||
|
@ -64,6 +64,10 @@ class ShapeVerifier : public DfsHloVisitor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HandleConvert(HloInstruction* convert) override {
|
Status HandleConvert(HloInstruction* convert) override {
|
||||||
|
if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) {
|
||||||
|
TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape()))
|
||||||
|
<< "Unsupported complex->real kConvert";
|
||||||
|
}
|
||||||
return CheckShape(convert, ShapeInference::InferConvertShape(
|
return CheckShape(convert, ShapeInference::InferConvertShape(
|
||||||
convert->operand(0)->shape(),
|
convert->operand(0)->shape(),
|
||||||
convert->shape().element_type()));
|
convert->shape().element_type()));
|
||||||
|
@ -32,17 +32,16 @@ namespace xla {
|
|||||||
const HloInstruction& instruction) {
|
const HloInstruction& instruction) {
|
||||||
switch (instruction.opcode()) {
|
switch (instruction.opcode()) {
|
||||||
// Cheap instructions.
|
// Cheap instructions.
|
||||||
case HloOpcode::kAbs:
|
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
case HloOpcode::kBitcast:
|
case HloOpcode::kBitcast:
|
||||||
case HloOpcode::kBroadcast:
|
case HloOpcode::kBroadcast:
|
||||||
case HloOpcode::kCeil:
|
case HloOpcode::kCeil:
|
||||||
case HloOpcode::kClamp:
|
case HloOpcode::kClamp:
|
||||||
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kConcatenate:
|
case HloOpcode::kConcatenate:
|
||||||
case HloOpcode::kConstant:
|
case HloOpcode::kConstant:
|
||||||
case HloOpcode::kConvert:
|
case HloOpcode::kConvert:
|
||||||
case HloOpcode::kCopy:
|
case HloOpcode::kCopy:
|
||||||
case HloOpcode::kCos:
|
|
||||||
case HloOpcode::kDynamicSlice:
|
case HloOpcode::kDynamicSlice:
|
||||||
case HloOpcode::kDynamicUpdateSlice:
|
case HloOpcode::kDynamicUpdateSlice:
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kEq:
|
||||||
@ -50,6 +49,7 @@ namespace xla {
|
|||||||
case HloOpcode::kGe:
|
case HloOpcode::kGe:
|
||||||
case HloOpcode::kGetTupleElement:
|
case HloOpcode::kGetTupleElement:
|
||||||
case HloOpcode::kGt:
|
case HloOpcode::kGt:
|
||||||
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kInfeed:
|
case HloOpcode::kInfeed:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLe:
|
case HloOpcode::kLe:
|
||||||
@ -64,6 +64,7 @@ namespace xla {
|
|||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kOutfeed:
|
case HloOpcode::kOutfeed:
|
||||||
case HloOpcode::kPad:
|
case HloOpcode::kPad:
|
||||||
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kReducePrecision:
|
case HloOpcode::kReducePrecision:
|
||||||
case HloOpcode::kReshape:
|
case HloOpcode::kReshape:
|
||||||
case HloOpcode::kReverse:
|
case HloOpcode::kReverse:
|
||||||
@ -72,15 +73,21 @@ namespace xla {
|
|||||||
case HloOpcode::kShiftLeft:
|
case HloOpcode::kShiftLeft:
|
||||||
case HloOpcode::kShiftRightArithmetic:
|
case HloOpcode::kShiftRightArithmetic:
|
||||||
case HloOpcode::kShiftRightLogical:
|
case HloOpcode::kShiftRightLogical:
|
||||||
case HloOpcode::kSign:
|
|
||||||
case HloOpcode::kSin:
|
|
||||||
case HloOpcode::kSlice:
|
case HloOpcode::kSlice:
|
||||||
case HloOpcode::kSubtract:
|
case HloOpcode::kSubtract:
|
||||||
case HloOpcode::kTranspose:
|
case HloOpcode::kTranspose:
|
||||||
case HloOpcode::kTuple:
|
case HloOpcode::kTuple:
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
// Cheap instructions for reals, but expensive for complex.
|
||||||
|
case HloOpcode::kAbs:
|
||||||
|
case HloOpcode::kCos:
|
||||||
|
case HloOpcode::kSign:
|
||||||
|
case HloOpcode::kSin:
|
||||||
|
return ShapeUtil::ElementIsComplex(instruction.shape());
|
||||||
|
|
||||||
// Expensive instructions.
|
// Expensive instructions.
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
case HloOpcode::kBatchNormTraining:
|
case HloOpcode::kBatchNormTraining:
|
||||||
case HloOpcode::kBatchNormInference:
|
case HloOpcode::kBatchNormInference:
|
||||||
case HloOpcode::kBatchNormGrad:
|
case HloOpcode::kBatchNormGrad:
|
||||||
|
@ -75,7 +75,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
|
|||||||
Status FusedIrEmitter::HandleConstant(HloInstruction* constant,
|
Status FusedIrEmitter::HandleConstant(HloInstruction* constant,
|
||||||
const Literal& literal) {
|
const Literal& literal) {
|
||||||
llvm::Constant* initializer =
|
llvm::Constant* initializer =
|
||||||
llvm_ir::ConvertLiteralToIrConstant(literal, ir_builder_);
|
llvm_ir::ConvertLiteralToIrConstant(literal, module_);
|
||||||
llvm::GlobalVariable* global = new llvm::GlobalVariable(
|
llvm::GlobalVariable* global = new llvm::GlobalVariable(
|
||||||
*ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
|
*ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
|
||||||
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
|
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
|
||||||
@ -101,7 +101,7 @@ Status FusedIrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
|
|||||||
// Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
|
// Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
|
||||||
llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
|
llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
|
||||||
get_tuple_element->shape(), get_tuple_element->tuple_index(),
|
get_tuple_element->shape(), get_tuple_element->tuple_index(),
|
||||||
/*alignment=*/1, it->second, ir_builder_);
|
/*alignment=*/1, it->second, ir_builder_, module_);
|
||||||
gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr));
|
gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr));
|
||||||
// Emit code to read base tuple element array (if non-tuple shaped).
|
// Emit code to read base tuple element array (if non-tuple shaped).
|
||||||
if (!ShapeUtil::IsTuple(get_tuple_element->shape())) {
|
if (!ShapeUtil::IsTuple(get_tuple_element->shape())) {
|
||||||
@ -134,7 +134,7 @@ Status FusedIrEmitter::HandleTuple(
|
|||||||
std::vector<llvm::Type*> operand_elemental_ir_types;
|
std::vector<llvm::Type*> operand_elemental_ir_types;
|
||||||
for (HloInstruction* operand : operands) {
|
for (HloInstruction* operand : operands) {
|
||||||
operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
|
operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
|
||||||
operand->shape().element_type(), ir_builder_));
|
operand->shape().element_type(), module_));
|
||||||
}
|
}
|
||||||
generators_[tuple] =
|
generators_[tuple] =
|
||||||
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||||
|
@ -42,7 +42,8 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
|
|||||||
ElementalIrEmitter* elemental_emitter)
|
ElementalIrEmitter* elemental_emitter)
|
||||||
: parameter_arrays_(parameter_arrays),
|
: parameter_arrays_(parameter_arrays),
|
||||||
elemental_emitter_(elemental_emitter),
|
elemental_emitter_(elemental_emitter),
|
||||||
ir_builder_(elemental_emitter->ir_builder()) {}
|
ir_builder_(elemental_emitter->ir_builder()),
|
||||||
|
module_(elemental_emitter->module()) {}
|
||||||
|
|
||||||
Status DefaultAction(HloInstruction* hlo) override;
|
Status DefaultAction(HloInstruction* hlo) override;
|
||||||
|
|
||||||
@ -85,6 +86,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
// Borrowed
|
// Borrowed
|
||||||
llvm::IRBuilder<>* ir_builder_;
|
llvm::IRBuilder<>* ir_builder_;
|
||||||
|
llvm::Module* module_;
|
||||||
|
|
||||||
// Map from instruction pointers to functions to generate elements of their
|
// Map from instruction pointers to functions to generate elements of their
|
||||||
// outputs
|
// outputs
|
||||||
|
@ -229,9 +229,11 @@ llvm::Value* IrArray::EmitArrayElementAddress(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) {
|
if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) {
|
||||||
|
llvm::Module* module =
|
||||||
|
ir_builder->GetInsertBlock()->getParent()->getParent();
|
||||||
return ir_builder->CreateInBoundsGEP(
|
return ir_builder->CreateInBoundsGEP(
|
||||||
ir_builder->CreateBitCast(
|
ir_builder->CreateBitCast(
|
||||||
base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), ir_builder)
|
base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module)
|
||||||
->getPointerTo()),
|
->getPointerTo()),
|
||||||
{index.linear()}, llvm_ir::AsStringRef(name));
|
{index.linear()}, llvm_ir::AsStringRef(name));
|
||||||
}
|
}
|
||||||
@ -281,7 +283,8 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
|
|||||||
|
|
||||||
IrArray IrArray::CastToShape(const Shape& new_shape,
|
IrArray IrArray::CastToShape(const Shape& new_shape,
|
||||||
llvm::IRBuilder<>* ir_builder) const {
|
llvm::IRBuilder<>* ir_builder) const {
|
||||||
llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, ir_builder);
|
llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
|
||||||
|
llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
|
||||||
return IrArray(
|
return IrArray(
|
||||||
ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()),
|
ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()),
|
||||||
new_shape);
|
new_shape);
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "llvm/IR/DerivedTypes.h"
|
||||||
#include "llvm/IR/MDBuilder.h"
|
#include "llvm/IR/MDBuilder.h"
|
||||||
#include "llvm/IR/Operator.h"
|
#include "llvm/IR/Operator.h"
|
||||||
#include "llvm/Target/TargetOptions.h"
|
#include "llvm/Target/TargetOptions.h"
|
||||||
@ -38,6 +39,19 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace llvm_ir {
|
namespace llvm_ir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Note, this function is only useful in an insertion context; in a global
|
||||||
|
// (e.g. constants) context it will CHECK fail.
|
||||||
|
llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) {
|
||||||
|
auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock());
|
||||||
|
auto fn = CHECK_NOTNULL(block->getParent());
|
||||||
|
auto module = CHECK_NOTNULL(fn->getParent());
|
||||||
|
return module;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
string AsString(const std::string& str) {
|
string AsString(const std::string& str) {
|
||||||
return string(str.data(), str.length());
|
return string(str.data(), str.length());
|
||||||
}
|
}
|
||||||
@ -63,7 +77,7 @@ llvm::Value* EmitCallToIntrinsic(
|
|||||||
for (auto type : overloaded_types) {
|
for (auto type : overloaded_types) {
|
||||||
types.push_back(type);
|
types.push_back(type);
|
||||||
}
|
}
|
||||||
llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
|
llvm::Module* module = ModuleFromIRBuilder(ir_builder);
|
||||||
llvm::Function* intrinsic =
|
llvm::Function* intrinsic =
|
||||||
llvm::Intrinsic::getDeclaration(module, intrinsic_id, types);
|
llvm::Intrinsic::getDeclaration(module, intrinsic_id, types);
|
||||||
std::vector<llvm::Value*> operands_vec;
|
std::vector<llvm::Value*> operands_vec;
|
||||||
@ -119,38 +133,53 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
|
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
|
||||||
llvm::IRBuilder<>* ir_builder) {
|
llvm::Module* module) {
|
||||||
switch (element_type) {
|
switch (element_type) {
|
||||||
case PRED:
|
case PRED:
|
||||||
case S8:
|
case S8:
|
||||||
case U8:
|
case U8:
|
||||||
return ir_builder->getInt8Ty();
|
return llvm::Type::getInt8Ty(module->getContext());
|
||||||
case S16:
|
case S16:
|
||||||
case U16:
|
case U16:
|
||||||
return ir_builder->getInt16Ty();
|
return llvm::Type::getInt16Ty(module->getContext());
|
||||||
case S32:
|
case S32:
|
||||||
case U32:
|
case U32:
|
||||||
return ir_builder->getInt32Ty();
|
return llvm::Type::getInt32Ty(module->getContext());
|
||||||
case S64:
|
case S64:
|
||||||
case U64:
|
case U64:
|
||||||
return ir_builder->getInt64Ty();
|
return llvm::Type::getInt64Ty(module->getContext());
|
||||||
case F32:
|
case F32:
|
||||||
return ir_builder->getFloatTy();
|
return llvm::Type::getFloatTy(module->getContext());
|
||||||
case F64:
|
case F64:
|
||||||
return ir_builder->getDoubleTy();
|
return llvm::Type::getDoubleTy(module->getContext());
|
||||||
|
case C64: {
|
||||||
|
auto cplx_t = module->getTypeByName("complex64");
|
||||||
|
if (cplx_t == nullptr) {
|
||||||
|
// C++ standard dictates the memory layout of std::complex is contiguous
|
||||||
|
// real followed by imaginary. C++11 section 26.4 [complex.numbers]:
|
||||||
|
// If z is an lvalue expression of type cv std::complex<T> then the
|
||||||
|
// expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
|
||||||
|
// reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
|
||||||
|
// z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
|
||||||
|
// imaginary part of z.
|
||||||
|
return llvm::StructType::create(
|
||||||
|
"complex64", llvm::Type::getFloatTy(module->getContext()),
|
||||||
|
llvm::Type::getFloatTy(module->getContext()));
|
||||||
|
}
|
||||||
|
return cplx_t;
|
||||||
|
}
|
||||||
// A Tuple contains an array of pointers. Use i8*.
|
// A Tuple contains an array of pointers. Use i8*.
|
||||||
case TUPLE:
|
case TUPLE:
|
||||||
// An Opaque is like a void*, use i8*.
|
// An Opaque is like a void*, use i8*.
|
||||||
case OPAQUE:
|
case OPAQUE:
|
||||||
return ir_builder->getInt8PtrTy();
|
return llvm::Type::getInt8PtrTy(module->getContext());
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unsupported type " << element_type;
|
LOG(FATAL) << "unsupported type " << element_type;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) {
|
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
|
||||||
llvm::Type* result_type =
|
llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
|
||||||
PrimitiveTypeToIrType(shape.element_type(), ir_builder);
|
|
||||||
if (ShapeUtil::IsTuple(shape)) {
|
if (ShapeUtil::IsTuple(shape)) {
|
||||||
// A tuple buffer is an array of pointers.
|
// A tuple buffer is an array of pointers.
|
||||||
result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
|
result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
|
||||||
@ -197,10 +226,10 @@ namespace {
|
|||||||
// value down to zero).
|
// value down to zero).
|
||||||
llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
||||||
std::vector<int64>* multi_index,
|
std::vector<int64>* multi_index,
|
||||||
llvm::IRBuilder<>* ir_builder) {
|
llvm::Module* module) {
|
||||||
const Shape& shape = literal.shape();
|
const Shape& shape = literal.shape();
|
||||||
llvm::Type* ir_element_type =
|
llvm::Type* ir_element_type =
|
||||||
llvm_ir::PrimitiveTypeToIrType(shape.element_type(), ir_builder);
|
llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module);
|
||||||
if (dimension_index == -1) {
|
if (dimension_index == -1) {
|
||||||
// Base case of the recursion. Index into the data field of the protobuf
|
// Base case of the recursion. Index into the data field of the protobuf
|
||||||
// with the multi index.
|
// with the multi index.
|
||||||
@ -238,6 +267,16 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
|||||||
value = llvm::ConstantFP::get(ir_element_type,
|
value = llvm::ConstantFP::get(ir_element_type,
|
||||||
literal.Get<double>(*multi_index));
|
literal.Get<double>(*multi_index));
|
||||||
break;
|
break;
|
||||||
|
case C64: {
|
||||||
|
complex64 x = literal.Get<complex64>(*multi_index);
|
||||||
|
value = llvm::ConstantStruct::get(
|
||||||
|
static_cast<llvm::StructType*>(ir_element_type),
|
||||||
|
llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
|
||||||
|
x.real()),
|
||||||
|
llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
|
||||||
|
x.imag()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unsupported type " << shape.element_type();
|
LOG(FATAL) << "unsupported type " << shape.element_type();
|
||||||
}
|
}
|
||||||
@ -256,8 +295,8 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
|||||||
std::vector<llvm::Constant*> elements;
|
std::vector<llvm::Constant*> elements;
|
||||||
for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
|
for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
|
||||||
(*multi_index)[dimension] = i;
|
(*multi_index)[dimension] = i;
|
||||||
elements.push_back(LiteralToConstant(literal, dimension_index - 1,
|
elements.push_back(
|
||||||
multi_index, ir_builder));
|
LiteralToConstant(literal, dimension_index - 1, multi_index, module));
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Type* element_type;
|
llvm::Type* element_type;
|
||||||
@ -279,11 +318,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
|
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
|
||||||
llvm::IRBuilder<>* ir_builder) {
|
llvm::Module* module) {
|
||||||
std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
|
std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
|
||||||
llvm::Constant* value = LiteralToConstant(
|
llvm::Constant* value = LiteralToConstant(
|
||||||
literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
|
literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
|
||||||
&multi_index, ir_builder);
|
&multi_index, module);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -380,7 +419,8 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
|
|||||||
// comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
|
// comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
|
||||||
// arrays. So we extend it to i8 so that it's addressable.
|
// arrays. So we extend it to i8 so that it's addressable.
|
||||||
return ir_builder->CreateZExt(
|
return ir_builder->CreateZExt(
|
||||||
comparison_result, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder));
|
comparison_result,
|
||||||
|
llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal helper that is called from emitted code to log an int64 value with a
|
// Internal helper that is called from emitted code to log an int64 value with a
|
||||||
|
@ -127,11 +127,11 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
|
|||||||
|
|
||||||
// Returns the LLVM type which represents the given XLA primitive type.
|
// Returns the LLVM type which represents the given XLA primitive type.
|
||||||
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
|
llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
|
||||||
llvm::IRBuilder<>* ir_builder);
|
llvm::Module* module);
|
||||||
|
|
||||||
// Returns the LLVM type which represents the given XLA shape. For example,
|
// Returns the LLVM type which represents the given XLA shape. For example,
|
||||||
// if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
|
// if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
|
||||||
llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder);
|
llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
|
||||||
|
|
||||||
// Returns a value that represents a pointer to a global string constant that
|
// Returns a value that represents a pointer to a global string constant that
|
||||||
// encodes the shape as a serialized protobuf.
|
// encodes the shape as a serialized protobuf.
|
||||||
@ -149,7 +149,7 @@ StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
|
|||||||
// Converts a given literal to an IR Constant. Literals have known constant
|
// Converts a given literal to an IR Constant. Literals have known constant
|
||||||
// values at IR emission time.
|
// values at IR emission time.
|
||||||
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
|
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
|
||||||
llvm::IRBuilder<>* ir_builder);
|
llvm::Module* module);
|
||||||
|
|
||||||
// Inserts an allocate of the requested type at the entry point of the
|
// Inserts an allocate of the requested type at the entry point of the
|
||||||
// function that the builder is currently building. The insert point
|
// function that the builder is currently building. The insert point
|
||||||
|
@ -31,14 +31,15 @@ namespace xla {
|
|||||||
namespace llvm_ir {
|
namespace llvm_ir {
|
||||||
|
|
||||||
void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
|
void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
|
||||||
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) {
|
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
|
||||||
|
llvm::Module* module) {
|
||||||
CHECK(ShapeUtil::IsScalar(pred.GetShape()));
|
CHECK(ShapeUtil::IsScalar(pred.GetShape()));
|
||||||
|
|
||||||
llvm::LoadInst* pred_value =
|
llvm::LoadInst* pred_value =
|
||||||
ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value");
|
ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value");
|
||||||
llvm::Value* pred_cond = ir_builder->CreateICmpNE(
|
llvm::Value* pred_cond = ir_builder->CreateICmpNE(
|
||||||
pred_value,
|
pred_value,
|
||||||
llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0),
|
llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0),
|
||||||
"boolean_predicate");
|
"boolean_predicate");
|
||||||
|
|
||||||
VLOG(2) << "HandleSelect for tuple:";
|
VLOG(2) << "HandleSelect for tuple:";
|
||||||
@ -71,11 +72,11 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
|
|||||||
|
|
||||||
void EmitTuple(IrArray tuple,
|
void EmitTuple(IrArray tuple,
|
||||||
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
||||||
llvm::IRBuilder<>* ir_builder) {
|
llvm::IRBuilder<>* ir_builder, llvm::Module* module) {
|
||||||
for (size_t i = 0; i < operands.size(); ++i) {
|
for (size_t i = 0; i < operands.size(); ++i) {
|
||||||
auto* store = ir_builder->CreateStore(
|
auto* store = ir_builder->CreateStore(
|
||||||
ir_builder->CreatePointerCast(operands[i],
|
ir_builder->CreatePointerCast(operands[i],
|
||||||
PrimitiveTypeToIrType(TUPLE, ir_builder)),
|
PrimitiveTypeToIrType(TUPLE, module)),
|
||||||
ir_builder->CreateInBoundsGEP(
|
ir_builder->CreateInBoundsGEP(
|
||||||
tuple.GetBasePointer(),
|
tuple.GetBasePointer(),
|
||||||
{ir_builder->getInt64(0), ir_builder->getInt64(i)}));
|
{ir_builder->getInt64(0), ir_builder->getInt64(i)}));
|
||||||
@ -85,7 +86,8 @@ void EmitTuple(IrArray tuple,
|
|||||||
|
|
||||||
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||||
int alignment, llvm::Value* operand,
|
int alignment, llvm::Value* operand,
|
||||||
llvm::IRBuilder<>* ir_builder) {
|
llvm::IRBuilder<>* ir_builder,
|
||||||
|
llvm::Module* module) {
|
||||||
llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP(
|
llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP(
|
||||||
operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)});
|
operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)});
|
||||||
llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr);
|
llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr);
|
||||||
@ -98,7 +100,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
|||||||
}
|
}
|
||||||
SetAlignmentMetadataForLoad(src_buffer, alignment);
|
SetAlignmentMetadataForLoad(src_buffer, alignment);
|
||||||
|
|
||||||
llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder);
|
llvm::Type* element_type = ShapeToIrType(target_shape, module);
|
||||||
llvm::Value* ret_val =
|
llvm::Value* ret_val =
|
||||||
ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo());
|
ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo());
|
||||||
return ret_val;
|
return ret_val;
|
||||||
|
@ -60,13 +60,14 @@ namespace llvm_ir {
|
|||||||
// tuple_on_true or tuple_on_false:
|
// tuple_on_true or tuple_on_false:
|
||||||
// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i]
|
// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i]
|
||||||
void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
|
void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
|
||||||
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder);
|
llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
|
||||||
|
llvm::Module* module);
|
||||||
|
|
||||||
// A tuple is an array of pointers, one for each operand. Each pointer points to
|
// A tuple is an array of pointers, one for each operand. Each pointer points to
|
||||||
// the output buffer of its corresponding operand.
|
// the output buffer of its corresponding operand.
|
||||||
void EmitTuple(IrArray tuple,
|
void EmitTuple(IrArray tuple,
|
||||||
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
|
||||||
llvm::IRBuilder<>* ir_builder);
|
llvm::IRBuilder<>* ir_builder, llvm::Module* module);
|
||||||
|
|
||||||
// A tuple is an array of pointers, one for each operand. Each pointer points to
|
// A tuple is an array of pointers, one for each operand. Each pointer points to
|
||||||
// the output buffer of its corresponding operand. A GetTupleElement instruction
|
// the output buffer of its corresponding operand. A GetTupleElement instruction
|
||||||
@ -74,7 +75,8 @@ void EmitTuple(IrArray tuple,
|
|||||||
// Returns an llvm value representing a pointer to the tuple element buffer.
|
// Returns an llvm value representing a pointer to the tuple element buffer.
|
||||||
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
|
||||||
int alignment, llvm::Value* operand,
|
int alignment, llvm::Value* operand,
|
||||||
llvm::IRBuilder<>* ir_builder);
|
llvm::IRBuilder<>* ir_builder,
|
||||||
|
llvm::Module* module);
|
||||||
} // namespace llvm_ir
|
} // namespace llvm_ir
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
@ -53,6 +53,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
|
|||||||
return UNOP_EXP;
|
return UNOP_EXP;
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
return UNOP_FLOOR;
|
return UNOP_FLOOR;
|
||||||
|
case HloOpcode::kImag:
|
||||||
|
return UNOP_IMAG;
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
return UNOP_IS_FINITE;
|
return UNOP_IS_FINITE;
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
@ -61,6 +63,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
|
|||||||
return UNOP_NOT;
|
return UNOP_NOT;
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
return UNOP_NEGATE;
|
return UNOP_NEGATE;
|
||||||
|
case HloOpcode::kReal:
|
||||||
|
return UNOP_REAL;
|
||||||
case HloOpcode::kRoundNearestAfz:
|
case HloOpcode::kRoundNearestAfz:
|
||||||
return UNOP_ROUND_NEAREST_AFZ;
|
return UNOP_ROUND_NEAREST_AFZ;
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
@ -81,6 +85,10 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
|
|||||||
// opcode.
|
// opcode.
|
||||||
BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
|
BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
|
||||||
switch (opcode) {
|
switch (opcode) {
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
|
return BINOP_ATAN2;
|
||||||
|
case HloOpcode::kComplex:
|
||||||
|
return BINOP_COMPLEX;
|
||||||
case HloOpcode::kDot:
|
case HloOpcode::kDot:
|
||||||
return BINOP_DOT;
|
return BINOP_DOT;
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
@ -307,19 +315,41 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
|||||||
switch (operation) {
|
switch (operation) {
|
||||||
case UNOP_FLOOR:
|
case UNOP_FLOOR:
|
||||||
case UNOP_CEIL:
|
case UNOP_CEIL:
|
||||||
|
if (!ShapeUtil::ElementIsFloating(arg)) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"expected element type in shape to be floating for floor/ceil "
|
||||||
|
"operation; got %s",
|
||||||
|
PrimitiveType_Name(arg.element_type()).c_str());
|
||||||
|
}
|
||||||
|
return arg;
|
||||||
case UNOP_COS:
|
case UNOP_COS:
|
||||||
case UNOP_SIN:
|
case UNOP_SIN:
|
||||||
case UNOP_EXP:
|
case UNOP_EXP:
|
||||||
case UNOP_LOG:
|
case UNOP_LOG:
|
||||||
case UNOP_TANH:
|
case UNOP_TANH:
|
||||||
if (!ShapeUtil::ElementIsFloating(arg)) {
|
if (!ShapeUtil::ElementIsFloating(arg) &&
|
||||||
|
!ShapeUtil::ElementIsComplex(arg)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"expected element type in shape to be floating for exp/log/tanh "
|
"expected element type in shape to be floating or complex for "
|
||||||
"operation; got %s",
|
"sin/cos/exp/log/tanh operation; got %s",
|
||||||
PrimitiveType_Name(arg.element_type()).c_str());
|
PrimitiveType_Name(arg.element_type()).c_str());
|
||||||
}
|
}
|
||||||
return arg;
|
return arg;
|
||||||
|
case UNOP_REAL:
|
||||||
|
case UNOP_IMAG:
|
||||||
|
if (!ShapeUtil::ElementIsComplex(arg)) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"expected element type in shape to be complex for real/imag "
|
||||||
|
"operation; got %s",
|
||||||
|
PrimitiveType_Name(arg.element_type()).c_str());
|
||||||
|
}
|
||||||
|
return ShapeUtil::ChangeElementType(arg, F32);
|
||||||
case UNOP_ABS:
|
case UNOP_ABS:
|
||||||
|
if (ShapeUtil::ElementIsComplex(arg)) {
|
||||||
|
return ShapeUtil::ChangeElementType(
|
||||||
|
arg, primitive_util::ComplexComponentType(arg.element_type()));
|
||||||
|
}
|
||||||
|
return arg;
|
||||||
case UNOP_NEGATE:
|
case UNOP_NEGATE:
|
||||||
case UNOP_ROUND_NEAREST_AFZ:
|
case UNOP_ROUND_NEAREST_AFZ:
|
||||||
case UNOP_SIGN:
|
case UNOP_SIGN:
|
||||||
@ -751,6 +781,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||||||
case BINOP_MIN:
|
case BINOP_MIN:
|
||||||
case BINOP_SUB:
|
case BINOP_SUB:
|
||||||
case BINOP_ADD:
|
case BINOP_ADD:
|
||||||
|
case BINOP_ATAN2:
|
||||||
case BINOP_POW:
|
case BINOP_POW:
|
||||||
case BINOP_DIV:
|
case BINOP_DIV:
|
||||||
case BINOP_REM:
|
case BINOP_REM:
|
||||||
@ -761,6 +792,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||||||
return InferElementwiseBinaryOpShape(operation, lhs, rhs,
|
return InferElementwiseBinaryOpShape(operation, lhs, rhs,
|
||||||
broadcast_dimensions);
|
broadcast_dimensions);
|
||||||
|
|
||||||
|
case BINOP_COMPLEX: {
|
||||||
|
if (!ShapeUtil::ElementIsFloating(lhs)) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"expected element type in shape to be floating for complex compose "
|
||||||
|
"operation; got %s",
|
||||||
|
PrimitiveType_Name(lhs.element_type()).c_str());
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(const Shape& shape,
|
||||||
|
InferElementwiseBinaryOpShape(operation, lhs, rhs,
|
||||||
|
broadcast_dimensions));
|
||||||
|
if (lhs.element_type() == F32) {
|
||||||
|
return ShapeUtil::ChangeElementType(shape, C64);
|
||||||
|
} else {
|
||||||
|
return Unimplemented("complex component type not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
case BINOP_AND:
|
case BINOP_AND:
|
||||||
case BINOP_OR:
|
case BINOP_OR:
|
||||||
if (lhs.element_type() != PRED &&
|
if (lhs.element_type() != PRED &&
|
||||||
|
@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test {
|
|||||||
// Some handy scalar shapes.
|
// Some handy scalar shapes.
|
||||||
const Shape s32_ = ShapeUtil::MakeShape(S32, {});
|
const Shape s32_ = ShapeUtil::MakeShape(S32, {});
|
||||||
const Shape f32_ = ShapeUtil::MakeShape(F32, {});
|
const Shape f32_ = ShapeUtil::MakeShape(F32, {});
|
||||||
|
const Shape f64_ = ShapeUtil::MakeShape(F64, {});
|
||||||
const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
|
const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
|
||||||
|
|
||||||
// Some handy vector and matrix shapes of F32 type.
|
// Some handy vector and matrix shapes of F32 type.
|
||||||
@ -251,6 +252,44 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) {
|
|||||||
.ok());
|
.ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ShapeInferenceTest, Complex) {
|
||||||
|
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
|
||||||
|
const tensorflow::gtl::ArraySlice<int64>& bcast) {
|
||||||
|
return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX,
|
||||||
|
lhs, rhs, bcast);
|
||||||
|
};
|
||||||
|
// Inputs must be FP.
|
||||||
|
ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
|
||||||
|
ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
|
||||||
|
// Component types must match.
|
||||||
|
ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
|
||||||
|
// Only F32->C64 supported.
|
||||||
|
ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok());
|
||||||
|
// Validate correct uses.
|
||||||
|
Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
||||||
|
|
||||||
|
Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(result,
|
||||||
|
complex_shape(vector_64_, matrix_32_64_, {1}));
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(result,
|
||||||
|
complex_shape(matrix_32_64_, vector_64_, {1}));
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(result,
|
||||||
|
complex_shape(matrix_32_64_, matrix_32_64_, {}));
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
|
||||||
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
|
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
|
||||||
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
|
StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
|
||||||
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
|
VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
|
||||||
|
@ -55,6 +55,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
|
|||||||
return HloOpcode::kExp;
|
return HloOpcode::kExp;
|
||||||
case UNOP_FLOOR:
|
case UNOP_FLOOR:
|
||||||
return HloOpcode::kFloor;
|
return HloOpcode::kFloor;
|
||||||
|
case UNOP_IMAG:
|
||||||
|
return HloOpcode::kImag;
|
||||||
case UNOP_IS_FINITE:
|
case UNOP_IS_FINITE:
|
||||||
return HloOpcode::kIsFinite;
|
return HloOpcode::kIsFinite;
|
||||||
case UNOP_LOG:
|
case UNOP_LOG:
|
||||||
@ -63,6 +65,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
|
|||||||
return HloOpcode::kNot;
|
return HloOpcode::kNot;
|
||||||
case UNOP_NEGATE:
|
case UNOP_NEGATE:
|
||||||
return HloOpcode::kNegate;
|
return HloOpcode::kNegate;
|
||||||
|
case UNOP_REAL:
|
||||||
|
return HloOpcode::kReal;
|
||||||
case UNOP_ROUND_NEAREST_AFZ:
|
case UNOP_ROUND_NEAREST_AFZ:
|
||||||
return HloOpcode::kRoundNearestAfz;
|
return HloOpcode::kRoundNearestAfz;
|
||||||
case UNOP_SIGN:
|
case UNOP_SIGN:
|
||||||
@ -80,6 +84,10 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
|
|||||||
|
|
||||||
HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
|
HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
|
||||||
switch (binop) {
|
switch (binop) {
|
||||||
|
case BINOP_ATAN2:
|
||||||
|
return HloOpcode::kAtan2;
|
||||||
|
case BINOP_COMPLEX:
|
||||||
|
return HloOpcode::kComplex;
|
||||||
case BINOP_DOT:
|
case BINOP_DOT:
|
||||||
return HloOpcode::kDot;
|
return HloOpcode::kDot;
|
||||||
case BINOP_MUL:
|
case BINOP_MUL:
|
||||||
|
@ -272,6 +272,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
|||||||
case U16:
|
case U16:
|
||||||
case U32:
|
case U32:
|
||||||
case U64:
|
case U64:
|
||||||
|
case C64:
|
||||||
case TUPLE:
|
case TUPLE:
|
||||||
case OPAQUE:
|
case OPAQUE:
|
||||||
return false;
|
return false;
|
||||||
|
@ -361,8 +361,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
|
|||||||
ComputationBuilder* builder, const Array2D<NativeT>& expected,
|
ComputationBuilder* builder, const Array2D<NativeT>& expected,
|
||||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
||||||
static_assert(std::is_same<NativeT, float>::value ||
|
static_assert(std::is_same<NativeT, float>::value ||
|
||||||
std::is_same<NativeT, double>::value,
|
std::is_same<NativeT, double>::value ||
|
||||||
"Floating point type required when specifying an ErrorSpec");
|
std::is_same<NativeT, complex64>::value,
|
||||||
|
"Float or complex type required when specifying an ErrorSpec");
|
||||||
std::unique_ptr<Literal> expected_literal =
|
std::unique_ptr<Literal> expected_literal =
|
||||||
Literal::CreateR2FromArray2D<NativeT>(expected);
|
Literal::CreateR2FromArray2D<NativeT>(expected);
|
||||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||||
@ -384,8 +385,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
|
|||||||
ComputationBuilder* builder, const Array3D<NativeT>& expected,
|
ComputationBuilder* builder, const Array3D<NativeT>& expected,
|
||||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
||||||
static_assert(std::is_same<NativeT, float>::value ||
|
static_assert(std::is_same<NativeT, float>::value ||
|
||||||
std::is_same<NativeT, double>::value,
|
std::is_same<NativeT, double>::value ||
|
||||||
"Floating point type required when specifying an ErrorSpec");
|
std::is_same<NativeT, complex64>::value,
|
||||||
|
"Float or complex type required when specifying an ErrorSpec");
|
||||||
std::unique_ptr<Literal> expected_literal =
|
std::unique_ptr<Literal> expected_literal =
|
||||||
Literal::CreateR3FromArray3D<NativeT>(expected);
|
Literal::CreateR3FromArray3D<NativeT>(expected);
|
||||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||||
@ -407,8 +409,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
|
|||||||
ComputationBuilder* builder, const Array4D<NativeT>& expected,
|
ComputationBuilder* builder, const Array4D<NativeT>& expected,
|
||||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
|
||||||
static_assert(std::is_same<NativeT, float>::value ||
|
static_assert(std::is_same<NativeT, float>::value ||
|
||||||
std::is_same<NativeT, double>::value,
|
std::is_same<NativeT, double>::value ||
|
||||||
"Floating point type required when specifying an ErrorSpec");
|
std::is_same<NativeT, complex64>::value,
|
||||||
|
"Float or complex type required when specifying an ErrorSpec");
|
||||||
std::unique_ptr<Literal> expected_literal =
|
std::unique_ptr<Literal> expected_literal =
|
||||||
Literal::CreateR4FromArray4D<NativeT>(expected);
|
Literal::CreateR4FromArray4D<NativeT>(expected);
|
||||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
|
||||||
|
@ -347,7 +347,7 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) {
|
|||||||
TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
|
TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
|
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
|
||||||
constexpr bool kLhsRowMajor = true;
|
constexpr bool kLhsRowMajor = true;
|
||||||
constexpr bool kRhsRowMajor = true;
|
constexpr bool kRhsRowMajor = true;
|
||||||
TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
|
TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
|
||||||
@ -357,7 +357,11 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) {
|
|||||||
TestNonsquareMatrixDot<double>();
|
TestNonsquareMatrixDot<double>();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DotOperationTest, ConcurrentMatMul) {
|
XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
|
||||||
|
TestNonsquareMatrixDot<complex64>();
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(DotOperationTest, ConcurrentMatMul) {
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
|
||||||
auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});
|
auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});
|
||||||
|
@ -41,8 +41,12 @@ class UnaryOpTest : public ClientLibraryTestBase {
|
|||||||
auto arg = builder.ConstantR1<T>({});
|
auto arg = builder.ConstantR1<T>({});
|
||||||
auto abs = builder.Abs(arg);
|
auto abs = builder.Abs(arg);
|
||||||
|
|
||||||
|
if (primitive_util::NativeToPrimitiveType<T>() == C64) {
|
||||||
|
ComputeAndCompareR1<float>(&builder, {}, {});
|
||||||
|
} else {
|
||||||
ComputeAndCompareR1<T>(&builder, {}, {});
|
ComputeAndCompareR1<T>(&builder, {}, {});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void AbsTestHelper() {
|
void AbsTestHelper() {
|
||||||
@ -80,14 +84,58 @@ int UnaryOpTest::inf<int>() {
|
|||||||
return 2147483647;
|
return 2147483647;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void UnaryOpTest::AbsTestHelper<complex64>() {
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
auto arg = builder.ConstantR1<complex64>({{-2, 0},
|
||||||
|
{0, 25},
|
||||||
|
{0, 0},
|
||||||
|
{-0.3f, 0.4f},
|
||||||
|
{0, inf<float>()},
|
||||||
|
{-inf<float>(), 0}});
|
||||||
|
auto abs = builder.Abs(arg);
|
||||||
|
|
||||||
|
std::unique_ptr<Literal> expected =
|
||||||
|
Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
|
||||||
|
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void UnaryOpTest::SignTestHelper<complex64>() {
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
auto arg = builder.ConstantR1<complex64>(
|
||||||
|
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
|
||||||
|
auto sign = builder.Sign(arg);
|
||||||
|
|
||||||
|
std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>(
|
||||||
|
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
|
||||||
|
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void UnaryOpTest::SignAbsTestHelper<complex64>() {
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
auto arg =
|
||||||
|
builder.ConstantR1<complex64>({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}});
|
||||||
|
auto sign = builder.Sign(arg);
|
||||||
|
auto abs = builder.Abs(arg);
|
||||||
|
builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg);
|
||||||
|
|
||||||
|
std::unique_ptr<Literal> expected =
|
||||||
|
Literal::CreateR1<complex64>({0, 0, 0, 0});
|
||||||
|
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
|
XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
|
||||||
AbsSize0TestHelper<int>();
|
AbsSize0TestHelper<int>();
|
||||||
AbsSize0TestHelper<float>();
|
AbsSize0TestHelper<float>();
|
||||||
|
AbsSize0TestHelper<complex64>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(UnaryOpTest, AbsTestR1) {
|
XLA_TEST_F(UnaryOpTest, AbsTestR1) {
|
||||||
AbsTestHelper<int>();
|
AbsTestHelper<int>();
|
||||||
AbsTestHelper<float>();
|
AbsTestHelper<float>();
|
||||||
|
AbsTestHelper<complex64>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(UnaryOpTest, AbsTestR0) {
|
XLA_TEST_F(UnaryOpTest, AbsTestR0) {
|
||||||
@ -98,34 +146,44 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) {
|
|||||||
auto absf = builder.Abs(argf);
|
auto absf = builder.Abs(argf);
|
||||||
auto argf0 = builder.ConstantR0<float>(-0.0f);
|
auto argf0 = builder.ConstantR0<float>(-0.0f);
|
||||||
auto absf0 = builder.Abs(argf0);
|
auto absf0 = builder.Abs(argf0);
|
||||||
builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
|
auto argc = builder.ConstantR0<complex64>({-0.3f, 0.4f});
|
||||||
absi, PrimitiveType::F32)));
|
auto absc = builder.Abs(argc);
|
||||||
|
builder.Add(builder.Add(absc, absf0),
|
||||||
|
builder.Add(absf, builder.ConvertElementType(absi, F32)));
|
||||||
|
|
||||||
ComputeAndCompareR0<float>(&builder, 8.0f, {});
|
ComputeAndCompareR0<float>(&builder, 8.5f, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(UnaryOpTest, SignTestR0) {
|
XLA_TEST_F(UnaryOpTest, SignTestR0) {
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
auto argi = builder.ConstantR0<int>(-5);
|
auto argi = builder.ConstantR0<int>(-5);
|
||||||
auto absi = builder.Sign(argi);
|
auto sgni = builder.Sign(argi); // -1
|
||||||
auto argf = builder.ConstantR0<float>(-4.0f);
|
auto argf = builder.ConstantR0<float>(-4.0f);
|
||||||
auto absf = builder.Sign(argf);
|
auto sgnf = builder.Sign(argf); // -1
|
||||||
auto argf0 = builder.ConstantR0<float>(-0.0f);
|
auto argf0 = builder.ConstantR0<float>(-0.0f);
|
||||||
auto absf0 = builder.Sign(argf0);
|
auto sgnf0 = builder.Sign(argf0); // 0
|
||||||
builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
|
auto argc = builder.ConstantR0<complex64>({-.3, .4});
|
||||||
absi, PrimitiveType::F32)));
|
auto sgnc = builder.Sign(argc); // (-.6, .8)
|
||||||
|
builder.Add(sgnc, builder.ConvertElementType(
|
||||||
|
builder.Add(builder.Add(sgnf0, sgnf),
|
||||||
|
builder.ConvertElementType(sgni, F32)),
|
||||||
|
C64));
|
||||||
|
|
||||||
ComputeAndCompareR0<float>(&builder, -2.0f, {});
|
std::unique_ptr<Literal> expected =
|
||||||
|
Literal::CreateR0<complex64>({-2.6f, 0.8f});
|
||||||
|
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(UnaryOpTest, SignTestR1) {
|
XLA_TEST_F(UnaryOpTest, SignTestR1) {
|
||||||
SignTestHelper<int>();
|
SignTestHelper<int>();
|
||||||
SignTestHelper<float>();
|
SignTestHelper<float>();
|
||||||
|
SignTestHelper<complex64>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
|
XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
|
||||||
SignAbsTestHelper<int>();
|
SignAbsTestHelper<int>();
|
||||||
SignAbsTestHelper<float>();
|
SignAbsTestHelper<float>();
|
||||||
|
SignAbsTestHelper<complex64>();
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
|
XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
|
||||||
|
@ -235,11 +235,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
|||||||
case HloOpcode::kCopy:
|
case HloOpcode::kCopy:
|
||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
case HloOpcode::kSin:
|
case HloOpcode::kSin:
|
||||||
case HloOpcode::kSort:
|
case HloOpcode::kSort:
|
||||||
@ -256,6 +258,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
|||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kSubtract:
|
case HloOpcode::kSubtract:
|
||||||
|
case HloOpcode::kAtan2:
|
||||||
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kEq:
|
||||||
case HloOpcode::kGe:
|
case HloOpcode::kGe:
|
||||||
case HloOpcode::kGt:
|
case HloOpcode::kGt:
|
||||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_TYPES_H_
|
#define TENSORFLOW_COMPILER_XLA_TYPES_H_
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -35,7 +37,7 @@ using ::tensorflow::uint16;
|
|||||||
using ::tensorflow::uint32;
|
using ::tensorflow::uint32;
|
||||||
using ::tensorflow::uint64;
|
using ::tensorflow::uint64;
|
||||||
|
|
||||||
typedef std::complex<float> complex64;
|
using complex64 = std::complex<float>;
|
||||||
|
|
||||||
using ::Eigen::half;
|
using ::Eigen::half;
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ enum PrimitiveType {
|
|||||||
F64 = 12;
|
F64 = 12;
|
||||||
|
|
||||||
// Complex values of fixed width.
|
// Complex values of fixed width.
|
||||||
C64 = 15;
|
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>.
|
||||||
|
|
||||||
// A tuple is a polymorphic sequence; e.g. a shape that holds different
|
// A tuple is a polymorphic sequence; e.g. a shape that holds different
|
||||||
// sub-shapes. They are used for things like returning multiple values from a
|
// sub-shapes. They are used for things like returning multiple values from a
|
||||||
@ -667,6 +667,12 @@ enum UnaryOperation {
|
|||||||
// Elementwise, rounds x to nearest integral value, rounding half-way cases
|
// Elementwise, rounds x to nearest integral value, rounding half-way cases
|
||||||
// away from zero.
|
// away from zero.
|
||||||
UNOP_ROUND_NEAREST_AFZ = 14;
|
UNOP_ROUND_NEAREST_AFZ = 14;
|
||||||
|
|
||||||
|
// Elementwise, extract real component of complex x.
|
||||||
|
UNOP_REAL = 15;
|
||||||
|
|
||||||
|
// Elementwise, extract real component of complex x.
|
||||||
|
UNOP_IMAG = 16;
|
||||||
}
|
}
|
||||||
|
|
||||||
message UnaryOpRequest {
|
message UnaryOpRequest {
|
||||||
@ -721,6 +727,12 @@ enum BinaryOperation {
|
|||||||
BINOP_SHIFT_LEFT = 20;
|
BINOP_SHIFT_LEFT = 20;
|
||||||
BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
|
BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
|
||||||
BINOP_SHIFT_RIGHT_LOGICAL = 22;
|
BINOP_SHIFT_RIGHT_LOGICAL = 22;
|
||||||
|
|
||||||
|
// Complex from real, imag.
|
||||||
|
BINOP_COMPLEX = 23;
|
||||||
|
|
||||||
|
// Computes the 4-quadrant arctangent of the y, x input arguments.
|
||||||
|
BINOP_ATAN2 = 24;
|
||||||
}
|
}
|
||||||
|
|
||||||
message BinaryOpRequest {
|
message BinaryOpRequest {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user