[XLA] Add complex128 support.

Minimally tested at the moment (tested via the TF tests).

PiperOrigin-RevId: 228931680
This commit is contained in:
Peter Hawkins 2019-01-11 13:05:56 -08:00 committed by TensorFlower Gardener
parent 2cecb70a63
commit 937ff1b4cf
39 changed files with 462 additions and 84 deletions

View File

@ -151,7 +151,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
EXPECT_EQ(clusters["A"], clusters["C"]);
}
TEST(XlaCompilationTest, Complex128Unsupported) {
TEST(XlaCompilationTest, StringUnsupported) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
GraphDef graphdef;
{
@ -159,10 +159,10 @@ TEST(XlaCompilationTest, Complex128Unsupported) {
Node* a = ops::SourceOp(
"Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_COMPLEX128)
.WithAttr("value", Tensor(DT_COMPLEX128, TensorShape())));
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
.WithAttr("dtype", DT_STRING)
.WithAttr("value", Tensor(DT_STRING, TensorShape())));
Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}

View File

@ -83,9 +83,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
constexpr std::array<DataType, 13> kAllXlaCpuTypes = {
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);

View File

@ -26,9 +26,9 @@ namespace tensorflow {
const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER";
const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT";
constexpr std::array<DataType, 9> kExecAllTypes = {
constexpr std::array<DataType, 10> kExecAllTypes = {
{DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_BOOL, DT_BFLOAT16}};
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:

View File

@ -400,7 +400,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
def testComplexOps(self):
for dtype in self.complex_types:
ctypes = {np.complex64: np.float32}
ctypes = {np.complex64: np.float32, np.complex128: np.float64}
self._testBinary(
math_ops.complex,
np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]),

View File

@ -64,7 +64,7 @@ def tf_xla_py_test(
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128",
]
elif backend == "gpu":
backend_args += [

View File

@ -647,7 +647,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
ctypes = {np.complex64: np.float32}
ctypes = {np.complex64: np.float32, np.complex128: np.float64}
self._assertOpOutputMatchesExpected(
math_ops.abs,
np.array([[3 - 4j, -1j, np.inf]], dtype=dtype),

View File

@ -79,8 +79,8 @@ class BitcastOp : public XlaOpKernel {
if (src_dtype_ == dst_dtype_) {
output = input;
} else {
// The only complex type in XLA is C64, so error out if the bitcast has a
// complex source or destination type and the bitcast is not trivial.
// Error out if the bitcast has a complex source or destination type and
// the bitcast is not trivial.
OP_REQUIRES(ctx,
!xla::primitive_util::IsComplexType(src_type_) &&
!xla::primitive_util::IsComplexType(dst_type_),

View File

@ -83,6 +83,17 @@ class ConstOp : public XlaOpKernel {
return;
}
break;
case DT_COMPLEX128:
if (proto_.scomplex_val_size() == 2) {
ctx->SetOutput(
0,
xla::Broadcast(xla::ConstantR0<xla::complex128>(
b, xla::complex128(proto_.dcomplex_val(0),
proto_.dcomplex_val(1))),
shape.dim_sizes()));
return;
}
break;
case DT_INT32:
if (proto_.int_val_size() == 1) {
ctx->SetOutput(

View File

@ -24,8 +24,8 @@ limitations under the License.
namespace tensorflow {
namespace {
constexpr std::array<DataType, 5> kMatmulTypes = {
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
constexpr std::array<DataType, 6> kMatmulTypes = {
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}};
class MatMulOp : public XlaOpKernel {
public:

View File

@ -54,6 +54,9 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::C64:
return xla::ConstantR0<xla::complex64>(builder, value);
break;
case xla::C128:
return xla::ConstantR0<xla::complex128>(builder, value);
break;
default:
LOG(FATAL) << "unhandled element type " << type;
}
@ -90,6 +93,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
case xla::C64:
literal = xla::LiteralUtil::CreateR0<complex64>(value);
break;
case xla::C128:
literal = xla::LiteralUtil::CreateR0<complex128>(value);
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
case xla::S16:

View File

@ -69,6 +69,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_COMPLEX64:
*type = xla::C64;
return Status::OK();
case tensorflow::DT_COMPLEX128:
*type = xla::C128;
return Status::OK();
default:
return errors::InvalidArgument(
"Unsupported type in DataTypeToPrimitiveType ",

View File

@ -47,13 +47,14 @@ extern const char* const DEVICE_XLA_GPU;
constexpr std::array<DataType, 4> kFloatTypes = {
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
constexpr std::array<DataType, 11> kNumericTypes = {
constexpr std::array<DataType, 12> kNumericTypes = {
{DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}};
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}};
constexpr std::array<DataType, 14> kCpuAllTypes = {
constexpr std::array<DataType, 15> kCpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL}};
constexpr std::array<DataType, 15> kGpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,

View File

@ -56,6 +56,8 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
return ConstantR0<double>(builder, static_cast<double>(value));
case C64:
return ConstantR0<complex64>(builder, static_cast<complex64>(value));
case C128:
return ConstantR0<complex128>(builder, static_cast<complex128>(value));
case U8:
return ConstantR0<uint8>(builder, static_cast<uint8>(value));
case U32:

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -323,7 +324,8 @@ XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
auto perform_conj = shape.element_type() == C64 && conjugate;
auto perform_conj =
primitive_util::IsComplexType(shape.element_type()) && conjugate;
return perform_conj ? Conj(x) : x;
});
}

View File

@ -29,10 +29,12 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
@ -411,6 +413,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
COPY_ELEMENTS(F32, float);
COPY_ELEMENTS(F64, double);
COPY_ELEMENTS(C64, complex64);
COPY_ELEMENTS(C128, complex128);
COPY_ELEMENTS(PRED, bool);
#undef COPY_ELEMENTS
default:
@ -548,6 +551,9 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
case C64:
return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
copy_size);
case C128:
return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
copy_size);
case PRED:
return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
copy_size);
@ -766,6 +772,8 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
return SliceInternal<double>(result_shape, start_indices);
case C64:
return SliceInternal<complex64>(result_shape, start_indices);
case C128:
return SliceInternal<complex128>(result_shape, start_indices);
default:
LOG(FATAL) << "not yet implemented: "
<< PrimitiveType_Name(result_shape.element_type());
@ -814,6 +822,10 @@ string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
complex64 c = Get<complex64>(multi_index, shape_index);
return StrCat("(", c.real(), ", ", c.imag(), ")");
}
case C128: {
complex128 c = Get<complex128>(multi_index, shape_index);
return StrCat("(", c.real(), ", ", c.imag(), ")");
}
default:
LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
}
@ -868,6 +880,11 @@ string LiteralBase::GetSparseElementAsString(
GetSparseElement<complex64>(sparse_element_number, shape_index);
return StrCat("(", c.real(), ", ", c.imag(), ")");
}
case C128: {
complex128 c =
GetSparseElement<complex128>(sparse_element_number, shape_index);
return StrCat("(", c.real(), ", ", c.imag(), ")");
}
default:
LOG(FATAL) << "Invalid element type for sparse arrays: "
<< PrimitiveType_Name(subshape.element_type());
@ -996,6 +1013,9 @@ void LiteralBase::Piece::SortSparseElements() {
case C64:
SortSparseElementsInternal<complex64>();
break;
case C128:
SortSparseElementsInternal<complex128>();
break;
case F16:
SortSparseElementsInternal<half>();
break;
@ -1230,7 +1250,24 @@ Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
}
template <typename NativeSrcT, typename NativeDestT>
Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
(std::is_same<NativeDestT, complex64>::value ||
std::is_same<NativeDestT, complex128>::value),
Literal>::type
ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
};
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
}
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
(!std::is_same<NativeDestT, complex64>::value &&
!std::is_same<NativeDestT, complex128>::value),
Literal>::type
ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
@ -1274,22 +1311,6 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
Literal ConvertToC64(const LiteralBase& src_literal) {
CHECK(src_literal.shape().IsArray());
Literal result_literal(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
absl::Span<complex64> dest_data = result_literal.data<complex64>();
int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
}
return result_literal;
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
@ -1332,10 +1353,15 @@ StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
CONVERT_IF_TYPES_MATCH(BF16)
#undef CONVERT_IF_TYPES_MATCH
case C64:
if (!bitcast) {
return ConvertToC64<primitive_src_type>(src_literal);
if (bitcast) {
break;
}
break;
return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
case C128:
if (bitcast) {
break;
}
return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
// Other types are not yet supported.
default:
break;
@ -1485,6 +1511,8 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
return EqualElementsInternal<bfloat16>(other, &multi_index);
case C64:
return EqualElementsInternal<complex64>(other, &multi_index);
case C128:
return EqualElementsInternal<complex128>(other, &multi_index);
default:
LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
<< PrimitiveType_Name(subshape().element_type());
@ -1628,6 +1656,9 @@ bool LiteralBase::IsAllComplex(complex64 value) const {
case C64:
return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
value);
case C128:
return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
value);
default:
return false;
}
@ -1707,6 +1738,11 @@ bool LiteralBase::IsAllFirst() const {
auto data = piece.data<uint64>();
return AllElementsEqualValue<uint64>(data, data[0]);
}
case C128: {
auto data = piece.data<complex128>();
return AllElementsEqualValue<complex128>(data, data[0]);
}
default:
return false;
}
@ -1756,6 +1792,8 @@ bool LiteralBase::IsR1Iota() const {
return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
case C64:
return Get<complex64>({idx}) == complex64(idx, 0.0f);
case C128:
return Get<complex128>({idx}) == complex128(idx, 0.0f);
case PRED:
return Get<bool>({idx}) == idx;
// token, opaque, tuple, etc. are all not iota.
@ -1799,6 +1837,8 @@ bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
return Get<double>(indices) == 0.0;
case C64:
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
case C128:
return Get<complex128>(indices) == complex128(0.0f, 0.0f);
case F16:
return Get<half>(indices) == static_cast<half>(0.0f);
case BF16:
@ -1886,6 +1926,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
proto->add_c64s(value.imag());
}
break;
case C128:
for (complex128 value : data<complex128>()) {
proto->add_c128s(value.real());
proto->add_c128s(value.imag());
}
break;
case TUPLE:
case TOKEN:
// Nothing to do but assign the shape which is done above.
@ -2018,7 +2064,17 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
for (int64 i = 0; i < complex_data.size(); ++i) {
complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
}
} break;
break;
}
case C128: {
auto complex_data = data<complex128>();
TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2);
for (int64 i = 0; i < complex_data.size(); ++i) {
complex_data[i] =
complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
}
break;
}
case TUPLE:
return InvalidArgument("Should not be called on tuple shapes: %s",
ShapeUtil::HumanString(subshape()));

