[XLA] Adds a C64 type to XLA, with actual compilation support coming soon.

PiperOrigin-RevId: 173172916
This commit is contained in:
A. Unique TensorFlower 2017-10-23 14:42:57 -07:00 committed by TensorFlower Gardener
parent 4f127e9019
commit f226eb3717
14 changed files with 317 additions and 11 deletions

View File

@ -58,6 +58,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_DOUBLE:
*type = xla::F64;
return Status::OK();
case tensorflow::DT_COMPLEX64:
*type = xla::C64;
return Status::OK();
case tensorflow::DT_QUINT8:
*type = xla::U8;
return Status::OK();

View File

@ -173,6 +173,8 @@ Status Literal::Copy(const Literal& src_literal,
return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
case F64:
return CopyRange<double>(src_literal, src_base, dest_base, copy_size);
case C64:
return CopyRange<complex64>(src_literal, src_base, dest_base, copy_size);
case PRED:
return CopyRange<bool>(src_literal, src_base, dest_base, copy_size);
default:
@ -522,6 +524,10 @@ string Literal::GetAsString(
return tensorflow::strings::StrCat(Get<float>(multi_index));
case F64:
return tensorflow::strings::StrCat(Get<double>(multi_index));
case C64: {
complex64 c = Get<complex64>(multi_index);
return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")");
}
case F16:
return tensorflow::strings::StrCat(Get<half>(multi_index));
default:
@ -716,6 +722,8 @@ void* Literal::MutableInternalData() {
return reinterpret_cast<void*>(f32s_.data());
case F64:
return reinterpret_cast<void*>(f64s_.data());
case C64:
return reinterpret_cast<void*>(c64s_.data());
case F16:
return reinterpret_cast<void*>(f16s_.data());
default:
@ -754,6 +762,9 @@ void Literal::Reserve(int64 num_elements) {
case F64:
Resize<double>(num_elements, 0);
break;
case C64:
Resize<complex64>(num_elements, 0);
break;
case F16:
Resize<half>(num_elements, static_cast<half>(0.0f));
break;
@ -790,6 +801,9 @@ tensorflow::Status Literal::ValidateLiteral() const {
case F64:
actual = f64s_size();
break;
case C64:
actual = c64s_size();
break;
case F16:
actual = f16s().size() / sizeof(half);
break;
@ -843,6 +857,26 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
return result_literal;
}
template <PrimitiveType primitive_src_type>
std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
auto result_literal = MakeUnique<Literal>();
Shape* result_shape = result_literal->mutable_shape();
*result_shape = src_literal.shape();
result_shape->set_element_type(C64);
result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
src_literal.GetArraySlice<NativeSrcT>();
tensorflow::gtl::MutableArraySlice<complex64> dest_data =
result_literal->GetMutableArraySlice<complex64>();
int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape());
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
}
return result_literal;
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
@ -870,6 +904,8 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
case C64:
return ConvertToC64<primitive_src_type>(src_literal);
// Other types are not yet supported.
default:
return InvalidArgument(
@ -966,6 +1002,8 @@ bool Literal::operator==(const Literal& other) const {
return EqualElements<double>(*this, other, 0, &multi_index);
case F16:
return EqualElements<half>(*this, other, 0, &multi_index);
case C64:
return EqualElements<complex64>(*this, other, 0, &multi_index);
default:
LOG(FATAL) << "Unimplemented: Literal::Equal for type "
<< PrimitiveType_Name(shape().element_type());
@ -1065,6 +1103,12 @@ tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice() {
values->size());
}
template <>
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
auto values = mutable_c64s();
return {values->data(), values->size()};
}
template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
// TODO - there is an endianess problem here. fix it, or wait for uint16
@ -1144,6 +1188,13 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
f16s().size() / sizeof(half));
}
template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const {
CHECK_EQ(shape().element_type(), C64);
return c64s();
}
template <typename NativeT>
static bool AllElementsEqualValue(const Literal& literal, NativeT value) {
for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
@ -1211,6 +1262,15 @@ bool Literal::IsAllFloat(float value) const {
}
}
bool Literal::IsAllComplex(complex64 value) const {
switch (shape().element_type()) {
case C64:
return AllElementsEqualValue<complex64>(*this, value);
default:
return false;
}
}
bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
switch (shape().element_type()) {
case U8:
@ -1229,6 +1289,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
return Get<float>(indices) == 0.0f;
case F64:
return Get<double>(indices) == 0.0;
case C64:
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
case F16:
return Get<half>(indices) == static_cast<half>(0.0f);
case PRED:
@ -1298,12 +1360,27 @@ void Literal::Resize<half>(int64 num_elements, half value) {
mutable_f16s()->resize(num_elements, value);
}
template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
mutable_c64s()->resize(num_elements, value);
}
template <typename RepeatedFieldT, typename NativeT>
static void CopyToRepeatedField(RepeatedFieldT* dest,
const std::vector<NativeT>& src) {
void CopyToRepeatedField(RepeatedFieldT* dest,
const std::vector<NativeT>& src) {
*dest = RepeatedFieldT(src.begin(), src.end());
}
template <>
void CopyToRepeatedField<tensorflow::protobuf::RepeatedField<float>, complex64>(
tensorflow::protobuf::RepeatedField<float>* dest,
const std::vector<complex64>& src) {
*dest = tensorflow::protobuf::RepeatedField<float>(
reinterpret_cast<const float*>(src.data()),
reinterpret_cast<const float*>(src.data()) + src.size() * 2);
}
LiteralProto Literal::ToProto() const {
LiteralProto proto;
proto.Clear();
@ -1338,6 +1415,9 @@ LiteralProto Literal::ToProto() const {
case F64:
CopyToRepeatedField(proto.mutable_f64s(), f64s());
break;
case C64:
CopyToRepeatedField(proto.mutable_c64s(), c64s());
break;
case TUPLE:
for (const auto& tuple : tuple_literals()) {
*proto.add_tuple_literals() = tuple.ToProto();
@ -1351,11 +1431,21 @@ LiteralProto Literal::ToProto() const {
}
template <typename RepeatedFieldT, typename NativeT>
static void CopyFromRepeatedField(std::vector<NativeT>* dest,
const RepeatedFieldT& src) {
void CopyFromRepeatedField(std::vector<NativeT>* dest,
const RepeatedFieldT& src) {
*dest = std::vector<NativeT>(src.begin(), src.end());
}
template <>
void CopyFromRepeatedField<tensorflow::protobuf::RepeatedField<float>,
complex64>(
std::vector<complex64>* dest,
const tensorflow::protobuf::RepeatedField<float>& src) {
*dest = std::vector<complex64>(
reinterpret_cast<const complex64*>(src.data()),
reinterpret_cast<const complex64*>(src.data()) + src.size() / 2);
}
void Literal::CopyFromProto(const LiteralProto& literal_proto) {
if (!literal_proto.has_shape()) {
return;
@ -1394,6 +1484,9 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
case F64:
CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s());
break;
case C64:
CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s());
break;
case TUPLE:
for (const auto& proto : literal_proto.tuple_literals()) {
mutable_tuple_literals()->push_back(Literal(proto));

View File

@ -159,6 +159,10 @@ class Literal {
const std::vector<double>& f64s() const { return f64s_; }
std::vector<double>* mutable_f64s() { return &f64s_; }
int c64s_size() const { return c64s().size(); }
const std::vector<complex64>& c64s() const { return c64s_; }
std::vector<complex64>* mutable_c64s() { return &c64s_; }
int tuple_literals_size() const { return tuple_literals().size(); }
const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
Literal* add_tuple_literals() {
@ -560,6 +564,17 @@ class Literal {
// e.g. -0.5.
bool IsAllFloat(float value) const;
// Like IsAll(const Literal&, int8), except we check whether the literal is
// equal to a particular complex number.
//
// If the literal is not a complex value, this always returns false.
//
// This casts value to the type of literal, then compares using ==. The usual
// admonishments about floating-point equality checks apply. We expect you to
// use this to check for complex values that can be expressed precisely as
// float pairs e.g. (-0.5, 1.0).
bool IsAllComplex(complex64 value) const;
// Returns whether this literal is zero at the specified index. This literal
// must be an array.
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
@ -610,6 +625,7 @@ class Literal {
std::vector<half> f16s_;
std::vector<float> f32s_;
std::vector<double> f64s_;
std::vector<complex64> c64s_;
std::vector<Literal> tuple_literals_;
};
@ -658,6 +674,10 @@ tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const;
template <>
tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
template <>
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
const;
template <>
tensorflow::gtl::MutableArraySlice<bool> Literal::GetMutableArraySlice();
@ -694,6 +714,9 @@ tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice();
template <>
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
template <>
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
template <>
void Literal::Resize<bool>(int64 num_elements, bool value);
@ -724,6 +747,9 @@ void Literal::Resize<double>(int64 num_elements, double value);
template <>
void Literal::Resize<half>(int64 num_elements, half value);
template <>
void Literal::Resize<complex64>(int64 num_elements, complex64 value);
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
auto literal = MakeUnique<Literal>();

View File

@ -107,6 +107,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto f16_lit = Literal::CreateR0<half>(static_cast<half>(0.5f));
ASSERT_EQ("0.5", f16_lit->ToString());
auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
@ -331,6 +334,19 @@ TEST_F(LiteralUtilTest, TupleEquality) {
EXPECT_NE(*tuple1, *different_tuple);
}
TEST_F(LiteralUtilTest, C64Equality) {
// Test equality with tuples.
auto vector = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
auto vector_clone = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
EXPECT_EQ(*vector, *vector_clone);
auto vector_reversed = Literal::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
EXPECT_NE(*vector, *vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
auto element1 = Literal::CreateR0<float>(0.0);
auto element2 = Literal::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
@ -381,6 +397,9 @@ TEST_F(LiteralUtilTest, IsAll) {
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
complex64 c8_9 = {8, 9};
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
auto uint64_max = std::numeric_limits<uint64>::max();
EXPECT_FALSE(Literal::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
@ -411,6 +430,25 @@ TEST_F(LiteralUtilTest, IsAllFloat) {
Literal::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
}
TEST_F(LiteralUtilTest, IsAllComplex) {
// IsAllComplex always returns false when the literal is not complex.
EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllComplex(0));
EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllComplex(0));
EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllComplex(0));
EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllComplex(0));
EXPECT_FALSE(Literal::CreateR0<float>(0)->IsAllComplex(0));
EXPECT_FALSE(Literal::CreateR0<double>(0)->IsAllComplex(0));
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})
->IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}})
->IsAllComplex({8.0f, 9.0f}));
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c7_9}})
->IsAllComplex({8.0f, 9.0f}));
}
TEST_F(LiteralUtilTest, IsZero) {
auto scalar_zero = Literal::CreateR0<float>(0.0f);
auto scalar_one = Literal::CreateR0<float>(1.0f);
@ -422,12 +460,17 @@ TEST_F(LiteralUtilTest, IsZero) {
EXPECT_TRUE(array->IsZero({0, 2}));
EXPECT_TRUE(array->IsZero({1, 1}));
EXPECT_FALSE(array->IsZero({1, 2}));
auto complex_zero = Literal::CreateR0<complex64>(0.0f);
auto complex_nonzero = Literal::CreateR0<complex64>(0.5f);
EXPECT_TRUE(complex_zero->IsZero({}));
EXPECT_FALSE(complex_nonzero->IsZero({}));
}
template <typename T>
class LiteralUtilTestTemplated : public ::testing::Test {};
using TestedTypes = ::testing::Types<float, int32, uint32>;
using TestedTypes = ::testing::Types<float, int32, uint32, complex64>;
TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes);
TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
@ -626,13 +669,28 @@ TEST_F(LiteralUtilTest, PopulateR1S64) {
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR2U64) {
TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output;
output.PopulateR1<uint64>({{77, 88}});
auto expected = Literal::CreateR1<uint64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
Literal output;
output.PopulateR1<complex64>({{77, 88}});
auto expected = Literal::CreateR1<complex64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR2C64) {
Literal output;
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output;
output.PopulateWithValue<float>(2.5f, {});
@ -654,6 +712,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
Literal output;
output.PopulateWithValue<complex64>({4, 2}, {2, 2});
auto expected =
Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
Literal output;
half h(0.25f);
@ -919,6 +985,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
}}, layout_r4_dim0major_);
auto c64 = Literal::CreateR4WithLayout<complex64>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
// clang-format on
std::unique_ptr<Literal> conv;
@ -961,12 +1032,22 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
conv = u32->Convert(F16).ConsumeValueOrDie();
EXPECT_EQ(*conv, *f16);
conv = s32->Convert(C64).ConsumeValueOrDie();
EXPECT_EQ(*conv, *c64);
conv = f16->Convert(C64).ConsumeValueOrDie();
EXPECT_EQ(*conv, *c64);
EXPECT_EQ(s32->Convert(TUPLE).status().code(),
tensorflow::error::INVALID_ARGUMENT);
EXPECT_EQ(s32->Convert(S16).status().code(),
tensorflow::error::INVALID_ARGUMENT);
EXPECT_EQ(s32->Convert(U16).status().code(),
tensorflow::error::INVALID_ARGUMENT);
EXPECT_EQ(c64->Convert(F32).status().code(),
tensorflow::error::INVALID_ARGUMENT);
EXPECT_EQ(c64->Convert(S32).status().code(),
tensorflow::error::INVALID_ARGUMENT);
}
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {

View File

@ -83,10 +83,17 @@ PrimitiveType NativeToPrimitiveType<half>() {
return F16;
}
template <>
PrimitiveType NativeToPrimitiveType<complex64>() {
return C64;
}
bool IsFloatingPointType(PrimitiveType type) {
return type == F16 || type == F32 || type == F64;
}
bool IsComplexType(PrimitiveType type) { return type == C64; }
bool IsSignedIntegralType(PrimitiveType type) {
return type == S8 || type == S16 || type == S32 || type == S64;
}
@ -121,6 +128,7 @@ int BitWidth(PrimitiveType type) {
case U64:
case S64:
case F64:
case C64:
return 64;
case TUPLE:
@ -134,5 +142,15 @@ int BitWidth(PrimitiveType type) {
}
}
PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
switch (complex_type) {
case C64:
return F32;
default:
LOG(FATAL) << "Primitive type is not complex: "
<< PrimitiveType_Name(complex_type);
}
}
} // namespace primitive_util
} // namespace xla

