From f226eb3717a0df815579178f4393d4e68cbe08fc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2017 14:42:57 -0700 Subject: [PATCH] [XLA] Adds a C64 type to XLA, with actual compilation support coming soon. PiperOrigin-RevId: 173172916 --- tensorflow/compiler/tf2xla/type_util.cc | 3 + tensorflow/compiler/xla/literal_util.cc | 101 +++++++++++++++++- tensorflow/compiler/xla/literal_util.h | 26 +++++ tensorflow/compiler/xla/literal_util_test.cc | 85 ++++++++++++++- tensorflow/compiler/xla/primitive_util.cc | 18 ++++ tensorflow/compiler/xla/primitive_util.h | 15 +++ .../compiler/xla/service/hlo_evaluator.cc | 3 + tensorflow/compiler/xla/shape_util.cc | 6 ++ tensorflow/compiler/xla/shape_util.h | 3 + tensorflow/compiler/xla/shape_util_test.cc | 4 + .../xla/tests/client_library_test_base.cc | 6 +- .../compiler/xla/tests/literal_test_util.cc | 52 ++++++++- tensorflow/compiler/xla/types.h | 2 + tensorflow/compiler/xla/xla_data.proto | 4 + 14 files changed, 317 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index c6984887766..1efbe0ffb17 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -58,6 +58,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_DOUBLE: *type = xla::F64; return Status::OK(); + case tensorflow::DT_COMPLEX64: + *type = xla::C64; + return Status::OK(); case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 79e40c12625..413b85e3ba1 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -173,6 +173,8 @@ Status Literal::Copy(const Literal& src_literal, return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: return CopyRange(src_literal, src_base, dest_base, copy_size); + case C64: + return CopyRange(src_literal, src_base, dest_base, copy_size); case PRED: return CopyRange(src_literal, src_base, dest_base, copy_size); default: @@ -522,6 +524,10 @@ string Literal::GetAsString( return tensorflow::strings::StrCat(Get(multi_index)); case F64: return tensorflow::strings::StrCat(Get(multi_index)); + case C64: { + complex64 c = Get(multi_index); + return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")"); + } case F16: return tensorflow::strings::StrCat(Get(multi_index)); default: @@ -716,6 +722,8 @@ void* Literal::MutableInternalData() { return reinterpret_cast(f32s_.data()); case F64: return reinterpret_cast(f64s_.data()); + case C64: + return reinterpret_cast(c64s_.data()); case F16: return reinterpret_cast(f16s_.data()); default: @@ -754,6 +762,9 @@ void Literal::Reserve(int64 num_elements) { case F64: Resize(num_elements, 0); break; + case C64: + Resize(num_elements, 0); + break; case F16: Resize(num_elements, static_cast(0.0f)); break; @@ -790,6 +801,9 @@ tensorflow::Status Literal::ValidateLiteral() const { case F64: actual = f64s_size(); break; + case C64: + actual = c64s_size(); + break; case F16: actual = f16s().size() / sizeof(half); break; @@ -843,6 +857,26 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { return result_literal; } +template +std::unique_ptr ConvertToC64(const Literal& src_literal) { + auto result_literal = MakeUnique(); + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = src_literal.shape(); + result_shape->set_element_type(C64); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + using NativeSrcT = + typename primitive_util::PrimitiveTypeToNative::type; + tensorflow::gtl::ArraySlice src_data = + src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = complex64(static_cast(src_data[i]), 0); + } + return result_literal; +} + template std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); @@ -870,6 +904,8 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) #undef CONVERT_IF_TYPES_MATCH + case C64: + return ConvertToC64(src_literal); // Other types are not yet supported. default: return InvalidArgument( @@ -966,6 +1002,8 @@ bool Literal::operator==(const Literal& other) const { return EqualElements(*this, other, 0, &multi_index); case F16: return EqualElements(*this, other, 0, &multi_index); + case C64: + return EqualElements(*this, other, 0, &multi_index); default: LOG(FATAL) << "Unimplemented: Literal::Equal for type " << PrimitiveType_Name(shape().element_type()); @@ -1065,6 +1103,12 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { values->size()); } +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_c64s(); + return {values->data(), values->size()}; +} + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // TODO - there is an endianess problem here. fix it, or wait for uint16 @@ -1144,6 +1188,13 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { f16s().size() / sizeof(half)); } +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const { + CHECK_EQ(shape().element_type(), C64); + return c64s(); +} + template static bool AllElementsEqualValue(const Literal& literal, NativeT value) { for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { @@ -1211,6 +1262,15 @@ bool Literal::IsAllFloat(float value) const { } } +bool Literal::IsAllComplex(complex64 value) const { + switch (shape().element_type()) { + case C64: + return AllElementsEqualValue(*this, value); + default: + return false; + } +} + bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { switch (shape().element_type()) { case U8: @@ -1229,6 +1289,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { return Get(indices) == 0.0f; case F64: return Get(indices) == 0.0; + case C64: + return Get(indices) == complex64(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); case PRED: @@ -1298,12 +1360,27 @@ void Literal::Resize(int64 num_elements, half value) { mutable_f16s()->resize(num_elements, value); } +template <> +void Literal::Resize(int64 num_elements, complex64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_c64s()->resize(num_elements, value); +} + template -static void CopyToRepeatedField(RepeatedFieldT* dest, - const std::vector& src) { +void CopyToRepeatedField(RepeatedFieldT* dest, + const std::vector& src) { *dest = RepeatedFieldT(src.begin(), src.end()); } +template <> +void CopyToRepeatedField, complex64>( + tensorflow::protobuf::RepeatedField* dest, + const std::vector& src) { + *dest = tensorflow::protobuf::RepeatedField( + reinterpret_cast(src.data()), + reinterpret_cast(src.data()) + src.size() * 2); +} + LiteralProto Literal::ToProto() const { LiteralProto proto; proto.Clear(); @@ -1338,6 +1415,9 @@ LiteralProto Literal::ToProto() const { case F64: CopyToRepeatedField(proto.mutable_f64s(), f64s()); break; + case C64: + CopyToRepeatedField(proto.mutable_c64s(), c64s()); + break; case TUPLE: for (const auto& tuple : tuple_literals()) { *proto.add_tuple_literals() = tuple.ToProto(); @@ -1351,11 +1431,21 @@ LiteralProto Literal::ToProto() const { } template -static void CopyFromRepeatedField(std::vector* dest, - const RepeatedFieldT& src) { +void CopyFromRepeatedField(std::vector* dest, + const RepeatedFieldT& src) { *dest = std::vector(src.begin(), src.end()); } +template <> +void CopyFromRepeatedField, + complex64>( + std::vector* dest, + const tensorflow::protobuf::RepeatedField& src) { + *dest = std::vector( + reinterpret_cast(src.data()), + reinterpret_cast(src.data()) + src.size() / 2); +} + void Literal::CopyFromProto(const LiteralProto& literal_proto) { if (!literal_proto.has_shape()) { return; @@ -1394,6 +1484,9 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { case F64: CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); break; + case C64: + CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s()); + break; case TUPLE: for (const auto& proto : literal_proto.tuple_literals()) { mutable_tuple_literals()->push_back(Literal(proto)); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 4063cb05a91..a1e288829f2 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -159,6 +159,10 @@ class Literal { const std::vector& f64s() const { return f64s_; } std::vector* mutable_f64s() { return &f64s_; } + int c64s_size() const { return c64s().size(); } + const std::vector& c64s() const { return c64s_; } + std::vector* mutable_c64s() { return &c64s_; } + int tuple_literals_size() const { return tuple_literals().size(); } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } Literal* add_tuple_literals() { @@ -560,6 +564,17 @@ class Literal { // e.g. -0.5. bool IsAllFloat(float value) const; + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + bool IsAllComplex(complex64 value) const; + // Returns whether this literal is zero at the specified index. This literal // must be an array. bool IsZero(tensorflow::gtl::ArraySlice indices) const; @@ -610,6 +625,7 @@ class Literal { std::vector f16s_; std::vector f32s_; std::vector f64s_; + std::vector c64s_; std::vector tuple_literals_; }; @@ -658,6 +674,10 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const; + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); @@ -694,6 +714,9 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + template <> void Literal::Resize(int64 num_elements, bool value); @@ -724,6 +747,9 @@ void Literal::Resize(int64 num_elements, double value); template <> void Literal::Resize(int64 num_elements, half value); +template <> +void Literal::Resize(int64 num_elements, complex64 value); + template /* static */ std::unique_ptr Literal::CreateR0(NativeT value) { auto literal = MakeUnique(); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index e7dedd08218..a9af4849e21 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -107,6 +107,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f16_lit = Literal::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", f16_lit->ToString()); + + auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); + ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -331,6 +334,19 @@ TEST_F(LiteralUtilTest, TupleEquality) { EXPECT_NE(*tuple1, *different_tuple); } +TEST_F(LiteralUtilTest, C64Equality) { + // Test equality with tuples. + auto vector = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto vector_clone = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + EXPECT_EQ(*vector, *vector_clone); + + auto vector_reversed = Literal::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + EXPECT_NE(*vector, *vector_reversed); +} + TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = Literal::CreateR0(0.0); auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); @@ -381,6 +397,9 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + complex64 c8_9 = {8, 9}; + EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(Literal::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) @@ -411,6 +430,25 @@ TEST_F(LiteralUtilTest, IsAllFloat) { Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } +TEST_F(LiteralUtilTest, IsAllComplex) { + // IsAllComplex always returns false when the literal is not complex. + EXPECT_FALSE(Literal::CreateR0(false)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + + complex64 c8_9 = {8, 9}; + complex64 c7_9 = {7, 9}; + EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}}) + ->IsAllComplex({8.0f, 9.0f})); + EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}}) + ->IsAllComplex({8.0f, 9.0f})); + EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c7_9}}) + ->IsAllComplex({8.0f, 9.0f})); +} + TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = Literal::CreateR0(0.0f); auto scalar_one = Literal::CreateR0(1.0f); @@ -422,12 +460,17 @@ TEST_F(LiteralUtilTest, IsZero) { EXPECT_TRUE(array->IsZero({0, 2})); EXPECT_TRUE(array->IsZero({1, 1})); EXPECT_FALSE(array->IsZero({1, 2})); + + auto complex_zero = Literal::CreateR0(0.0f); + auto complex_nonzero = Literal::CreateR0(0.5f); + EXPECT_TRUE(complex_zero->IsZero({})); + EXPECT_FALSE(complex_nonzero->IsZero({})); } template class LiteralUtilTestTemplated : public ::testing::Test {}; -using TestedTypes = ::testing::Types; +using TestedTypes = ::testing::Types; TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { @@ -626,13 +669,28 @@ TEST_F(LiteralUtilTest, PopulateR1S64) { EXPECT_EQ(output, *expected); } -TEST_F(LiteralUtilTest, PopulateR2U64) { +TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output; output.PopulateR1({{77, 88}}); auto expected = Literal::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateR1C64) { + Literal output; + output.PopulateR1({{77, 88}}); + auto expected = Literal::CreateR1({{77, 88}}); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateR2C64) { + Literal output; + output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + auto expected = + Literal::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue(2.5f, {}); @@ -654,6 +712,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { + Literal output; + output.PopulateWithValue({4, 2}, {2, 2}); + auto expected = + Literal::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output; half h(0.25f); @@ -919,6 +985,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, }}, layout_r4_dim0major_); + auto c64 = Literal::CreateR4WithLayout({{ + {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, + {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, + {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, + }}, layout_r4_dim0major_); // clang-format on std::unique_ptr conv; @@ -961,12 +1032,22 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = u32->Convert(F16).ConsumeValueOrDie(); EXPECT_EQ(*conv, *f16); + conv = s32->Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *c64); + + conv = f16->Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *c64); + EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); EXPECT_EQ(s32->Convert(S16).status().code(), tensorflow::error::INVALID_ARGUMENT); EXPECT_EQ(s32->Convert(U16).status().code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(c64->Convert(F32).status().code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(c64->Convert(S32).status().code(), + tensorflow::error::INVALID_ARGUMENT); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index e4e37177a2d..2113b5e06f3 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -83,10 +83,17 @@ PrimitiveType NativeToPrimitiveType() { return F16; } +template <> +PrimitiveType NativeToPrimitiveType() { + return C64; +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64; } +bool IsComplexType(PrimitiveType type) { return type == C64; } + bool IsSignedIntegralType(PrimitiveType type) { return type == S8 || type == S16 || type == S32 || type == S64; } @@ -121,6 +128,7 @@ int BitWidth(PrimitiveType type) { case U64: case S64: case F64: + case C64: return 64; case TUPLE: @@ -134,5 +142,15 @@ int BitWidth(PrimitiveType type) { } } +PrimitiveType ComplexComponentType(PrimitiveType complex_type) { + switch (complex_type) { + case C64: + return F32; + default: + LOG(FATAL) << "Primitive type is not complex: " + << PrimitiveType_Name(complex_type); + } +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 162a11c7d29..a49c8b86fcf 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -78,8 +78,14 @@ PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +// Complex +template <> +PrimitiveType NativeToPrimitiveType(); + bool IsFloatingPointType(PrimitiveType type); +bool IsComplexType(PrimitiveType type); + bool IsSignedIntegralType(PrimitiveType type); bool IsUnsignedIntegralType(PrimitiveType type); @@ -89,6 +95,10 @@ bool IsIntegralType(PrimitiveType type); // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); +// Returns the real, imag component type underlying the given complex type. +// LOG(FATAL)'s if complex_type is not complex. +PrimitiveType ComplexComponentType(PrimitiveType complex_type); + // Returns the native type (eg, float) corresponding to the given template // parameter XLA primitive type (eg, F32). template @@ -157,6 +167,11 @@ struct PrimitiveTypeToNative { using type = half; }; +// Complex +template <> +struct PrimitiveTypeToNative { + using type = complex64; +}; } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e8f88427dae..fa6a8f3d53d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1265,6 +1265,9 @@ HloEvaluator::HloEvaluator() { }); typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[C64] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: C64."); + }); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented("unhandled primitive type: TUPLE."); }); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index af583bed625..fa4f71414dd 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -281,6 +281,10 @@ StatusOr MakeShapeWithLayoutInternal( } } +/* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) { + return primitive_util::IsComplexType(shape.element_type()); +} + /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) { return primitive_util::IsFloatingPointType(shape.element_type()); } @@ -592,6 +596,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(float); case F64: return sizeof(double); + case C64: + return sizeof(complex64); default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index c5800acaf11..8f8d4a73c9e 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -291,6 +291,9 @@ class ShapeUtil { // Returns whether the element type of the shape is floating point. static bool ElementIsFloating(const Shape& shape); + // Returns whether the element type of the shape is complex. + static bool ElementIsComplex(const Shape& shape); + // Returns whether the element type has the given bit width. static bool ElementHasBitWidth(const Shape& shape, int bits); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 79945b9c772..0ba542ad1be 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -218,6 +218,10 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20}))); + + EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); + EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); + EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a60d3e50bd4..065bce7e314 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -254,7 +254,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - if (ShapeUtil::ElementIsFloating(expected.shape())) { + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; } else { TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || @@ -282,7 +283,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { - TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape())); + TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); auto expect_near = [&](const Literal& actual, const string& error_message) { LiteralTestUtil::ExpectNear(expected, actual, error, error_message); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 2876a79dd8b..95a52ecd2f5 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -156,6 +156,15 @@ template <> ::testing::AssertionResult CompareEqual(double lhs, double rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } +template <> +::testing::AssertionResult CompareEqual(complex64 lhs, + complex64 rhs) { + auto res = CompareEqual(lhs.real(), rhs.real()); + if (!res) { + return res; + } + return CompareEqual(lhs.imag(), rhs.imag()); +} // A recursive function which iterates through every index of expected and // actual literal and compares their values elementwise. Returns true if all @@ -235,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case F64: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case C64: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case TUPLE: { bool tuple_match = true; for (int i = 0; i < actual.tuple_literals_size(); ++i) { @@ -325,6 +337,9 @@ class NearComparator { case F64: ExpectLiteralsNear(expected, actual, 0); break; + case C64: + ExpectLiteralsNear(expected, actual, 0); + break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " << PrimitiveType_Name(expected.shape().element_type()) @@ -365,6 +380,19 @@ class NearComparator { } private: + template + bool NanMismatch(NativeT lhs, NativeT rhs) { + return std::isnan(lhs) != std::isnan(rhs); + } + + template + void ExpectNear(NativeT expected, NativeT actual, + const ::testing::Message& message) { + EXPECT_NEAR(expected, actual, error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; + } + // EXPECTs that the two given scalar values are within the error bound. Keeps // track of how many mismatches have occurred to keep the size of the output // manageable. @@ -390,7 +418,7 @@ class NearComparator { "index %s abs_diff %f rel_err %f", LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff, rel_err); - bool nan_mismatch = std::isnan(actual) != std::isnan(expected); + bool nan_mismatch = NanMismatch(expected, actual); bool mismatch = (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); if (mismatch) { @@ -398,11 +426,12 @@ class NearComparator { abs_expected_miscompare_sum_ += std::abs(expected); const int64 kMaxFailures = 2; if (num_miscompares_ < kMaxFailures) { - EXPECT_NEAR(expected, actual, error_.abs) - << "mismatch at index " + ::testing::Message msg; + msg << "mismatch at index " << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff " << abs_diff << " rel err " << rel_err << " failure #" << num_miscompares_; + ExpectNear(expected, actual, msg); } else if (num_miscompares_ == kMaxFailures) { LOG(ERROR) << "reached max 'loud' failure count; silently proceeding..."; @@ -470,6 +499,23 @@ class NearComparator { std::vector max_abs_multi_index_; }; +template <> +bool NearComparator::NanMismatch(complex64 lhs, complex64 rhs) { + return std::isnan(lhs.real()) != std::isnan(rhs.real()) || + std::isnan(lhs.imag()) != std::isnan(rhs.imag()); +} + +template <> +void NearComparator::ExpectNear(complex64 expected, complex64 actual, + const ::testing::Message& message) { + EXPECT_NEAR(expected.real(), actual.real(), error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; + EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index ea8b4b7b989..3d78466107a 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -35,6 +35,8 @@ using ::tensorflow::uint16; using ::tensorflow::uint32; using ::tensorflow::uint64; +typedef std::complex complex64; + using ::Eigen::half; } // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index eae284afb76..7ad61fab81d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -48,6 +48,9 @@ enum PrimitiveType { F32 = 11; F64 = 12; + // Complex values of fixed width. + C64 = 15; + // A tuple is a polymorphic sequence; e.g. a shape that holds different // sub-shapes. They are used for things like returning multiple values from a // computation; e.g. a computation that returns weights and biases may have a @@ -305,6 +308,7 @@ message LiteralProto { repeated uint64 u64s = 7; repeated float f32s = 8; repeated double f64s = 9; + repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; bytes f16s = 11; // Note: the F16s are encoded in little endian byte order }