View File

@ -90,6 +90,12 @@ bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
}
template <>
bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
absl::Span<const int64> multi_index) {
return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
}
template <typename NativeT, typename UnsignedT>
Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
@ -143,6 +149,14 @@ Status MakeErrorStatus(complex64 lhs, complex64 rhs,
}
return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
}
template <>
Status MakeErrorStatus(complex128 lhs, complex128 rhs,
absl::Span<const int64> multi_index) {
if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
}
return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
}
// A recursive function which iterates through every index of expected and
// actual literal and compares their values elementwise. Returns true if all
@ -197,13 +211,6 @@ bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) {
}
}
template <>
bool NanMismatch<complex64>(complex64 expected, complex64 actual,
bool relaxed_nans) {
return NanMismatch<float>(expected.real(), actual.real(), relaxed_nans) ||
NanMismatch<float>(expected.imag(), actual.imag(), relaxed_nans);
}
template <>
bool NanMismatch<half>(half expected, half actual, bool relaxed_nans) {
return NanMismatch<float>(static_cast<float>(expected),
@ -232,6 +239,11 @@ string FpValueToString<complex64>(complex64 value) {
return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
}
template <>
string FpValueToString<complex128>(complex128 value) {
return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
}
// Returns the absolute value of the given floating point value. This function
// is used instead of std::abs directly in order to allow type-dependent
// implementations for NearComparator.
@ -434,7 +446,7 @@ class NearComparator {
mismatches_.data<bool>()[linear_index] = true;
}
// For complex64 types, we compare real and imaginary parts individually.
// For complex types, we compare real and imaginary parts individually.
void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
bool mismatch = false;
CompareValues<float>(expected.real(), actual.real(), linear_index);
@ -457,6 +469,29 @@ class NearComparator {
mismatches_.data<bool>()[linear_index] = mismatch;
}
void CompareValues(complex128 expected, complex128 actual,
int64 linear_index) {
bool mismatch = false;
CompareValues<double>(expected.real(), actual.real(), linear_index);
if (mismatches_.data<bool>()[linear_index] == true) {
mismatch = true;
// Delay the mismatch count increase for real part, instead increase
// mismatch by 1 for the entire complex number.
num_mismatches_--;
}
CompareValues<double>(expected.imag(), actual.imag(), linear_index);
if (mismatches_.data<bool>()[linear_index] == true) {
mismatch = true;
// Delay the mismatch count increase for imag part, instead increase
// mismatch by 1 for the entire complex number.
num_mismatches_--;
}
if (mismatch == true) {
num_mismatches_++;
}
mismatches_.data<bool>()[linear_index] = mismatch;
}
// Compares the two literals elementwise.
void CompareLiterals() {
// Fast path optimization for the case were layouts match.
@ -665,6 +700,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
case C64:
result = Equal<complex64>(expected, actual, index, 0);
break;
case C128:
result = Equal<complex128>(expected, actual, index, 0);
break;
case TUPLE: {
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
result.Update(EqualHelper(LiteralSlice(expected, {i}),
@ -749,6 +787,10 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
return NearComparator<complex64>::Compare(
expected, actual, error, detailed_message, miscompare_callback);
break;
case C128:
return NearComparator<complex128>::Compare(
expected, actual, error, detailed_message, miscompare_callback);
break;
default:
LOG(FATAL) << "Unsupported primitive type in near comparator: "
<< PrimitiveType_Name(expected.shape().element_type())

View File

@ -118,6 +118,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
auto c128_lit = LiteralUtil::CreateR0<complex128>({3.14f, 2.78f});
EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString());
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
@ -469,6 +472,21 @@ TEST_F(LiteralUtilTest, C64Equality) {
EXPECT_NE(vector, vector_reversed);
}
TEST_F(LiteralUtilTest, C128Equality) {
// Test equality with tuples.
auto vector = LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
auto vector_clone =
LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
EXPECT_EQ(vector, vector_clone);
auto vector_reversed =
LiteralUtil::CreateR1<complex128>({{3.0, 4.0}, {1.0, 2.0}});
EXPECT_NE(vector, vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
auto element1 = LiteralUtil::CreateR0<float>(0.0);
auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
@ -836,6 +854,13 @@ TEST_F(LiteralUtilTest, PopulateR1C64) {
EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR1C128) {
Literal output(ShapeUtil::MakeShape(C128, {1}));
output.PopulateR1<complex128>({{77, 88}});
auto expected = LiteralUtil::CreateR1<complex128>({{77, 88}});
EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateR2C64) {
Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
@ -897,6 +922,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2C128) {
Literal output(ShapeUtil::MakeShape(C128, {2, 2}));
output.PopulateWithValue<complex128>({4, 2});
auto expected =
LiteralUtil::CreateR2<complex128>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
EXPECT_EQ(output, expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
Literal output(ShapeUtil::MakeShape(F16, {}));
half h(0.25f);
@ -1308,7 +1341,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
// clang-format on
auto c128 = LiteralUtil::CreateR4WithLayout<complex128>({{
{{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
}}, layout_r4_dim0major_); // clang-format on
Literal conv;
conv = s8.Convert(U16).ConsumeValueOrDie();
@ -1374,10 +1411,20 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
conv = s32.Convert(U16).ConsumeValueOrDie();
EXPECT_EQ(conv, u16);
conv = s32.Convert(C128).ConsumeValueOrDie();
EXPECT_EQ(conv, c128);
conv = f16.Convert(C128).ConsumeValueOrDie();
EXPECT_EQ(conv, c128);
EXPECT_EQ(s32.Convert(TUPLE).status().code(),
tensorflow::error::UNIMPLEMENTED);
EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_EQ(c128.Convert(F32).status().code(),
tensorflow::error::UNIMPLEMENTED);
EXPECT_EQ(c128.Convert(S32).status().code(),
tensorflow::error::UNIMPLEMENTED);
}
TEST_F(LiteralUtilTest, BitcastConvert) {
@ -1739,7 +1786,8 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}),
ShapeUtil::MakeShape(C128, {})}));
EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
@ -1747,6 +1795,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
EXPECT_EQ(tuple.Get<complex128>({}, {4}), complex128(0.0, 0.0));
}
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
@ -1756,6 +1805,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
auto vector_c128 =
LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
auto vector_half =
@ -1776,6 +1827,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
EXPECT_EQ(vector_c128, to_from_proto(vector_c128));
EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
EXPECT_EQ(tuple, to_from_proto(tuple));

View File

@ -130,6 +130,8 @@ Literal ConvertType(LiteralSlice literal) {
return LiteralUtil::CreateR0<double>(0);
case C64:
return LiteralUtil::CreateR0<complex64>(0);
case C128:
return LiteralUtil::CreateR0<complex128>(0);
case PRED:
return LiteralUtil::CreateR0<bool>(false);
case TUPLE:
@ -165,6 +167,8 @@ Literal ConvertType(LiteralSlice literal) {
return LiteralUtil::CreateR0<double>(1);
case C64:
return LiteralUtil::CreateR0<complex64>(1);
case C128:
return LiteralUtil::CreateR0<complex128>(1);
case PRED:
return LiteralUtil::CreateR0<bool>(true);
case S16:
@ -201,6 +205,8 @@ Literal ConvertType(LiteralSlice literal) {
-std::numeric_limits<double>::infinity());
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case C128:
LOG(FATAL) << "C128 element type has no minimum value";
case PRED:
return LiteralUtil::CreateR0<bool>(false);
case S16:
@ -345,6 +351,10 @@ Literal ConvertType(LiteralSlice literal) {
new_literal.Set<complex64>(to_multi_index,
literal.Get<complex64>(from_multi_index));
break;
case C128:
new_literal.Set<complex128>(to_multi_index,
literal.Get<complex128>(from_multi_index));
break;
default:
LOG(FATAL) << "Unhandled primitive element type: "
<< PrimitiveType_Name(literal.shape().element_type());
@ -393,6 +403,10 @@ Literal ConvertType(LiteralSlice literal) {
return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
case U64:
return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
case C128:
return LiteralUtil::CreateR0<complex128>(
literal.GetFirstElement<complex128>());
default:
LOG(FATAL) << "Unhandled primitive type "
<< literal.shape().element_type();

View File

@ -27,7 +27,7 @@ bool IsFloatingPointType(PrimitiveType type) {
return type == F16 || type == F32 || type == F64 || type == BF16;
}
bool IsComplexType(PrimitiveType type) { return type == C64; }
bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
bool IsSignedIntegralType(PrimitiveType type) {
return type == S8 || type == S16 || type == S32 || type == S64;
@ -67,6 +67,9 @@ int BitWidth(PrimitiveType type) {
case C64:
return 64;
case C128:
return 128;
case TUPLE:
LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
@ -82,6 +85,8 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
switch (complex_type) {
case C64:
return F32;
case C128:
return F64;
default:
LOG(FATAL) << "Primitive type is not complex: "
<< PrimitiveType_Name(complex_type);

View File

@ -126,6 +126,11 @@ inline PrimitiveType NativeToPrimitiveType<complex64>() {
return C64;
}
template <>
inline PrimitiveType NativeToPrimitiveType<complex128>() {
return C128;
}
bool IsFloatingPointType(PrimitiveType type);
bool IsComplexType(PrimitiveType type);
@ -225,6 +230,11 @@ struct PrimitiveTypeToNative<C64> {
using type = complex64;
};
template <>
struct PrimitiveTypeToNative<C128> {
using type = complex128;
};
// Returns the lower-case name of the given primitive type.
const string& LowercasePrimitiveTypeName(PrimitiveType s);

View File

@ -54,6 +54,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
return NPY_FLOAT64;
case C64:
return NPY_COMPLEX64;
case C128:
return NPY_COMPLEX128;
case TUPLE:
return NPY_OBJECT;
default:
@ -89,6 +91,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) {
return F64;
case NPY_COMPLEX64:
return C64;
case NPY_COMPLEX128:
return C128;
case NPY_OBJECT:
return TUPLE;
default:
@ -111,6 +115,7 @@ bool NumpyTypeIsValid(int np_type) {
case NPY_FLOAT32:
case NPY_FLOAT64:
case NPY_COMPLEX64:
case NPY_COMPLEX128:
case NPY_OBJECT:
return true;
default:
@ -430,6 +435,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
case NPY_COMPLEX64:
CopyNumpyArrayToLiteral<complex64>(py_array, literal);
break;
case NPY_COMPLEX128:
CopyNumpyArrayToLiteral<complex128>(py_array, literal);
break;
default:
return InvalidArgument(
"No XLA literal container for Numpy type number: %d", np_type);
@ -470,6 +478,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
case NPY_COMPLEX64:
CopyLiteralToNumpyArray<complex64>(literal, py_array);
break;
case NPY_COMPLEX128:
CopyLiteralToNumpyArray<complex128>(literal, py_array);
break;
default:
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
}

View File

@ -199,6 +199,7 @@ XLA_ELEMENT_TYPE_TO_DTYPE = {
xla_data_pb2.F32: np.dtype('float32'),
xla_data_pb2.F64: np.dtype('float64'),
xla_data_pb2.C64: np.dtype('complex64'),
xla_data_pb2.C128: np.dtype('complex128'),
xla_data_pb2.TUPLE: np.dtype(np.object),
}

View File

@ -223,6 +223,7 @@ cc_library(
"hlo_evaluator_typed_visitor.h",
"hlo_evaluator_typed_visitor_bfloat16.cc",
"hlo_evaluator_typed_visitor_bool.cc",
"hlo_evaluator_typed_visitor_complex128.cc",
"hlo_evaluator_typed_visitor_complex64.cc",
"hlo_evaluator_typed_visitor_double.cc",
"hlo_evaluator_typed_visitor_float.cc",
@ -259,6 +260,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",

View File

@ -877,6 +877,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
case C64:
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
break;
case C128:
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal));
break;
default:
return Status::OK();
}

View File

@ -856,7 +856,8 @@ Status EmitNonBatchDotOperation(
const HloModuleConfig& hlo_module_config,
const TargetMachineFeatures& target_machine_features) {
PrimitiveType type = target_array.GetShape().element_type();
TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type ||
C128 == type);
DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
target_array, lhs_array, rhs_array, addend_array,
executable_run_options_value, b, hlo_module_config,

View File

@ -239,10 +239,12 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) {
int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
DCHECK_GE(byte_size, 0);
// Largest scalar is a complex64 so we don't need to worry about the
// Largest scalar is a complex128 so we don't need to worry about the
// int64->int truncation here.
DCHECK_LE(byte_size, 8);
return byte_size;
DCHECK_LE(byte_size, 16);
// Allocations may be 8-byte aligned if part of a small block.
return std::min(8LL, byte_size);
}
int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
@ -942,7 +944,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
auto rhs = dot->operand(1);
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
/*supported_types=*/{F16, F32, F64, C64}));
/*supported_types=*/{F16, F32, F64, C64, C128}));
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
if (dnums.lhs_contracting_dimensions_size() != 1) {
@ -1114,7 +1116,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
auto rhs = convolution->operand(1);
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
/*supported_types=*/{F16, F32, C64}));
/*supported_types=*/{F16, F32, C64, C128}));
// TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
// different data layouts.

View File

@ -206,6 +206,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
return &DoGemm<double>;
case C64:
return &DoGemm<std::complex<float>>;
case C128:
return &DoGemm<std::complex<double>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@ -221,6 +223,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type)
return &DoGemmWithAlgorithm<double>;
case C64:
return &DoGemmWithAlgorithm<std::complex<float>>;
case C128:
return &DoGemmWithAlgorithm<std::complex<double>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@ -235,6 +239,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
return &DoGemmAutotune<double>;
case C64:
return &DoGemmAutotune<std::complex<float>>;
case C128:
return &DoGemmAutotune<std::complex<double>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@ -255,6 +261,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
return se::blas::ComputationType::kF64;
case C64:
return se::blas::ComputationType::kComplexF32;
case C128:
return se::blas::ComputationType::kComplexF64;
default:
LOG(FATAL) << "Unsupported type.";
}