View File

@ -78,8 +78,14 @@ PrimitiveType NativeToPrimitiveType<double>();
template <>
PrimitiveType NativeToPrimitiveType<half>();
// Complex
template <>
PrimitiveType NativeToPrimitiveType<complex64>();
bool IsFloatingPointType(PrimitiveType type);
bool IsComplexType(PrimitiveType type);
bool IsSignedIntegralType(PrimitiveType type);
bool IsUnsignedIntegralType(PrimitiveType type);
@ -89,6 +95,10 @@ bool IsIntegralType(PrimitiveType type);
// Returns the number of bits in the representation for a given type.
int BitWidth(PrimitiveType type);
// Returns the real, imag component type underlying the given complex type.
// LOG(FATAL)'s if complex_type is not complex.
PrimitiveType ComplexComponentType(PrimitiveType complex_type);
// Returns the native type (eg, float) corresponding to the given template
// parameter XLA primitive type (eg, F32).
template <PrimitiveType>
@ -157,6 +167,11 @@ struct PrimitiveTypeToNative<F16> {
using type = half;
};
// Complex
template <>
struct PrimitiveTypeToNative<C64> {
using type = complex64;
};
} // namespace primitive_util
} // namespace xla

View File

@ -1265,6 +1265,9 @@ HloEvaluator::HloEvaluator() {
});
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[C64] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: C64.");
});
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: TUPLE.");
});

