From 937ff1b4cf1d954806e075d66198429a6d2312be Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 11 Jan 2019 13:05:56 -0800 Subject: [PATCH] [XLA] Add complex128 support. Minimally tested at the moment (tested via the TF tests). PiperOrigin-RevId: 228931680 --- .../jit/mark_for_compilation_pass_test.cc | 10 +- tensorflow/compiler/jit/xla_cpu_device.cc | 4 +- .../compiler/jit/xla_interpreter_device.cc | 4 +- tensorflow/compiler/tests/binary_ops_test.py | 2 +- tensorflow/compiler/tests/build_defs.bzl | 2 +- tensorflow/compiler/tests/unary_ops_test.py | 2 +- tensorflow/compiler/tf2xla/kernels/cast_op.cc | 4 +- .../compiler/tf2xla/kernels/const_op.cc | 11 +++ .../compiler/tf2xla/kernels/matmul_op.cc | 4 +- tensorflow/compiler/tf2xla/lib/util.cc | 6 ++ tensorflow/compiler/tf2xla/type_util.cc | 3 + tensorflow/compiler/tf2xla/xla_op_registry.h | 9 +- .../compiler/xla/client/lib/constants.h | 2 + tensorflow/compiler/xla/client/lib/math.cc | 4 +- tensorflow/compiler/xla/literal.cc | 98 +++++++++++++++---- tensorflow/compiler/xla/literal_comparison.cc | 58 +++++++++-- tensorflow/compiler/xla/literal_test.cc | 56 ++++++++++- tensorflow/compiler/xla/literal_util.cc | 14 +++ tensorflow/compiler/xla/primitive_util.cc | 7 +- tensorflow/compiler/xla/primitive_util.h | 10 ++ .../compiler/xla/python/numpy_bridge.cc | 11 +++ tensorflow/compiler/xla/python/xla_client.py | 1 + tensorflow/compiler/xla/service/BUILD | 2 + .../xla/service/algebraic_simplifier.cc | 3 + .../xla/service/cpu/dot_op_emitter.cc | 3 +- .../compiler/xla/service/cpu/ir_emitter.cc | 12 ++- .../compiler/xla/service/gpu/gemm_thunk.cc | 8 ++ .../xla/service/gpu/ir_emission_utils.cc | 3 +- .../compiler/xla/service/hlo_evaluator.cc | 97 ++++++++++++++++-- .../xla/service/hlo_evaluator_typed_visitor.h | 15 ++- .../hlo_evaluator_typed_visitor_complex128.cc | 22 +++++ .../compiler/xla/service/hlo_parser_test.cc | 11 +++ .../compiler/xla/service/llvm_ir/llvm_util.cc | 11 ++- .../compiler/xla/service/shape_inference.cc | 4 +- .../xla/service/shape_inference_test.cc | 8 +- tensorflow/compiler/xla/shape_util.cc | 3 + .../xla/tests/client_library_test_base.h | 15 ++- tensorflow/compiler/xla/types.h | 1 + tensorflow/compiler/xla/xla_data.proto | 6 +- 39 files changed, 462 insertions(+), 84 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index bf2c5508ea9..c2b6250f738 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -151,7 +151,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, Complex128Unsupported) { +TEST(XlaCompilationTest, StringUnsupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -159,10 +159,10 @@ TEST(XlaCompilationTest, Complex128Unsupported) { Node* a = ops::SourceOp( "Const", builder.opts() .WithName("A") - .WithAttr("dtype", DT_COMPLEX128) - .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); - Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); - ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + .WithAttr("dtype", DT_STRING) + .WithAttr("value", Tensor(DT_STRING, TensorShape()))); + Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B")); + ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index e9770647e7b..94dc61d55fb 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -83,9 +83,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { +constexpr std::array kAllXlaCpuTypes = { {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, - DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4007309ed1c..e1a58240615 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -26,9 +26,9 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { +constexpr std::array kExecAllTypes = { {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_BOOL, DT_BFLOAT16}}; + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9a5423c1b2a..a3651b4b0de 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -400,7 +400,7 @@ class BinaryOpsTest(xla_test.XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - ctypes = {np.complex64: np.float32} + ctypes = {np.complex64: np.float32, np.complex128: np.float64} self._testBinary( math_ops.complex, np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]), diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 447a7de2cb6..be9766c4ef4 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -64,7 +64,7 @@ def tf_xla_py_test( if backend == "cpu": backend_args += [ "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128", ] elif backend == "gpu": backend_args += [ diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 95c9e7ffd46..3c2875ba477 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -647,7 +647,7 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) - ctypes = {np.complex64: np.float32} + ctypes = {np.complex64: np.float32, np.complex128: np.float64} self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[3 - 4j, -1j, np.inf]], dtype=dtype), diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 8cc2479dd55..ce8131eeb4c 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -79,8 +79,8 @@ class BitcastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; } else { - // The only complex type in XLA is C64, so error out if the bitcast has a - // complex source or destination type and the bitcast is not trivial. + // Error out if the bitcast has a complex source or destination type and + // the bitcast is not trivial. OP_REQUIRES(ctx, !xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(dst_type_), diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index dff8af80022..ff6c54e47c6 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -83,6 +83,17 @@ class ConstOp : public XlaOpKernel { return; } break; + case DT_COMPLEX128: + if (proto_.scomplex_val_size() == 2) { + ctx->SetOutput( + 0, + xla::Broadcast(xla::ConstantR0( + b, xla::complex128(proto_.dcomplex_val(0), + proto_.dcomplex_val(1))), + shape.dim_sizes())); + return; + } + break; case DT_INT32: if (proto_.int_val_size() == 1) { ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 6440770c298..f36e0025250 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -24,8 +24,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}}; class MatMulOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index c0bd172d17c..06eda416118 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -54,6 +54,9 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::C64: return xla::ConstantR0(builder, value); break; + case xla::C128: + return xla::ConstantR0(builder, value); + break; default: LOG(FATAL) << "unhandled element type " << type; } @@ -90,6 +93,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::C64: literal = xla::LiteralUtil::CreateR0(value); break; + case xla::C128: + literal = xla::LiteralUtil::CreateR0(value); + break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; case xla::S16: diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index d00b1376620..732f957d732 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -69,6 +69,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); + case tensorflow::DT_COMPLEX128: + *type = xla::C128; + return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 0bdd4a10854..ce3b6b298c6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,13 +47,14 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array kNumericTypes = { +constexpr std::array kNumericTypes = { {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { +constexpr std::array kCpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, - DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL}}; constexpr std::array kGpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 81624614c1e..a38282e8dbd 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -56,6 +56,8 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { return ConstantR0(builder, static_cast(value)); case C64: return ConstantR0(builder, static_cast(value)); + case C128: + return ConstantR0(builder, static_cast(value)); case U8: return ConstantR0(builder, static_cast(value)); case U32: diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 36fdda39b41..3d0e3a2b93f 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -323,7 +324,8 @@ XlaOp MaybeConjugate(XlaOp x, bool conjugate) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == C64 && conjugate; + auto perform_conj = + primitive_util::IsComplexType(shape.element_type()) && conjugate; return perform_conj ? Conj(x) : x; }); } diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 81874fa405c..8600e8752cf 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -29,10 +29,12 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -411,6 +413,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { COPY_ELEMENTS(F32, float); COPY_ELEMENTS(F64, double); COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(C128, complex128); COPY_ELEMENTS(PRED, bool); #undef COPY_ELEMENTS default: @@ -548,6 +551,9 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, case C64: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); + case C128: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case PRED: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); @@ -766,6 +772,8 @@ Literal LiteralBase::Slice(absl::Span start_indices, return SliceInternal(result_shape, start_indices); case C64: return SliceInternal(result_shape, start_indices); + case C128: + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -814,6 +822,10 @@ string LiteralBase::GetAsString(absl::Span multi_index, complex64 c = Get(multi_index, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + case C128: { + complex128 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); } @@ -868,6 +880,11 @@ string LiteralBase::GetSparseElementAsString( GetSparseElement(sparse_element_number, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + case C128: { + complex128 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: LOG(FATAL) << "Invalid element type for sparse arrays: " << PrimitiveType_Name(subshape.element_type()); @@ -996,6 +1013,9 @@ void LiteralBase::Piece::SortSparseElements() { case C64: SortSparseElementsInternal(); break; + case C128: + SortSparseElementsInternal(); + break; case F16: SortSparseElementsInternal(); break; @@ -1230,7 +1250,24 @@ Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, } template -Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { +typename std::enable_if<(std::is_same::value) && + (std::is_same::value || + std::is_same::value), + Literal>::type +ConvertBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return NativeDestT(static_cast(src)); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(!std::is_same::value) || + (!std::is_same::value && + !std::is_same::value), + Literal>::type +ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1274,22 +1311,6 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } -template -Literal ConvertToC64(const LiteralBase& src_literal) { - CHECK(src_literal.shape().IsArray()); - Literal result_literal( - ShapeUtil::ChangeElementType(src_literal.shape(), C64)); - using NativeSrcT = - typename primitive_util::PrimitiveTypeToNative::type; - absl::Span src_data = src_literal.data(); - absl::Span dest_data = result_literal.data(); - int64 num_elements = src_literal.element_count(); - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = complex64(static_cast(src_data[i]), 0); - } - return result_literal; -} - template Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); @@ -1332,10 +1353,15 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: - if (!bitcast) { - return ConvertToC64(src_literal); + if (bitcast) { + break; } - break; + return ConvertIfTypesMatch(src_literal, false); + case C128: + if (bitcast) { + break; + } + return ConvertIfTypesMatch(src_literal, false); // Other types are not yet supported. default: break; @@ -1485,6 +1511,8 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case C64: return EqualElementsInternal(other, &multi_index); + case C128: + return EqualElementsInternal(other, &multi_index); default: LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); @@ -1628,6 +1656,9 @@ bool LiteralBase::IsAllComplex(complex64 value) const { case C64: return AllElementsEqualValue(root_piece().data(), value); + case C128: + return AllElementsEqualValue(root_piece().data(), + value); default: return false; } @@ -1707,6 +1738,11 @@ bool LiteralBase::IsAllFirst() const { auto data = piece.data(); return AllElementsEqualValue(data, data[0]); } + + case C128: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } default: return false; } @@ -1756,6 +1792,8 @@ bool LiteralBase::IsR1Iota() const { return Get({idx}) == static_cast(idx); case C64: return Get({idx}) == complex64(idx, 0.0f); + case C128: + return Get({idx}) == complex128(idx, 0.0f); case PRED: return Get({idx}) == idx; // token, opaque, tuple, etc. are all not iota. @@ -1799,6 +1837,8 @@ bool LiteralBase::IsZero(absl::Span indices) const { return Get(indices) == 0.0; case C64: return Get(indices) == complex64(0.0f, 0.0f); + case C128: + return Get(indices) == complex128(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); case BF16: @@ -1886,6 +1926,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { proto->add_c64s(value.imag()); } break; + case C128: + for (complex128 value : data()) { + proto->add_c128s(value.real()); + proto->add_c128s(value.imag()); + } + break; case TUPLE: case TOKEN: // Nothing to do but assign the shape which is done above. @@ -2018,7 +2064,17 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { for (int64 i = 0; i < complex_data.size(); ++i) { complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; } - } break; + break; + } + case C128: { + auto complex_data = data(); + TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = + complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)}; + } + break; + } case TUPLE: return InvalidArgument("Should not be called on tuple shapes: %s", ShapeUtil::HumanString(subshape())); diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 258bc966b1a..69efa06d39a 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -90,6 +90,12 @@ bool CompareEqual(complex64 lhs, complex64 rhs, return CompareEqual(lhs.real(), rhs.real(), multi_index) && CompareEqual(lhs.imag(), rhs.imag(), multi_index); } +template <> +bool CompareEqual(complex128 lhs, complex128 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} template Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs, @@ -143,6 +149,14 @@ Status MakeErrorStatus(complex64 lhs, complex64 rhs, } return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); } +template <> +Status MakeErrorStatus(complex128 lhs, complex128 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); + } + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); +} // A recursive function which iterates through every index of expected and // actual literal and compares their values elementwise. Returns true if all @@ -197,13 +211,6 @@ bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { } } -template <> -bool NanMismatch(complex64 expected, complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); -} - template <> bool NanMismatch(half expected, half actual, bool relaxed_nans) { return NanMismatch(static_cast(expected), @@ -232,6 +239,11 @@ string FpValueToString(complex64 value) { return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } +template <> +string FpValueToString(complex128 value) { + return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); +} + // Returns the absolute value of the given floating point value. This function // is used instead of std::abs directly in order to allow type-dependent // implementations for NearComparator. @@ -434,7 +446,7 @@ class NearComparator { mismatches_.data()[linear_index] = true; } - // For complex64 types, we compare real and imaginary parts individually. + // For complex types, we compare real and imaginary parts individually. void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { bool mismatch = false; CompareValues(expected.real(), actual.real(), linear_index); @@ -457,6 +469,29 @@ class NearComparator { mismatches_.data()[linear_index] = mismatch; } + void CompareValues(complex128 expected, complex128 actual, + int64 linear_index) { + bool mismatch = false; + CompareValues(expected.real(), actual.real(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for real part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + CompareValues(expected.imag(), actual.imag(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for imag part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + if (mismatch == true) { + num_mismatches_++; + } + mismatches_.data()[linear_index] = mismatch; + } + // Compares the two literals elementwise. void CompareLiterals() { // Fast path optimization for the case were layouts match. @@ -665,6 +700,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { case C64: result = Equal(expected, actual, index, 0); break; + case C128: + result = Equal(expected, actual, index, 0); + break; case TUPLE: { for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { result.Update(EqualHelper(LiteralSlice(expected, {i}), @@ -749,6 +787,10 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, return NearComparator::Compare( expected, actual, error, detailed_message, miscompare_callback); break; + case C128: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " << PrimitiveType_Name(expected.shape().element_type()) diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index a4088c6d119..e67bb5c32ff 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -118,6 +118,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString()); + auto c128_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); + EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString()); + auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString()); @@ -469,6 +472,21 @@ TEST_F(LiteralUtilTest, C64Equality) { EXPECT_NE(vector, vector_reversed); } +TEST_F(LiteralUtilTest, C128Equality) { + // Test equality with tuples. + auto vector = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto vector_clone = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + EXPECT_EQ(vector, vector_clone); + + auto vector_reversed = + LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + EXPECT_NE(vector, vector_reversed); +} + TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); @@ -836,6 +854,13 @@ TEST_F(LiteralUtilTest, PopulateR1C64) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateR1C128) { + Literal output(ShapeUtil::MakeShape(C128, {1})); + output.PopulateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, PopulateR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); @@ -897,6 +922,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR2C128) { + Literal output(ShapeUtil::MakeShape(C128, {2, 2})); + output.PopulateWithValue({4, 2}); + auto expected = + LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); @@ -1308,7 +1341,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - // clang-format on + auto c128 = LiteralUtil::CreateR4WithLayout({{ + {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, + {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, + {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, + }}, layout_r4_dim0major_); // clang-format on Literal conv; conv = s8.Convert(U16).ConsumeValueOrDie(); @@ -1374,10 +1411,20 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = s32.Convert(U16).ConsumeValueOrDie(); EXPECT_EQ(conv, u16); + conv = s32.Convert(C128).ConsumeValueOrDie(); + EXPECT_EQ(conv, c128); + + conv = f16.Convert(C128).ConsumeValueOrDie(); + EXPECT_EQ(conv, c128); + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c128.Convert(F32).status().code(), + tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c128.Convert(S32).status().code(), + tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1739,7 +1786,8 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}), + ShapeUtil::MakeShape(C128, {})})); EXPECT_EQ(tuple.Get({}, {0}), 0.0); EXPECT_EQ(tuple.Get({0}, {1}), false); @@ -1747,6 +1795,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); + EXPECT_EQ(tuple.Get({}, {4}), complex128(0.0, 0.0)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1756,6 +1805,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); auto vector_uint8 = LiteralUtil::CreateR1({128, 0, 2, 56, 127, 255}); auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_c128 = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = @@ -1776,6 +1827,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(tuple, to_from_proto(tuple)); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 5881e43e969..26b029c8d0c 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -130,6 +130,8 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(0); case C64: return LiteralUtil::CreateR0(0); + case C128: + return LiteralUtil::CreateR0(0); case PRED: return LiteralUtil::CreateR0(false); case TUPLE: @@ -165,6 +167,8 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(1); case C64: return LiteralUtil::CreateR0(1); + case C128: + return LiteralUtil::CreateR0(1); case PRED: return LiteralUtil::CreateR0(true); case S16: @@ -201,6 +205,8 @@ Literal ConvertType(LiteralSlice literal) { -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; + case C128: + LOG(FATAL) << "C128 element type has no minimum value"; case PRED: return LiteralUtil::CreateR0(false); case S16: @@ -345,6 +351,10 @@ Literal ConvertType(LiteralSlice literal) { new_literal.Set(to_multi_index, literal.Get(from_multi_index)); break; + case C128: + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); + break; default: LOG(FATAL) << "Unhandled primitive element type: " << PrimitiveType_Name(literal.shape().element_type()); @@ -393,6 +403,10 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: return LiteralUtil::CreateR0(literal.GetFirstElement()); + + case C128: + return LiteralUtil::CreateR0( + literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 00ad01fc407..04b40efbed5 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -27,7 +27,7 @@ bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16; } -bool IsComplexType(PrimitiveType type) { return type == C64; } +bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; } bool IsSignedIntegralType(PrimitiveType type) { return type == S8 || type == S16 || type == S32 || type == S64; @@ -67,6 +67,9 @@ int BitWidth(PrimitiveType type) { case C64: return 64; + case C128: + return 128; + case TUPLE: LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; @@ -82,6 +85,8 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) { switch (complex_type) { case C64: return F32; + case C128: + return F64; default: LOG(FATAL) << "Primitive type is not complex: " << PrimitiveType_Name(complex_type); diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 70603b6fed1..739a7c409ce 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -126,6 +126,11 @@ inline PrimitiveType NativeToPrimitiveType() { return C64; } +template <> +inline PrimitiveType NativeToPrimitiveType() { + return C128; +} + bool IsFloatingPointType(PrimitiveType type); bool IsComplexType(PrimitiveType type); @@ -225,6 +230,11 @@ struct PrimitiveTypeToNative { using type = complex64; }; +template <> +struct PrimitiveTypeToNative { + using type = complex128; +}; + // Returns the lower-case name of the given primitive type. const string& LowercasePrimitiveTypeName(PrimitiveType s); diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 9c905508255..52c5c621f72 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -54,6 +54,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT64; case C64: return NPY_COMPLEX64; + case C128: + return NPY_COMPLEX128; case TUPLE: return NPY_OBJECT; default: @@ -89,6 +91,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F64; case NPY_COMPLEX64: return C64; + case NPY_COMPLEX128: + return C128; case NPY_OBJECT: return TUPLE; default: @@ -111,6 +115,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT32: case NPY_FLOAT64: case NPY_COMPLEX64: + case NPY_COMPLEX128: case NPY_OBJECT: return true; default: @@ -430,6 +435,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_COMPLEX64: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_COMPLEX128: + CopyNumpyArrayToLiteral(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -470,6 +478,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_COMPLEX64: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_COMPLEX128: + CopyLiteralToNumpyArray(literal, py_array); + break; default: LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; } diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 4e71121c097..1684cb20e6d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -199,6 +199,7 @@ XLA_ELEMENT_TYPE_TO_DTYPE = { xla_data_pb2.F32: np.dtype('float32'), xla_data_pb2.F64: np.dtype('float64'), xla_data_pb2.C64: np.dtype('complex64'), + xla_data_pb2.C128: np.dtype('complex128'), xla_data_pb2.TUPLE: np.dtype(np.object), } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ff14c500e31..b324a810ce3 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -223,6 +223,7 @@ cc_library( "hlo_evaluator_typed_visitor.h", "hlo_evaluator_typed_visitor_bfloat16.cc", "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex128.cc", "hlo_evaluator_typed_visitor_complex64.cc", "hlo_evaluator_typed_visitor_double.cc", "hlo_evaluator_typed_visitor_float.cc", @@ -259,6 +260,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 5ac746c9f3f..4d6042b0d48 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -877,6 +877,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { case C64: TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); break; + case C128: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; default: return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index dba28aa51a2..48510181bd0 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -856,7 +856,8 @@ Status EmitNonBatchDotOperation( const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); + TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type || + C128 == type); DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name), target_array, lhs_array, rhs_array, addend_array, executable_run_options_value, b, hlo_module_config, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index b27a2325578..f8a997045a6 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -239,10 +239,12 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); DCHECK_GE(byte_size, 0); - // Largest scalar is a complex64 so we don't need to worry about the + // Largest scalar is a complex128 so we don't need to worry about the // int64->int truncation here. - DCHECK_LE(byte_size, 8); - return byte_size; + DCHECK_LE(byte_size, 16); + + // Allocations may be 8-byte aligned if part of a small block. + return std::min(8LL, byte_size); } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -942,7 +944,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { auto rhs = dot->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, F64, C64})); + /*supported_types=*/{F16, F32, F64, C64, C128})); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); if (dnums.lhs_contracting_dimensions_size() != 1) { @@ -1114,7 +1116,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto rhs = convolution->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, C64})); + /*supported_types=*/{F16, F32, C64, C128})); // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index b8fbe7d2bcb..86c9bc6a345 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -206,6 +206,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { return &DoGemm; case C64: return &DoGemm>; + case C128: + return &DoGemm>; default: LOG(FATAL) << "Unsupported type."; } @@ -221,6 +223,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm; case C64: return &DoGemmWithAlgorithm>; + case C128: + return &DoGemmWithAlgorithm>; default: LOG(FATAL) << "Unsupported type."; } @@ -235,6 +239,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { return &DoGemmAutotune; case C64: return &DoGemmAutotune>; + case C128: + return &DoGemmAutotune>; default: LOG(FATAL) << "Unsupported type."; } @@ -255,6 +261,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF64; case C64: return se::blas::ComputationType::kComplexF32; + case C128: + return se::blas::ComputationType::kComplexF64; default: LOG(FATAL) << "Unsupported type."; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index caccb188997..82bdd677d96 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -54,7 +54,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64 || output_primitive_type == C64); + output_primitive_type == F64 || output_primitive_type == C64 || + output_primitive_type == C128); return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) && IsRank2(rhs_shape, batch_dimensions_size) && IsRank2(output_shape, batch_dimensions_size) && diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 407615f8429..927e751d38d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -136,6 +136,37 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, return std::move(result); } +template <> +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](complex128 lhs_el, complex128 rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](complex128 lhs_el, complex128 rhs_el) { + return lhs_el != rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + + return std::move(result); +} + } // namespace // Note that unsupported types by the typed visitor does not necessarily imply @@ -170,6 +201,8 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) absl::make_unique>(this); typed_visitors_[C64] = absl::make_unique>(this); + typed_visitors_[C128] = + absl::make_unique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all @@ -500,6 +533,13 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); break; } + case C128: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](complex128 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } case F16: { auto result_or = ElementWiseUnaryOpImpl( real, [](Eigen::half elem_operand) { return elem_operand; }, @@ -530,11 +570,29 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { } Status HloEvaluator::HandleImag(HloInstruction* imag) { - auto result_or = ElementWiseUnaryOpImpl( - imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, - GetEvaluatedLiteralFor(imag->operand(0))); + auto operand = imag->operand(0); + switch (operand->shape().element_type()) { + case C64: { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + break; + } + case C128: { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex128 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } - TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); return Status::OK(); } @@ -544,11 +602,27 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) { TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape())); Literal result(complex->shape()); - TF_RETURN_IF_ERROR( - result.Populate([&](absl::Span multi_index) { - return std::complex(real.Get(multi_index), - imag.Get(multi_index)); - })); + switch (complex->shape().element_type()) { + case C64: { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return std::complex(real.Get(multi_index), + imag.Get(multi_index)); + })); + break; + } + case C128: { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return std::complex(real.Get(multi_index), + imag.Get(multi_index)); + })); + break; + } + default: + LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: " + << PrimitiveType_Name(complex->shape().element_type()); + } evaluated_[complex] = std::move(result); return Status::OK(); @@ -647,6 +721,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; + case C128: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; default: LOG(FATAL) << "HandleCompare: unknown primitive type: " << PrimitiveType_Name(lhs->shape().element_type()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 698b1773104..f3f400e46f3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/meta/type_traits.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -39,9 +40,8 @@ namespace xla { // Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is // a "private" header that's not exposed outside of hlo_evaluator.cc. template -using is_complex_t = std::is_same; -template -using is_complex64_t = std::is_same; +using is_complex_t = + absl::disjunction, std::is_same>; // It's UB to use std::sort with std::less, because of NaNs. Define // "safe" less functions which are actually strict weak orders. -NaN and NaN @@ -212,7 +212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(abs->operand(0)); @@ -231,6 +231,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // specifying the ElementwiseT explicitly as C64 is needed below. if (abs->operand(0)->shape().element_type() == C64) { return HandleAbs(abs); + } else if (abs->operand(0)->shape().element_type() == C128) { + return HandleAbs(abs); } return HandleAbs(abs); } @@ -1616,6 +1618,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; } + case C128: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } default: LOG(FATAL) << "HandleMap: unhandled primitive type for " "input operand: " @@ -3040,6 +3046,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc new file mode 100644 index 00000000000..1f48140ee4f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 76b8a5bc117..74ef8a2fbaf 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -551,6 +551,17 @@ ENTRY %Transpose.v2 () -> s32[1,2,3] { ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} } +)" +}, +{ +"TransposeC128", +R"(HloModule TransposeC128_module + +ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] { + %input = c128[1,2,3]{2,1,0} parameter(0) + ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2} +} + )" }, // Dynamic slice diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 83af5d25432..807296329c0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -188,7 +188,16 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } return cplx_t; } - // A Tuple contains an array of pointers. Use i8*. + case C128: { + auto cplx_t = module->getTypeByName("complex128"); + if (cplx_t == nullptr) { + return llvm::StructType::create( + {llvm::Type::getDoubleTy(module->getContext()), + llvm::Type::getDoubleTy(module->getContext())}, + "complex128", /*isPacked=*/true); + } + return cplx_t; + } // A Tuple contains an array of pointers. Use i8*. case TUPLE: // An Opaque is like a void*, use i8*. case OPAQUE: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 1d3f84af955..c2e014ad423 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -906,6 +906,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); + } else if (lhs.element_type() == F64 && rhs.element_type() == F64) { + return ShapeUtil::ChangeElementType(shape, C128); } else { return Unimplemented("Complex component type is not implemented."); } @@ -1733,7 +1735,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case FFT: case IFFT: if (in.element_type() != C64) { - return InvalidArgument("%s requires C64 input type, found %s.", + return InvalidArgument("%s requires complex input type, found %s.", FftType_Name(fft_type), PrimitiveType_Name(in.element_type())); } diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index bb00bfd2eee..26120a06b82 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. const Shape s32_ = ShapeUtil::MakeShape(S32, {}); + const Shape f16_ = ShapeUtil::MakeShape(F16, {}); const Shape f32_ = ShapeUtil::MakeShape(F32, {}); const Shape f64_ = ShapeUtil::MakeShape(F64, {}); const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); @@ -260,8 +261,8 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok()); // Component types must match. ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok()); - // Only F32->C64 supported. - ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok()); + // Only F32->C64 and F64->C128 supported. + ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok()); // Validate correct uses. Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); @@ -285,6 +286,9 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {}))); } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 235b065585c..9821790cb34 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -378,6 +378,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case U32: case U64: case C64: + case C128: case TUPLE: case OPAQUE: case TOKEN: @@ -639,6 +640,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return sizeof(double); case C64: return sizeof(complex64); + case C128: + return sizeof(complex128); case TOKEN: // Tokens require no space. return 0; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 65a23dd8835..3f65ed7fce4 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -431,7 +431,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -455,7 +456,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -480,7 +482,8 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); @@ -506,7 +509,8 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); @@ -532,7 +536,8 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index b645acb700b..daf678f6901 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -41,6 +41,7 @@ using ::tensorflow::uint32; using ::tensorflow::uint64; using complex64 = std::complex; +using complex128 = std::complex; using ::Eigen::half; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index d029669cbea..a64e2f5df5c 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -56,6 +56,7 @@ enum PrimitiveType { // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. // A tuple is a polymorphic sequence; e.g. a shape that holds different // sub-shapes. They are used for things like returning multiple values from a @@ -75,7 +76,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 18 + // Next = 19 } // Describes the padding configuration for Pad operation. The padding amount on @@ -367,6 +368,7 @@ message LiteralProto { repeated float f32s = 8; repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. repeated LiteralProto tuple_literals = 10; // The F16s, BF16s, U16s and S16s are encoded in little endian byte order bytes f16s = 11; @@ -374,7 +376,7 @@ message LiteralProto { bytes u16s = 16; bytes s16s = 17; repeated int64 sparse_indices = 14; - // Next = 18 + // Next = 19 } message WindowDimension {