View File

@ -54,7 +54,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
PrimitiveType output_primitive_type = output_shape.element_type();
bool type_is_allowed =
(output_primitive_type == F16 || output_primitive_type == F32 ||
output_primitive_type == F64 || output_primitive_type == C64);
output_primitive_type == F64 || output_primitive_type == C64 ||
output_primitive_type == C128);
return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
IsRank2(rhs_shape, batch_dimensions_size) &&
IsRank2(output_shape, batch_dimensions_size) &&

View File

@ -136,6 +136,37 @@ StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
return std::move(result);
}
template <>
StatusOr<Literal> Compare<complex128>(const Shape& shape, HloOpcode opcode,
LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
std::function<bool(complex128, complex128)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
return lhs_el != rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
}
Literal result(shape);
TF_RETURN_IF_ERROR(
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
return compare_op(lhs_literal.Get<complex128>(multi_index),
rhs_literal.Get<complex128>(multi_index));
}));
return std::move(result);
}
} // namespace
// Note that unsupported types by the typed visitor does not necessarily imply
@ -170,6 +201,8 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
typed_visitors_[C64] =
absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
typed_visitors_[C128] =
absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
// Most of the evaluator computations we use don't support BF16 (e.g.,
// std::ceil, std::tanh). To make evaluator work with BF16, we set all
@ -500,6 +533,13 @@ Status HloEvaluator::HandleReal(HloInstruction* real) {
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
break;
}
case C128: {
auto result_or = ElementWiseUnaryOpImpl<float, complex128>(
real, [](complex128 elem_operand) { return std::real(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
break;
}
case F16: {
auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
real, [](Eigen::half elem_operand) { return elem_operand; },
@ -530,11 +570,29 @@ Status HloEvaluator::HandleReal(HloInstruction* real) {
}
Status HloEvaluator::HandleImag(HloInstruction* imag) {
auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
GetEvaluatedLiteralFor(imag->operand(0)));
auto operand = imag->operand(0);
switch (operand->shape().element_type()) {
case C64: {
auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
GetEvaluatedLiteralFor(imag->operand(0)));
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
break;
}
case C128: {
auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
GetEvaluatedLiteralFor(imag->operand(0)));
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
break;
}
default:
LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
<< PrimitiveType_Name(operand->shape().element_type());
}
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
return Status::OK();
}
@ -544,11 +602,27 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) {
TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
Literal result(complex->shape());
TF_RETURN_IF_ERROR(
result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
return std::complex<float>(real.Get<float>(multi_index),
imag.Get<float>(multi_index));
}));
switch (complex->shape().element_type()) {
case C64: {
TF_RETURN_IF_ERROR(
result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
return std::complex<float>(real.Get<float>(multi_index),
imag.Get<float>(multi_index));
}));
break;
}
case C128: {
TF_RETURN_IF_ERROR(
result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
return std::complex<float>(real.Get<double>(multi_index),
imag.Get<double>(multi_index));
}));
break;
}
default:
LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
<< PrimitiveType_Name(complex->shape().element_type());
}
evaluated_[complex] = std::move(result);
return Status::OK();
@ -647,6 +721,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
Compare<complex64>(compare->shape(), opcode,
lhs_literal, rhs_literal));
} break;
case C128: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<complex128>(compare->shape(), opcode,
lhs_literal, rhs_literal));
} break;
default:
LOG(FATAL) << "HandleCompare: unknown primitive type: "
<< PrimitiveType_Name(lhs->shape().element_type());

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/meta/type_traits.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -39,9 +40,8 @@ namespace xla {
// Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is
// a "private" header that's not exposed outside of hlo_evaluator.cc.
template <typename T>
using is_complex_t = std::is_same<T, complex64>;
template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
using is_complex_t =
absl::disjunction<std::is_same<T, complex64>, std::is_same<T, complex128>>;
// It's UB to use std::sort with std::less<float>, because of NaNs. Define
// "safe" less functions which are actually strict weak orders. -NaN and NaN
@ -212,7 +212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
template <
typename NativeT,
typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr>
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
Status HandleAbs(HloInstruction* abs) {
const Literal& operand_literal =
parent_->GetEvaluatedLiteralFor(abs->operand(0));
@ -231,6 +231,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// specifying the ElementwiseT explicitly as C64 is needed below.
if (abs->operand(0)->shape().element_type() == C64) {
return HandleAbs<complex64>(abs);
} else if (abs->operand(0)->shape().element_type() == C128) {
return HandleAbs<complex128>(abs);
}
return HandleAbs<ElementwiseT>(abs);
}
@ -1616,6 +1618,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map));
break;
}
case C128: {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex128>(map));
break;
}
default:
LOG(FATAL) << "HandleMap: unhandled primitive type for "
"input operand: "
@ -3040,6 +3046,7 @@ extern template class HloEvaluatorTypedVisitor<Eigen::half, float>;
extern template class HloEvaluatorTypedVisitor<float>;
extern template class HloEvaluatorTypedVisitor<double>;
extern template class HloEvaluatorTypedVisitor<complex64>;
extern template class HloEvaluatorTypedVisitor<complex128>;
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
} // namespace xla