View File

@ -281,6 +281,10 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
}
}
/* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
return primitive_util::IsComplexType(shape.element_type());
}
/* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
return primitive_util::IsFloatingPointType(shape.element_type());
}
@ -592,6 +596,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return sizeof(float);
case F64:
return sizeof(double);
case C64:
return sizeof(complex64);
default:
LOG(FATAL) << "Unhandled primitive type " << primitive_type;
}

View File

@ -291,6 +291,9 @@ class ShapeUtil {
// Returns whether the element type of the shape is floating point.
static bool ElementIsFloating(const Shape& shape);
// Returns whether the element type of the shape is complex.
static bool ElementIsComplex(const Shape& shape);
// Returns whether the element type has the given bit width.
static bool ElementHasBitWidth(const Shape& shape, int bits);

View File

@ -218,6 +218,10 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64));
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {})));
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20})));
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64));
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {})));
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20})));
}
TEST(ShapeUtilTest, ByteSizeOfWithPadding) {

View File

@ -254,7 +254,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_layout) {
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
if (ShapeUtil::ElementIsFloating(expected.shape())) {
if (ShapeUtil::ElementIsFloating(expected.shape()) ||
ShapeUtil::ElementIsComplex(expected.shape())) {
LOG(WARNING) << "performing exact comparison of floating point numbers";
} else {
TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
@ -282,7 +283,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
ComputationBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout) {
TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()));
TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) ||
ShapeUtil::ElementIsComplex(expected.shape()));
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
auto expect_near = [&](const Literal& actual, const string& error_message) {
LiteralTestUtil::ExpectNear(expected, actual, error, error_message);

View File

@ -156,6 +156,15 @@ template <>
::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
}
template <>
::testing::AssertionResult CompareEqual<complex64>(complex64 lhs,
complex64 rhs) {
auto res = CompareEqual<float>(lhs.real(), rhs.real());
if (!res) {
return res;
}
return CompareEqual<float>(lhs.imag(), rhs.imag());
}
// A recursive function which iterates through every index of expected and
// actual literal and compares their values elementwise. Returns true if all
@ -235,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
case F64:
match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
break;
case C64:
match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0);
break;
case TUPLE: {
bool tuple_match = true;
for (int i = 0; i < actual.tuple_literals_size(); ++i) {
@ -325,6 +337,9 @@ class NearComparator {
case F64:
ExpectLiteralsNear<double>(expected, actual, 0);
break;
case C64:
ExpectLiteralsNear<complex64>(expected, actual, 0);
break;
default:
LOG(FATAL) << "Unsupported primitive type in near comparator: "
<< PrimitiveType_Name(expected.shape().element_type())
@ -365,6 +380,19 @@ class NearComparator {
}
private:
template <typename NativeT>
bool NanMismatch(NativeT lhs, NativeT rhs) {
return std::isnan(lhs) != std::isnan(rhs);
}
template <typename NativeT>
void ExpectNear(NativeT expected, NativeT actual,
const ::testing::Message& message) {
EXPECT_NEAR(expected, actual, error_.abs)
<< "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n"
<< message;
}
// EXPECTs that the two given scalar values are within the error bound. Keeps
// track of how many mismatches have occurred to keep the size of the output
// manageable.
@ -390,7 +418,7 @@ class NearComparator {
"index %s abs_diff %f rel_err %f",
LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff,
rel_err);
bool nan_mismatch = std::isnan(actual) != std::isnan(expected);
bool nan_mismatch = NanMismatch<NativeT>(expected, actual);
bool mismatch =
(nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
if (mismatch) {
@ -398,11 +426,12 @@ class NearComparator {
abs_expected_miscompare_sum_ += std::abs(expected);
const int64 kMaxFailures = 2;
if (num_miscompares_ < kMaxFailures) {
EXPECT_NEAR(expected, actual, error_.abs)
<< "mismatch at index "
::testing::Message msg;
msg << "mismatch at index "
<< LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff "
<< abs_diff << " rel err " << rel_err << " failure #"
<< num_miscompares_;
ExpectNear<NativeT>(expected, actual, msg);
} else if (num_miscompares_ == kMaxFailures) {
LOG(ERROR)
<< "reached max 'loud' failure count; silently proceeding...";
@ -470,6 +499,23 @@ class NearComparator {
std::vector<int64> max_abs_multi_index_;
};
template <>
bool NearComparator::NanMismatch<complex64>(complex64 lhs, complex64 rhs) {
return std::isnan(lhs.real()) != std::isnan(rhs.real()) ||
std::isnan(lhs.imag()) != std::isnan(rhs.imag());
}
template <>
void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual,
const ::testing::Message& message) {
EXPECT_NEAR(expected.real(), actual.real(), error_.abs)
<< "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n"
<< message;
EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs)
<< "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n"
<< message;
}
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(

View File

@ -35,6 +35,8 @@ using ::tensorflow::uint16;
using ::tensorflow::uint32;
using ::tensorflow::uint64;
typedef std::complex<float> complex64;
using ::Eigen::half;
} // namespace xla

View File

@ -48,6 +48,9 @@ enum PrimitiveType {
F32 = 11;
F64 = 12;
// Complex values of fixed width.
C64 = 15;
// A tuple is a polymorphic sequence; e.g. a shape that holds different
// sub-shapes. They are used for things like returning multiple values from a
// computation; e.g. a computation that returns weights and biases may have a
@ -305,6 +308,7 @@ message LiteralProto {
repeated uint64 u64s = 7;
repeated float f32s = 8;
repeated double f64s = 9;
repeated float c64s = 12; // Stored as interleaved real, imag floats.
repeated LiteralProto tuple_literals = 10;
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
}