[XLA] Adds a C64 type to XLA, with actual compilation support coming soon.
PiperOrigin-RevId: 173172916
This commit is contained in:
parent
4f127e9019
commit
f226eb3717
@ -58,6 +58,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
|
||||
case tensorflow::DT_DOUBLE:
|
||||
*type = xla::F64;
|
||||
return Status::OK();
|
||||
case tensorflow::DT_COMPLEX64:
|
||||
*type = xla::C64;
|
||||
return Status::OK();
|
||||
case tensorflow::DT_QUINT8:
|
||||
*type = xla::U8;
|
||||
return Status::OK();
|
||||
|
@ -173,6 +173,8 @@ Status Literal::Copy(const Literal& src_literal,
|
||||
return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
|
||||
case F64:
|
||||
return CopyRange<double>(src_literal, src_base, dest_base, copy_size);
|
||||
case C64:
|
||||
return CopyRange<complex64>(src_literal, src_base, dest_base, copy_size);
|
||||
case PRED:
|
||||
return CopyRange<bool>(src_literal, src_base, dest_base, copy_size);
|
||||
default:
|
||||
@ -522,6 +524,10 @@ string Literal::GetAsString(
|
||||
return tensorflow::strings::StrCat(Get<float>(multi_index));
|
||||
case F64:
|
||||
return tensorflow::strings::StrCat(Get<double>(multi_index));
|
||||
case C64: {
|
||||
complex64 c = Get<complex64>(multi_index);
|
||||
return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")");
|
||||
}
|
||||
case F16:
|
||||
return tensorflow::strings::StrCat(Get<half>(multi_index));
|
||||
default:
|
||||
@ -716,6 +722,8 @@ void* Literal::MutableInternalData() {
|
||||
return reinterpret_cast<void*>(f32s_.data());
|
||||
case F64:
|
||||
return reinterpret_cast<void*>(f64s_.data());
|
||||
case C64:
|
||||
return reinterpret_cast<void*>(c64s_.data());
|
||||
case F16:
|
||||
return reinterpret_cast<void*>(f16s_.data());
|
||||
default:
|
||||
@ -754,6 +762,9 @@ void Literal::Reserve(int64 num_elements) {
|
||||
case F64:
|
||||
Resize<double>(num_elements, 0);
|
||||
break;
|
||||
case C64:
|
||||
Resize<complex64>(num_elements, 0);
|
||||
break;
|
||||
case F16:
|
||||
Resize<half>(num_elements, static_cast<half>(0.0f));
|
||||
break;
|
||||
@ -790,6 +801,9 @@ tensorflow::Status Literal::ValidateLiteral() const {
|
||||
case F64:
|
||||
actual = f64s_size();
|
||||
break;
|
||||
case C64:
|
||||
actual = c64s_size();
|
||||
break;
|
||||
case F16:
|
||||
actual = f16s().size() / sizeof(half);
|
||||
break;
|
||||
@ -843,6 +857,26 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
|
||||
return result_literal;
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type>
|
||||
std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
|
||||
auto result_literal = MakeUnique<Literal>();
|
||||
Shape* result_shape = result_literal->mutable_shape();
|
||||
*result_shape = src_literal.shape();
|
||||
result_shape->set_element_type(C64);
|
||||
result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
|
||||
using NativeSrcT =
|
||||
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
|
||||
tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
|
||||
src_literal.GetArraySlice<NativeSrcT>();
|
||||
tensorflow::gtl::MutableArraySlice<complex64> dest_data =
|
||||
result_literal->GetMutableArraySlice<complex64>();
|
||||
int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape());
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
|
||||
}
|
||||
return result_literal;
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
|
||||
std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
|
||||
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
|
||||
@ -870,6 +904,8 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
|
||||
CONVERT_IF_TYPES_MATCH(F32)
|
||||
CONVERT_IF_TYPES_MATCH(F64)
|
||||
#undef CONVERT_IF_TYPES_MATCH
|
||||
case C64:
|
||||
return ConvertToC64<primitive_src_type>(src_literal);
|
||||
// Other types are not yet supported.
|
||||
default:
|
||||
return InvalidArgument(
|
||||
@ -966,6 +1002,8 @@ bool Literal::operator==(const Literal& other) const {
|
||||
return EqualElements<double>(*this, other, 0, &multi_index);
|
||||
case F16:
|
||||
return EqualElements<half>(*this, other, 0, &multi_index);
|
||||
case C64:
|
||||
return EqualElements<complex64>(*this, other, 0, &multi_index);
|
||||
default:
|
||||
LOG(FATAL) << "Unimplemented: Literal::Equal for type "
|
||||
<< PrimitiveType_Name(shape().element_type());
|
||||
@ -1065,6 +1103,12 @@ tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice() {
|
||||
values->size());
|
||||
}
|
||||
|
||||
template <>
|
||||
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
|
||||
auto values = mutable_c64s();
|
||||
return {values->data(), values->size()};
|
||||
}
|
||||
|
||||
template <>
|
||||
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
|
||||
// TODO - there is an endianess problem here. fix it, or wait for uint16
|
||||
@ -1144,6 +1188,13 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
|
||||
f16s().size() / sizeof(half));
|
||||
}
|
||||
|
||||
template <>
|
||||
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
|
||||
const {
|
||||
CHECK_EQ(shape().element_type(), C64);
|
||||
return c64s();
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
static bool AllElementsEqualValue(const Literal& literal, NativeT value) {
|
||||
for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
|
||||
@ -1211,6 +1262,15 @@ bool Literal::IsAllFloat(float value) const {
|
||||
}
|
||||
}
|
||||
|
||||
bool Literal::IsAllComplex(complex64 value) const {
|
||||
switch (shape().element_type()) {
|
||||
case C64:
|
||||
return AllElementsEqualValue<complex64>(*this, value);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
|
||||
switch (shape().element_type()) {
|
||||
case U8:
|
||||
@ -1229,6 +1289,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
|
||||
return Get<float>(indices) == 0.0f;
|
||||
case F64:
|
||||
return Get<double>(indices) == 0.0;
|
||||
case C64:
|
||||
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
|
||||
case F16:
|
||||
return Get<half>(indices) == static_cast<half>(0.0f);
|
||||
case PRED:
|
||||
@ -1298,12 +1360,27 @@ void Literal::Resize<half>(int64 num_elements, half value) {
|
||||
mutable_f16s()->resize(num_elements, value);
|
||||
}
|
||||
|
||||
template <>
|
||||
void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
|
||||
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
|
||||
mutable_c64s()->resize(num_elements, value);
|
||||
}
|
||||
|
||||
template <typename RepeatedFieldT, typename NativeT>
|
||||
static void CopyToRepeatedField(RepeatedFieldT* dest,
|
||||
const std::vector<NativeT>& src) {
|
||||
void CopyToRepeatedField(RepeatedFieldT* dest,
|
||||
const std::vector<NativeT>& src) {
|
||||
*dest = RepeatedFieldT(src.begin(), src.end());
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyToRepeatedField<tensorflow::protobuf::RepeatedField<float>, complex64>(
|
||||
tensorflow::protobuf::RepeatedField<float>* dest,
|
||||
const std::vector<complex64>& src) {
|
||||
*dest = tensorflow::protobuf::RepeatedField<float>(
|
||||
reinterpret_cast<const float*>(src.data()),
|
||||
reinterpret_cast<const float*>(src.data()) + src.size() * 2);
|
||||
}
|
||||
|
||||
LiteralProto Literal::ToProto() const {
|
||||
LiteralProto proto;
|
||||
proto.Clear();
|
||||
@ -1338,6 +1415,9 @@ LiteralProto Literal::ToProto() const {
|
||||
case F64:
|
||||
CopyToRepeatedField(proto.mutable_f64s(), f64s());
|
||||
break;
|
||||
case C64:
|
||||
CopyToRepeatedField(proto.mutable_c64s(), c64s());
|
||||
break;
|
||||
case TUPLE:
|
||||
for (const auto& tuple : tuple_literals()) {
|
||||
*proto.add_tuple_literals() = tuple.ToProto();
|
||||
@ -1351,11 +1431,21 @@ LiteralProto Literal::ToProto() const {
|
||||
}
|
||||
|
||||
template <typename RepeatedFieldT, typename NativeT>
|
||||
static void CopyFromRepeatedField(std::vector<NativeT>* dest,
|
||||
const RepeatedFieldT& src) {
|
||||
void CopyFromRepeatedField(std::vector<NativeT>* dest,
|
||||
const RepeatedFieldT& src) {
|
||||
*dest = std::vector<NativeT>(src.begin(), src.end());
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyFromRepeatedField<tensorflow::protobuf::RepeatedField<float>,
|
||||
complex64>(
|
||||
std::vector<complex64>* dest,
|
||||
const tensorflow::protobuf::RepeatedField<float>& src) {
|
||||
*dest = std::vector<complex64>(
|
||||
reinterpret_cast<const complex64*>(src.data()),
|
||||
reinterpret_cast<const complex64*>(src.data()) + src.size() / 2);
|
||||
}
|
||||
|
||||
void Literal::CopyFromProto(const LiteralProto& literal_proto) {
|
||||
if (!literal_proto.has_shape()) {
|
||||
return;
|
||||
@ -1394,6 +1484,9 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
|
||||
case F64:
|
||||
CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s());
|
||||
break;
|
||||
case C64:
|
||||
CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s());
|
||||
break;
|
||||
case TUPLE:
|
||||
for (const auto& proto : literal_proto.tuple_literals()) {
|
||||
mutable_tuple_literals()->push_back(Literal(proto));
|
||||
|
@ -159,6 +159,10 @@ class Literal {
|
||||
const std::vector<double>& f64s() const { return f64s_; }
|
||||
std::vector<double>* mutable_f64s() { return &f64s_; }
|
||||
|
||||
int c64s_size() const { return c64s().size(); }
|
||||
const std::vector<complex64>& c64s() const { return c64s_; }
|
||||
std::vector<complex64>* mutable_c64s() { return &c64s_; }
|
||||
|
||||
int tuple_literals_size() const { return tuple_literals().size(); }
|
||||
const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
|
||||
Literal* add_tuple_literals() {
|
||||
@ -560,6 +564,17 @@ class Literal {
|
||||
// e.g. -0.5.
|
||||
bool IsAllFloat(float value) const;
|
||||
|
||||
// Like IsAll(const Literal&, int8), except we check whether the literal is
|
||||
// equal to a particular complex number.
|
||||
//
|
||||
// If the literal is not a complex value, this always returns false.
|
||||
//
|
||||
// This casts value to the type of literal, then compares using ==. The usual
|
||||
// admonishments about floating-point equality checks apply. We expect you to
|
||||
// use this to check for complex values that can be expressed precisely as
|
||||
// float pairs e.g. (-0.5, 1.0).
|
||||
bool IsAllComplex(complex64 value) const;
|
||||
|
||||
// Returns whether this literal is zero at the specified index. This literal
|
||||
// must be an array.
|
||||
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
|
||||
@ -610,6 +625,7 @@ class Literal {
|
||||
std::vector<half> f16s_;
|
||||
std::vector<float> f32s_;
|
||||
std::vector<double> f64s_;
|
||||
std::vector<complex64> c64s_;
|
||||
std::vector<Literal> tuple_literals_;
|
||||
};
|
||||
|
||||
@ -658,6 +674,10 @@ tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const;
|
||||
template <>
|
||||
tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
|
||||
|
||||
template <>
|
||||
tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
|
||||
const;
|
||||
|
||||
template <>
|
||||
tensorflow::gtl::MutableArraySlice<bool> Literal::GetMutableArraySlice();
|
||||
|
||||
@ -694,6 +714,9 @@ tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice();
|
||||
template <>
|
||||
tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
|
||||
|
||||
template <>
|
||||
tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
|
||||
|
||||
template <>
|
||||
void Literal::Resize<bool>(int64 num_elements, bool value);
|
||||
|
||||
@ -724,6 +747,9 @@ void Literal::Resize<double>(int64 num_elements, double value);
|
||||
template <>
|
||||
void Literal::Resize<half>(int64 num_elements, half value);
|
||||
|
||||
template <>
|
||||
void Literal::Resize<complex64>(int64 num_elements, complex64 value);
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
|
@ -107,6 +107,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
|
||||
|
||||
auto f16_lit = Literal::CreateR0<half>(static_cast<half>(0.5f));
|
||||
ASSERT_EQ("0.5", f16_lit->ToString());
|
||||
|
||||
auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
|
||||
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, LiteralVectorToString) {
|
||||
@ -331,6 +334,19 @@ TEST_F(LiteralUtilTest, TupleEquality) {
|
||||
EXPECT_NE(*tuple1, *different_tuple);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, C64Equality) {
|
||||
// Test equality with tuples.
|
||||
auto vector = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
|
||||
// Tuple with the same elements. One element is shared with the original
|
||||
// tuple, the other is a clone of the element in the original tuple.
|
||||
auto vector_clone = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
EXPECT_EQ(*vector, *vector_clone);
|
||||
|
||||
auto vector_reversed = Literal::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
|
||||
EXPECT_NE(*vector, *vector_reversed);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, IsAllTuple) {
|
||||
auto element1 = Literal::CreateR0<float>(0.0);
|
||||
auto element2 = Literal::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
|
||||
@ -381,6 +397,9 @@ TEST_F(LiteralUtilTest, IsAll) {
|
||||
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
|
||||
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
|
||||
|
||||
complex64 c8_9 = {8, 9};
|
||||
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
|
||||
|
||||
auto uint64_max = std::numeric_limits<uint64>::max();
|
||||
EXPECT_FALSE(Literal::CreateR2<uint64>(
|
||||
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
|
||||
@ -411,6 +430,25 @@ TEST_F(LiteralUtilTest, IsAllFloat) {
|
||||
Literal::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, IsAllComplex) {
|
||||
// IsAllComplex always returns false when the literal is not complex.
|
||||
EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllComplex(0));
|
||||
EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllComplex(0));
|
||||
EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllComplex(0));
|
||||
EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllComplex(0));
|
||||
EXPECT_FALSE(Literal::CreateR0<float>(0)->IsAllComplex(0));
|
||||
EXPECT_FALSE(Literal::CreateR0<double>(0)->IsAllComplex(0));
|
||||
|
||||
complex64 c8_9 = {8, 9};
|
||||
complex64 c7_9 = {7, 9};
|
||||
EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})
|
||||
->IsAllComplex({8.0f, 9.0f}));
|
||||
EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}})
|
||||
->IsAllComplex({8.0f, 9.0f}));
|
||||
EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c7_9}})
|
||||
->IsAllComplex({8.0f, 9.0f}));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, IsZero) {
|
||||
auto scalar_zero = Literal::CreateR0<float>(0.0f);
|
||||
auto scalar_one = Literal::CreateR0<float>(1.0f);
|
||||
@ -422,12 +460,17 @@ TEST_F(LiteralUtilTest, IsZero) {
|
||||
EXPECT_TRUE(array->IsZero({0, 2}));
|
||||
EXPECT_TRUE(array->IsZero({1, 1}));
|
||||
EXPECT_FALSE(array->IsZero({1, 2}));
|
||||
|
||||
auto complex_zero = Literal::CreateR0<complex64>(0.0f);
|
||||
auto complex_nonzero = Literal::CreateR0<complex64>(0.5f);
|
||||
EXPECT_TRUE(complex_zero->IsZero({}));
|
||||
EXPECT_FALSE(complex_nonzero->IsZero({}));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class LiteralUtilTestTemplated : public ::testing::Test {};
|
||||
|
||||
using TestedTypes = ::testing::Types<float, int32, uint32>;
|
||||
using TestedTypes = ::testing::Types<float, int32, uint32, complex64>;
|
||||
TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes);
|
||||
|
||||
TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
|
||||
@ -626,13 +669,28 @@ TEST_F(LiteralUtilTest, PopulateR1S64) {
|
||||
EXPECT_EQ(output, *expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateR2U64) {
|
||||
TEST_F(LiteralUtilTest, PopulateR1U64) {
|
||||
Literal output;
|
||||
output.PopulateR1<uint64>({{77, 88}});
|
||||
auto expected = Literal::CreateR1<uint64>({{77, 88}});
|
||||
EXPECT_EQ(output, *expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateR1C64) {
|
||||
Literal output;
|
||||
output.PopulateR1<complex64>({{77, 88}});
|
||||
auto expected = Literal::CreateR1<complex64>({{77, 88}});
|
||||
EXPECT_EQ(output, *expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateR2C64) {
|
||||
Literal output;
|
||||
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
|
||||
auto expected =
|
||||
Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
|
||||
EXPECT_EQ(output, *expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
|
||||
Literal output;
|
||||
output.PopulateWithValue<float>(2.5f, {});
|
||||
@ -654,6 +712,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
|
||||
EXPECT_EQ(output, *expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
|
||||
Literal output;
|
||||
output.PopulateWithValue<complex64>({4, 2}, {2, 2});
|
||||
auto expected =
|
||||
Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
|
||||
EXPECT_EQ(output, *expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
|
||||
Literal output;
|
||||
half h(0.25f);
|
||||
@ -919,6 +985,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
|
||||
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
|
||||
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
|
||||
}}, layout_r4_dim0major_);
|
||||
auto c64 = Literal::CreateR4WithLayout<complex64>({{
|
||||
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
|
||||
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
|
||||
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
|
||||
}}, layout_r4_dim0major_);
|
||||
// clang-format on
|
||||
std::unique_ptr<Literal> conv;
|
||||
|
||||
@ -961,12 +1032,22 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
|
||||
conv = u32->Convert(F16).ConsumeValueOrDie();
|
||||
EXPECT_EQ(*conv, *f16);
|
||||
|
||||
conv = s32->Convert(C64).ConsumeValueOrDie();
|
||||
EXPECT_EQ(*conv, *c64);
|
||||
|
||||
conv = f16->Convert(C64).ConsumeValueOrDie();
|
||||
EXPECT_EQ(*conv, *c64);
|
||||
|
||||
EXPECT_EQ(s32->Convert(TUPLE).status().code(),
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(s32->Convert(S16).status().code(),
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(s32->Convert(U16).status().code(),
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(c64->Convert(F32).status().code(),
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(c64->Convert(S32).status().code(),
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
|
||||
|
@ -83,10 +83,17 @@ PrimitiveType NativeToPrimitiveType<half>() {
|
||||
return F16;
|
||||
}
|
||||
|
||||
template <>
|
||||
PrimitiveType NativeToPrimitiveType<complex64>() {
|
||||
return C64;
|
||||
}
|
||||
|
||||
bool IsFloatingPointType(PrimitiveType type) {
|
||||
return type == F16 || type == F32 || type == F64;
|
||||
}
|
||||
|
||||
bool IsComplexType(PrimitiveType type) { return type == C64; }
|
||||
|
||||
bool IsSignedIntegralType(PrimitiveType type) {
|
||||
return type == S8 || type == S16 || type == S32 || type == S64;
|
||||
}
|
||||
@ -121,6 +128,7 @@ int BitWidth(PrimitiveType type) {
|
||||
case U64:
|
||||
case S64:
|
||||
case F64:
|
||||
case C64:
|
||||
return 64;
|
||||
|
||||
case TUPLE:
|
||||
@ -134,5 +142,15 @@ int BitWidth(PrimitiveType type) {
|
||||
}
|
||||
}
|
||||
|
||||
PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
|
||||
switch (complex_type) {
|
||||
case C64:
|
||||
return F32;
|
||||
default:
|
||||
LOG(FATAL) << "Primitive type is not complex: "
|
||||
<< PrimitiveType_Name(complex_type);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace primitive_util
|
||||
} // namespace xla
|
||||
|
@ -78,8 +78,14 @@ PrimitiveType NativeToPrimitiveType<double>();
|
||||
template <>
|
||||
PrimitiveType NativeToPrimitiveType<half>();
|
||||
|
||||
// Complex
|
||||
template <>
|
||||
PrimitiveType NativeToPrimitiveType<complex64>();
|
||||
|
||||
bool IsFloatingPointType(PrimitiveType type);
|
||||
|
||||
bool IsComplexType(PrimitiveType type);
|
||||
|
||||
bool IsSignedIntegralType(PrimitiveType type);
|
||||
|
||||
bool IsUnsignedIntegralType(PrimitiveType type);
|
||||
@ -89,6 +95,10 @@ bool IsIntegralType(PrimitiveType type);
|
||||
// Returns the number of bits in the representation for a given type.
|
||||
int BitWidth(PrimitiveType type);
|
||||
|
||||
// Returns the real, imag component type underlying the given complex type.
|
||||
// LOG(FATAL)'s if complex_type is not complex.
|
||||
PrimitiveType ComplexComponentType(PrimitiveType complex_type);
|
||||
|
||||
// Returns the native type (eg, float) corresponding to the given template
|
||||
// parameter XLA primitive type (eg, F32).
|
||||
template <PrimitiveType>
|
||||
@ -157,6 +167,11 @@ struct PrimitiveTypeToNative<F16> {
|
||||
using type = half;
|
||||
};
|
||||
|
||||
// Complex
|
||||
template <>
|
||||
struct PrimitiveTypeToNative<C64> {
|
||||
using type = complex64;
|
||||
};
|
||||
} // namespace primitive_util
|
||||
} // namespace xla
|
||||
|
||||
|
@ -1265,6 +1265,9 @@ HloEvaluator::HloEvaluator() {
|
||||
});
|
||||
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
|
||||
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
|
||||
typed_visitors_[C64] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: C64.");
|
||||
});
|
||||
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: TUPLE.");
|
||||
});
|
||||
|
@ -281,6 +281,10 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) {
|
||||
return primitive_util::IsComplexType(shape.element_type());
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
|
||||
return primitive_util::IsFloatingPointType(shape.element_type());
|
||||
}
|
||||
@ -592,6 +596,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
return sizeof(float);
|
||||
case F64:
|
||||
return sizeof(double);
|
||||
case C64:
|
||||
return sizeof(complex64);
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled primitive type " << primitive_type;
|
||||
}
|
||||
|
@ -291,6 +291,9 @@ class ShapeUtil {
|
||||
// Returns whether the element type of the shape is floating point.
|
||||
static bool ElementIsFloating(const Shape& shape);
|
||||
|
||||
// Returns whether the element type of the shape is complex.
|
||||
static bool ElementIsComplex(const Shape& shape);
|
||||
|
||||
// Returns whether the element type has the given bit width.
|
||||
static bool ElementHasBitWidth(const Shape& shape, int bits);
|
||||
|
||||
|
@ -218,6 +218,10 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
|
||||
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64));
|
||||
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {})));
|
||||
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20})));
|
||||
|
||||
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64));
|
||||
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {})));
|
||||
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20})));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
|
||||
|
@ -254,7 +254,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
|
||||
const Shape* shape_with_layout) {
|
||||
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
|
||||
if (ShapeUtil::ElementIsFloating(expected.shape())) {
|
||||
if (ShapeUtil::ElementIsFloating(expected.shape()) ||
|
||||
ShapeUtil::ElementIsComplex(expected.shape())) {
|
||||
LOG(WARNING) << "performing exact comparison of floating point numbers";
|
||||
} else {
|
||||
TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
|
||||
@ -282,7 +283,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
|
||||
ComputationBuilder* builder, const Literal& expected,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
|
||||
const Shape* shape_with_layout) {
|
||||
TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()));
|
||||
TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) ||
|
||||
ShapeUtil::ElementIsComplex(expected.shape()));
|
||||
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
|
||||
auto expect_near = [&](const Literal& actual, const string& error_message) {
|
||||
LiteralTestUtil::ExpectNear(expected, actual, error, error_message);
|
||||
|
@ -156,6 +156,15 @@ template <>
|
||||
::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
|
||||
return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
|
||||
}
|
||||
template <>
|
||||
::testing::AssertionResult CompareEqual<complex64>(complex64 lhs,
|
||||
complex64 rhs) {
|
||||
auto res = CompareEqual<float>(lhs.real(), rhs.real());
|
||||
if (!res) {
|
||||
return res;
|
||||
}
|
||||
return CompareEqual<float>(lhs.imag(), rhs.imag());
|
||||
}
|
||||
|
||||
// A recursive function which iterates through every index of expected and
|
||||
// actual literal and compares their values elementwise. Returns true if all
|
||||
@ -235,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
|
||||
case F64:
|
||||
match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
|
||||
break;
|
||||
case C64:
|
||||
match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0);
|
||||
break;
|
||||
case TUPLE: {
|
||||
bool tuple_match = true;
|
||||
for (int i = 0; i < actual.tuple_literals_size(); ++i) {
|
||||
@ -325,6 +337,9 @@ class NearComparator {
|
||||
case F64:
|
||||
ExpectLiteralsNear<double>(expected, actual, 0);
|
||||
break;
|
||||
case C64:
|
||||
ExpectLiteralsNear<complex64>(expected, actual, 0);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported primitive type in near comparator: "
|
||||
<< PrimitiveType_Name(expected.shape().element_type())
|
||||
@ -365,6 +380,19 @@ class NearComparator {
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename NativeT>
|
||||
bool NanMismatch(NativeT lhs, NativeT rhs) {
|
||||
return std::isnan(lhs) != std::isnan(rhs);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void ExpectNear(NativeT expected, NativeT actual,
|
||||
const ::testing::Message& message) {
|
||||
EXPECT_NEAR(expected, actual, error_.abs)
|
||||
<< "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n"
|
||||
<< message;
|
||||
}
|
||||
|
||||
// EXPECTs that the two given scalar values are within the error bound. Keeps
|
||||
// track of how many mismatches have occurred to keep the size of the output
|
||||
// manageable.
|
||||
@ -390,7 +418,7 @@ class NearComparator {
|
||||
"index %s abs_diff %f rel_err %f",
|
||||
LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff,
|
||||
rel_err);
|
||||
bool nan_mismatch = std::isnan(actual) != std::isnan(expected);
|
||||
bool nan_mismatch = NanMismatch<NativeT>(expected, actual);
|
||||
bool mismatch =
|
||||
(nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
|
||||
if (mismatch) {
|
||||
@ -398,11 +426,12 @@ class NearComparator {
|
||||
abs_expected_miscompare_sum_ += std::abs(expected);
|
||||
const int64 kMaxFailures = 2;
|
||||
if (num_miscompares_ < kMaxFailures) {
|
||||
EXPECT_NEAR(expected, actual, error_.abs)
|
||||
<< "mismatch at index "
|
||||
::testing::Message msg;
|
||||
msg << "mismatch at index "
|
||||
<< LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff "
|
||||
<< abs_diff << " rel err " << rel_err << " failure #"
|
||||
<< num_miscompares_;
|
||||
ExpectNear<NativeT>(expected, actual, msg);
|
||||
} else if (num_miscompares_ == kMaxFailures) {
|
||||
LOG(ERROR)
|
||||
<< "reached max 'loud' failure count; silently proceeding...";
|
||||
@ -470,6 +499,23 @@ class NearComparator {
|
||||
std::vector<int64> max_abs_multi_index_;
|
||||
};
|
||||
|
||||
template <>
|
||||
bool NearComparator::NanMismatch<complex64>(complex64 lhs, complex64 rhs) {
|
||||
return std::isnan(lhs.real()) != std::isnan(rhs.real()) ||
|
||||
std::isnan(lhs.imag()) != std::isnan(rhs.imag());
|
||||
}
|
||||
|
||||
template <>
|
||||
void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual,
|
||||
const ::testing::Message& message) {
|
||||
EXPECT_NEAR(expected.real(), actual.real(), error_.abs)
|
||||
<< "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n"
|
||||
<< message;
|
||||
EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs)
|
||||
<< "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n"
|
||||
<< message;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
|
||||
|
@ -35,6 +35,8 @@ using ::tensorflow::uint16;
|
||||
using ::tensorflow::uint32;
|
||||
using ::tensorflow::uint64;
|
||||
|
||||
typedef std::complex<float> complex64;
|
||||
|
||||
using ::Eigen::half;
|
||||
|
||||
} // namespace xla
|
||||
|
@ -48,6 +48,9 @@ enum PrimitiveType {
|
||||
F32 = 11;
|
||||
F64 = 12;
|
||||
|
||||
// Complex values of fixed width.
|
||||
C64 = 15;
|
||||
|
||||
// A tuple is a polymorphic sequence; e.g. a shape that holds different
|
||||
// sub-shapes. They are used for things like returning multiple values from a
|
||||
// computation; e.g. a computation that returns weights and biases may have a
|
||||
@ -305,6 +308,7 @@ message LiteralProto {
|
||||
repeated uint64 u64s = 7;
|
||||
repeated float f32s = 8;
|
||||
repeated double f64s = 9;
|
||||
repeated float c64s = 12; // Stored as interleaved real, imag floats.
|
||||
repeated LiteralProto tuple_literals = 10;
|
||||
bytes f16s = 11; // Note: the F16s are encoded in little endian byte order
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user