View File

@ -0,0 +1,22 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
namespace xla {
template class HloEvaluatorTypedVisitor<complex128>;
} // namespace xla

View File

@ -551,6 +551,17 @@ ENTRY %Transpose.v2 () -> s32[1,2,3] {
ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
}
)"
},
{
"TransposeC128",
R"(HloModule TransposeC128_module
ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
%input = c128[1,2,3]{2,1,0} parameter(0)
ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
}
)"
},
// Dynamic slice

View File

@ -188,7 +188,16 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
}
return cplx_t;
}
// A Tuple contains an array of pointers. Use i8*.
case C128: {
auto cplx_t = module->getTypeByName("complex128");
if (cplx_t == nullptr) {
return llvm::StructType::create(
{llvm::Type::getDoubleTy(module->getContext()),
llvm::Type::getDoubleTy(module->getContext())},
"complex128", /*isPacked=*/true);
}
return cplx_t;
} // A Tuple contains an array of pointers. Use i8*.
case TUPLE:
// An Opaque is like a void*, use i8*.
case OPAQUE:

View File

@ -906,6 +906,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
broadcast_dimensions));
if (lhs.element_type() == F32 && rhs.element_type() == F32) {
return ShapeUtil::ChangeElementType(shape, C64);
} else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
return ShapeUtil::ChangeElementType(shape, C128);
} else {
return Unimplemented("Complex component type is not implemented.");
}
@ -1733,7 +1735,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
case FFT:
case IFFT:
if (in.element_type() != C64) {
return InvalidArgument("%s requires C64 input type, found %s.",
return InvalidArgument("%s requires complex input type, found %s.",
FftType_Name(fft_type),
PrimitiveType_Name(in.element_type()));
}

View File

@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test {
protected:
// Some handy scalar shapes.
const Shape s32_ = ShapeUtil::MakeShape(S32, {});
const Shape f16_ = ShapeUtil::MakeShape(F16, {});
const Shape f32_ = ShapeUtil::MakeShape(F32, {});
const Shape f64_ = ShapeUtil::MakeShape(F64, {});
const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
@ -260,8 +261,8 @@ TEST_F(ShapeInferenceTest, Complex) {
ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
// Component types must match.
ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
// Only F32->C64 supported.
ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok());
// Only F32->C64 and F64->C128 supported.
ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
// Validate correct uses.
Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
@ -285,6 +286,9 @@ TEST_F(ShapeInferenceTest, Complex) {
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
}
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {

View File

@ -378,6 +378,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
case U32:
case U64:
case C64:
case C128:
case TUPLE:
case OPAQUE:
case TOKEN:
@ -639,6 +640,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return sizeof(double);
case C64:
return sizeof(complex64);
case C128:
return sizeof(complex128);
case TOKEN:
// Tokens require no space.
return 0;

View File

@ -431,7 +431,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
std::is_same<NativeT, complex64>::value ||
std::is_same<NativeT, complex128>::value,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
@ -455,7 +456,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
std::is_same<NativeT, complex64>::value ||
std::is_same<NativeT, complex128>::value,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
@ -480,7 +482,8 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
std::is_same<NativeT, complex64>::value ||
std::is_same<NativeT, complex128>::value,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal =
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
@ -506,7 +509,8 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
std::is_same<NativeT, complex64>::value ||
std::is_same<NativeT, complex128>::value,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal =
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
@ -532,7 +536,8 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
std::is_same<NativeT, double>::value ||
std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, half>::value ||
std::is_same<NativeT, complex64>::value,
std::is_same<NativeT, complex64>::value ||
std::is_same<NativeT, complex128>::value,
"Float or complex type required when specifying an ErrorSpec");
Literal expected_literal =
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);

View File

@ -41,6 +41,7 @@ using ::tensorflow::uint32;
using ::tensorflow::uint64;
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
using ::Eigen::half;

View File

@ -56,6 +56,7 @@ enum PrimitiveType {
// Complex values of fixed width.
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>.
C128 = 18; // Paired F64 (real, imag), as in std::complex<double>.
// A tuple is a polymorphic sequence; e.g. a shape that holds different
// sub-shapes. They are used for things like returning multiple values from a
@ -75,7 +76,7 @@ enum PrimitiveType {
// primitive type will have empty dimensions and tuple_shapes fields.
TOKEN = 17;
// Next = 18
// Next = 19
}
// Describes the padding configuration for Pad operation. The padding amount on
@ -367,6 +368,7 @@ message LiteralProto {
repeated float f32s = 8;
repeated double f64s = 9;
repeated float c64s = 12; // Stored as interleaved real, imag floats.
repeated double c128s = 18; // Stored as interleaved real, imag doubles.
repeated LiteralProto tuple_literals = 10;
// The F16s, BF16s, U16s and S16s are encoded in little endian byte order
bytes f16s = 11;
@ -374,7 +376,7 @@ message LiteralProto {
bytes u16s = 16;
bytes s16s = 17;
repeated int64 sparse_indices = 14;
// Next = 18
// Next = 19
}
message WindowDimension {