[XLA] Add complex128 support.
Minimally tested at the moment (tested via the TF tests). PiperOrigin-RevId: 228931680
This commit is contained in:
parent
2cecb70a63
commit
937ff1b4cf
@ -151,7 +151,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
|
||||
EXPECT_EQ(clusters["A"], clusters["C"]);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, Complex128Unsupported) {
|
||||
TEST(XlaCompilationTest, StringUnsupported) {
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
GraphDef graphdef;
|
||||
{
|
||||
@ -159,10 +159,10 @@ TEST(XlaCompilationTest, Complex128Unsupported) {
|
||||
Node* a = ops::SourceOp(
|
||||
"Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_COMPLEX128)
|
||||
.WithAttr("value", Tensor(DT_COMPLEX128, TensorShape())));
|
||||
Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
|
||||
.WithAttr("dtype", DT_STRING)
|
||||
.WithAttr("value", Tensor(DT_STRING, TensorShape())));
|
||||
Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
|
||||
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
|
||||
|
@ -83,9 +83,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
|
||||
|
||||
// Kernel registrations
|
||||
|
||||
constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
|
||||
constexpr std::array<DataType, 13> kAllXlaCpuTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
|
||||
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
|
||||
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL}};
|
||||
|
||||
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
|
||||
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
|
||||
|
@ -26,9 +26,9 @@ namespace tensorflow {
|
||||
const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER";
|
||||
const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT";
|
||||
|
||||
constexpr std::array<DataType, 9> kExecAllTypes = {
|
||||
constexpr std::array<DataType, 10> kExecAllTypes = {
|
||||
{DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_BOOL, DT_BFLOAT16}};
|
||||
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
|
||||
|
||||
class XlaInterpreterDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
|
@ -400,7 +400,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def testComplexOps(self):
|
||||
for dtype in self.complex_types:
|
||||
ctypes = {np.complex64: np.float32}
|
||||
ctypes = {np.complex64: np.float32, np.complex128: np.float64}
|
||||
self._testBinary(
|
||||
math_ops.complex,
|
||||
np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]),
|
||||
|
@ -64,7 +64,7 @@ def tf_xla_py_test(
|
||||
if backend == "cpu":
|
||||
backend_args += [
|
||||
"--test_device=XLA_CPU",
|
||||
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
|
||||
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128",
|
||||
]
|
||||
elif backend == "gpu":
|
||||
backend_args += [
|
||||
|
@ -647,7 +647,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
|
||||
expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
|
||||
|
||||
ctypes = {np.complex64: np.float32}
|
||||
ctypes = {np.complex64: np.float32, np.complex128: np.float64}
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.abs,
|
||||
np.array([[3 - 4j, -1j, np.inf]], dtype=dtype),
|
||||
|
@ -79,8 +79,8 @@ class BitcastOp : public XlaOpKernel {
|
||||
if (src_dtype_ == dst_dtype_) {
|
||||
output = input;
|
||||
} else {
|
||||
// The only complex type in XLA is C64, so error out if the bitcast has a
|
||||
// complex source or destination type and the bitcast is not trivial.
|
||||
// Error out if the bitcast has a complex source or destination type and
|
||||
// the bitcast is not trivial.
|
||||
OP_REQUIRES(ctx,
|
||||
!xla::primitive_util::IsComplexType(src_type_) &&
|
||||
!xla::primitive_util::IsComplexType(dst_type_),
|
||||
|
@ -83,6 +83,17 @@ class ConstOp : public XlaOpKernel {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
if (proto_.scomplex_val_size() == 2) {
|
||||
ctx->SetOutput(
|
||||
0,
|
||||
xla::Broadcast(xla::ConstantR0<xla::complex128>(
|
||||
b, xla::complex128(proto_.dcomplex_val(0),
|
||||
proto_.dcomplex_val(1))),
|
||||
shape.dim_sizes()));
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case DT_INT32:
|
||||
if (proto_.int_val_size() == 1) {
|
||||
ctx->SetOutput(
|
||||
|
@ -24,8 +24,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr std::array<DataType, 5> kMatmulTypes = {
|
||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
|
||||
constexpr std::array<DataType, 6> kMatmulTypes = {
|
||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}};
|
||||
|
||||
class MatMulOp : public XlaOpKernel {
|
||||
public:
|
||||
|
@ -54,6 +54,9 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
case xla::C64:
|
||||
return xla::ConstantR0<xla::complex64>(builder, value);
|
||||
break;
|
||||
case xla::C128:
|
||||
return xla::ConstantR0<xla::complex128>(builder, value);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled element type " << type;
|
||||
}
|
||||
@ -90,6 +93,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
|
||||
case xla::C64:
|
||||
literal = xla::LiteralUtil::CreateR0<complex64>(value);
|
||||
break;
|
||||
case xla::C128:
|
||||
literal = xla::LiteralUtil::CreateR0<complex128>(value);
|
||||
break;
|
||||
case xla::PRED:
|
||||
LOG(FATAL) << "pred element type is not integral";
|
||||
case xla::S16:
|
||||
|
@ -69,6 +69,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
|
||||
case tensorflow::DT_COMPLEX64:
|
||||
*type = xla::C64;
|
||||
return Status::OK();
|
||||
case tensorflow::DT_COMPLEX128:
|
||||
*type = xla::C128;
|
||||
return Status::OK();
|
||||
default:
|
||||
return errors::InvalidArgument(
|
||||
"Unsupported type in DataTypeToPrimitiveType ",
|
||||
|
@ -47,13 +47,14 @@ extern const char* const DEVICE_XLA_GPU;
|
||||
|
||||
constexpr std::array<DataType, 4> kFloatTypes = {
|
||||
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
|
||||
constexpr std::array<DataType, 11> kNumericTypes = {
|
||||
constexpr std::array<DataType, 12> kNumericTypes = {
|
||||
{DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
|
||||
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}};
|
||||
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}};
|
||||
|
||||
constexpr std::array<DataType, 14> kCpuAllTypes = {
|
||||
constexpr std::array<DataType, 15> kCpuAllTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
|
||||
DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
|
||||
DT_COMPLEX128, DT_BOOL}};
|
||||
|
||||
constexpr std::array<DataType, 15> kGpuAllTypes = {
|
||||
{DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
|
||||
|
@ -56,6 +56,8 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
|
||||
return ConstantR0<double>(builder, static_cast<double>(value));
|
||||
case C64:
|
||||
return ConstantR0<complex64>(builder, static_cast<complex64>(value));
|
||||
case C128:
|
||||
return ConstantR0<complex128>(builder, static_cast<complex128>(value));
|
||||
case U8:
|
||||
return ConstantR0<uint8>(builder, static_cast<uint8>(value));
|
||||
case U32:
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
@ -323,7 +324,8 @@ XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
|
||||
auto perform_conj = shape.element_type() == C64 && conjugate;
|
||||
auto perform_conj =
|
||||
primitive_util::IsComplexType(shape.element_type()) && conjugate;
|
||||
return perform_conj ? Conj(x) : x;
|
||||
});
|
||||
}
|
||||
|
@ -29,10 +29,12 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -411,6 +413,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
|
||||
COPY_ELEMENTS(F32, float);
|
||||
COPY_ELEMENTS(F64, double);
|
||||
COPY_ELEMENTS(C64, complex64);
|
||||
COPY_ELEMENTS(C128, complex128);
|
||||
COPY_ELEMENTS(PRED, bool);
|
||||
#undef COPY_ELEMENTS
|
||||
default:
|
||||
@ -548,6 +551,9 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal,
|
||||
case C64:
|
||||
return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
|
||||
copy_size);
|
||||
case C128:
|
||||
return CopySliceFromInternal<complex128>(src_literal, src_base, dest_base,
|
||||
copy_size);
|
||||
case PRED:
|
||||
return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
|
||||
copy_size);
|
||||
@ -766,6 +772,8 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
|
||||
return SliceInternal<double>(result_shape, start_indices);
|
||||
case C64:
|
||||
return SliceInternal<complex64>(result_shape, start_indices);
|
||||
case C128:
|
||||
return SliceInternal<complex128>(result_shape, start_indices);
|
||||
default:
|
||||
LOG(FATAL) << "not yet implemented: "
|
||||
<< PrimitiveType_Name(result_shape.element_type());
|
||||
@ -814,6 +822,10 @@ string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
|
||||
complex64 c = Get<complex64>(multi_index, shape_index);
|
||||
return StrCat("(", c.real(), ", ", c.imag(), ")");
|
||||
}
|
||||
case C128: {
|
||||
complex128 c = Get<complex128>(multi_index, shape_index);
|
||||
return StrCat("(", c.real(), ", ", c.imag(), ")");
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
|
||||
}
|
||||
@ -868,6 +880,11 @@ string LiteralBase::GetSparseElementAsString(
|
||||
GetSparseElement<complex64>(sparse_element_number, shape_index);
|
||||
return StrCat("(", c.real(), ", ", c.imag(), ")");
|
||||
}
|
||||
case C128: {
|
||||
complex128 c =
|
||||
GetSparseElement<complex128>(sparse_element_number, shape_index);
|
||||
return StrCat("(", c.real(), ", ", c.imag(), ")");
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Invalid element type for sparse arrays: "
|
||||
<< PrimitiveType_Name(subshape.element_type());
|
||||
@ -996,6 +1013,9 @@ void LiteralBase::Piece::SortSparseElements() {
|
||||
case C64:
|
||||
SortSparseElementsInternal<complex64>();
|
||||
break;
|
||||
case C128:
|
||||
SortSparseElementsInternal<complex128>();
|
||||
break;
|
||||
case F16:
|
||||
SortSparseElementsInternal<half>();
|
||||
break;
|
||||
@ -1230,7 +1250,24 @@ Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
|
||||
}
|
||||
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
|
||||
typename std::enable_if<(std::is_same<NativeSrcT, Eigen::half>::value) &&
|
||||
(std::is_same<NativeDestT, complex64>::value ||
|
||||
std::is_same<NativeDestT, complex128>::value),
|
||||
Literal>::type
|
||||
ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
|
||||
auto converter = [](NativeSrcT src) {
|
||||
return NativeDestT(static_cast<typename NativeDestT::value_type>(src));
|
||||
};
|
||||
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
|
||||
src_literal, converter);
|
||||
}
|
||||
|
||||
template <typename NativeSrcT, typename NativeDestT>
|
||||
typename std::enable_if<(!std::is_same<NativeSrcT, Eigen::half>::value) ||
|
||||
(!std::is_same<NativeDestT, complex64>::value &&
|
||||
!std::is_same<NativeDestT, complex128>::value),
|
||||
Literal>::type
|
||||
ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
|
||||
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
|
||||
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
|
||||
src_literal, converter);
|
||||
@ -1274,22 +1311,6 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
|
||||
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type>
|
||||
Literal ConvertToC64(const LiteralBase& src_literal) {
|
||||
CHECK(src_literal.shape().IsArray());
|
||||
Literal result_literal(
|
||||
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
|
||||
using NativeSrcT =
|
||||
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
|
||||
absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
|
||||
absl::Span<complex64> dest_data = result_literal.data<complex64>();
|
||||
int64 num_elements = src_literal.element_count();
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
|
||||
}
|
||||
return result_literal;
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
|
||||
Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
|
||||
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
|
||||
@ -1332,10 +1353,15 @@ StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
|
||||
CONVERT_IF_TYPES_MATCH(BF16)
|
||||
#undef CONVERT_IF_TYPES_MATCH
|
||||
case C64:
|
||||
if (!bitcast) {
|
||||
return ConvertToC64<primitive_src_type>(src_literal);
|
||||
if (bitcast) {
|
||||
break;
|
||||
}
|
||||
break;
|
||||
return ConvertIfTypesMatch<primitive_src_type, C64>(src_literal, false);
|
||||
case C128:
|
||||
if (bitcast) {
|
||||
break;
|
||||
}
|
||||
return ConvertIfTypesMatch<primitive_src_type, C128>(src_literal, false);
|
||||
// Other types are not yet supported.
|
||||
default:
|
||||
break;
|
||||
@ -1485,6 +1511,8 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
|
||||
return EqualElementsInternal<bfloat16>(other, &multi_index);
|
||||
case C64:
|
||||
return EqualElementsInternal<complex64>(other, &multi_index);
|
||||
case C128:
|
||||
return EqualElementsInternal<complex128>(other, &multi_index);
|
||||
default:
|
||||
LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
|
||||
<< PrimitiveType_Name(subshape().element_type());
|
||||
@ -1628,6 +1656,9 @@ bool LiteralBase::IsAllComplex(complex64 value) const {
|
||||
case C64:
|
||||
return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
|
||||
value);
|
||||
case C128:
|
||||
return AllElementsEqualValue<complex128>(root_piece().data<complex128>(),
|
||||
value);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -1707,6 +1738,11 @@ bool LiteralBase::IsAllFirst() const {
|
||||
auto data = piece.data<uint64>();
|
||||
return AllElementsEqualValue<uint64>(data, data[0]);
|
||||
}
|
||||
|
||||
case C128: {
|
||||
auto data = piece.data<complex128>();
|
||||
return AllElementsEqualValue<complex128>(data, data[0]);
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -1756,6 +1792,8 @@ bool LiteralBase::IsR1Iota() const {
|
||||
return Get<bfloat16>({idx}) == static_cast<bfloat16>(idx);
|
||||
case C64:
|
||||
return Get<complex64>({idx}) == complex64(idx, 0.0f);
|
||||
case C128:
|
||||
return Get<complex128>({idx}) == complex128(idx, 0.0f);
|
||||
case PRED:
|
||||
return Get<bool>({idx}) == idx;
|
||||
// token, opaque, tuple, etc. are all not iota.
|
||||
@ -1799,6 +1837,8 @@ bool LiteralBase::IsZero(absl::Span<const int64> indices) const {
|
||||
return Get<double>(indices) == 0.0;
|
||||
case C64:
|
||||
return Get<complex64>(indices) == complex64(0.0f, 0.0f);
|
||||
case C128:
|
||||
return Get<complex128>(indices) == complex128(0.0f, 0.0f);
|
||||
case F16:
|
||||
return Get<half>(indices) == static_cast<half>(0.0f);
|
||||
case BF16:
|
||||
@ -1886,6 +1926,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
|
||||
proto->add_c64s(value.imag());
|
||||
}
|
||||
break;
|
||||
case C128:
|
||||
for (complex128 value : data<complex128>()) {
|
||||
proto->add_c128s(value.real());
|
||||
proto->add_c128s(value.imag());
|
||||
}
|
||||
break;
|
||||
case TUPLE:
|
||||
case TOKEN:
|
||||
// Nothing to do but assign the shape which is done above.
|
||||
@ -2018,7 +2064,17 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
|
||||
for (int64 i = 0; i < complex_data.size(); ++i) {
|
||||
complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
|
||||
}
|
||||
} break;
|
||||
break;
|
||||
}
|
||||
case C128: {
|
||||
auto complex_data = data<complex128>();
|
||||
TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2);
|
||||
for (int64 i = 0; i < complex_data.size(); ++i) {
|
||||
complex_data[i] =
|
||||
complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)};
|
||||
}
|
||||
break;
|
||||
}
|
||||
case TUPLE:
|
||||
return InvalidArgument("Should not be called on tuple shapes: %s",
|
||||
ShapeUtil::HumanString(subshape()));
|
||||
|
@ -90,6 +90,12 @@ bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
|
||||
return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
|
||||
CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
|
||||
}
|
||||
template <>
|
||||
bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
|
||||
absl::Span<const int64> multi_index) {
|
||||
return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
|
||||
CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
|
||||
}
|
||||
|
||||
template <typename NativeT, typename UnsignedT>
|
||||
Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
|
||||
@ -143,6 +149,14 @@ Status MakeErrorStatus(complex64 lhs, complex64 rhs,
|
||||
}
|
||||
return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
|
||||
}
|
||||
template <>
|
||||
Status MakeErrorStatus(complex128 lhs, complex128 rhs,
|
||||
absl::Span<const int64> multi_index) {
|
||||
if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
|
||||
return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
|
||||
}
|
||||
return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
|
||||
}
|
||||
|
||||
// A recursive function which iterates through every index of expected and
|
||||
// actual literal and compares their values elementwise. Returns true if all
|
||||
@ -197,13 +211,6 @@ bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) {
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
bool NanMismatch<complex64>(complex64 expected, complex64 actual,
|
||||
bool relaxed_nans) {
|
||||
return NanMismatch<float>(expected.real(), actual.real(), relaxed_nans) ||
|
||||
NanMismatch<float>(expected.imag(), actual.imag(), relaxed_nans);
|
||||
}
|
||||
|
||||
template <>
|
||||
bool NanMismatch<half>(half expected, half actual, bool relaxed_nans) {
|
||||
return NanMismatch<float>(static_cast<float>(expected),
|
||||
@ -232,6 +239,11 @@ string FpValueToString<complex64>(complex64 value) {
|
||||
return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
|
||||
}
|
||||
|
||||
template <>
|
||||
string FpValueToString<complex128>(complex128 value) {
|
||||
return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag());
|
||||
}
|
||||
|
||||
// Returns the absolute value of the given floating point value. This function
|
||||
// is used instead of std::abs directly in order to allow type-dependent
|
||||
// implementations for NearComparator.
|
||||
@ -434,7 +446,7 @@ class NearComparator {
|
||||
mismatches_.data<bool>()[linear_index] = true;
|
||||
}
|
||||
|
||||
// For complex64 types, we compare real and imaginary parts individually.
|
||||
// For complex types, we compare real and imaginary parts individually.
|
||||
void CompareValues(complex64 expected, complex64 actual, int64 linear_index) {
|
||||
bool mismatch = false;
|
||||
CompareValues<float>(expected.real(), actual.real(), linear_index);
|
||||
@ -457,6 +469,29 @@ class NearComparator {
|
||||
mismatches_.data<bool>()[linear_index] = mismatch;
|
||||
}
|
||||
|
||||
void CompareValues(complex128 expected, complex128 actual,
|
||||
int64 linear_index) {
|
||||
bool mismatch = false;
|
||||
CompareValues<double>(expected.real(), actual.real(), linear_index);
|
||||
if (mismatches_.data<bool>()[linear_index] == true) {
|
||||
mismatch = true;
|
||||
// Delay the mismatch count increase for real part, instead increase
|
||||
// mismatch by 1 for the entire complex number.
|
||||
num_mismatches_--;
|
||||
}
|
||||
CompareValues<double>(expected.imag(), actual.imag(), linear_index);
|
||||
if (mismatches_.data<bool>()[linear_index] == true) {
|
||||
mismatch = true;
|
||||
// Delay the mismatch count increase for imag part, instead increase
|
||||
// mismatch by 1 for the entire complex number.
|
||||
num_mismatches_--;
|
||||
}
|
||||
if (mismatch == true) {
|
||||
num_mismatches_++;
|
||||
}
|
||||
mismatches_.data<bool>()[linear_index] = mismatch;
|
||||
}
|
||||
|
||||
// Compares the two literals elementwise.
|
||||
void CompareLiterals() {
|
||||
// Fast path optimization for the case were layouts match.
|
||||
@ -665,6 +700,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) {
|
||||
case C64:
|
||||
result = Equal<complex64>(expected, actual, index, 0);
|
||||
break;
|
||||
case C128:
|
||||
result = Equal<complex128>(expected, actual, index, 0);
|
||||
break;
|
||||
case TUPLE: {
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
|
||||
result.Update(EqualHelper(LiteralSlice(expected, {i}),
|
||||
@ -749,6 +787,10 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
|
||||
return NearComparator<complex64>::Compare(
|
||||
expected, actual, error, detailed_message, miscompare_callback);
|
||||
break;
|
||||
case C128:
|
||||
return NearComparator<complex128>::Compare(
|
||||
expected, actual, error, detailed_message, miscompare_callback);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported primitive type in near comparator: "
|
||||
<< PrimitiveType_Name(expected.shape().element_type())
|
||||
|
@ -118,6 +118,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
|
||||
auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
|
||||
EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
|
||||
|
||||
auto c128_lit = LiteralUtil::CreateR0<complex128>({3.14f, 2.78f});
|
||||
EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString());
|
||||
|
||||
auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
|
||||
EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
|
||||
|
||||
@ -469,6 +472,21 @@ TEST_F(LiteralUtilTest, C64Equality) {
|
||||
EXPECT_NE(vector, vector_reversed);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, C128Equality) {
|
||||
// Test equality with tuples.
|
||||
auto vector = LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
|
||||
// Tuple with the same elements. One element is shared with the original
|
||||
// tuple, the other is a clone of the element in the original tuple.
|
||||
auto vector_clone =
|
||||
LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
EXPECT_EQ(vector, vector_clone);
|
||||
|
||||
auto vector_reversed =
|
||||
LiteralUtil::CreateR1<complex128>({{3.0, 4.0}, {1.0, 2.0}});
|
||||
EXPECT_NE(vector, vector_reversed);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, IsAllTuple) {
|
||||
auto element1 = LiteralUtil::CreateR0<float>(0.0);
|
||||
auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
|
||||
@ -836,6 +854,13 @@ TEST_F(LiteralUtilTest, PopulateR1C64) {
|
||||
EXPECT_EQ(output, expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateR1C128) {
|
||||
Literal output(ShapeUtil::MakeShape(C128, {1}));
|
||||
output.PopulateR1<complex128>({{77, 88}});
|
||||
auto expected = LiteralUtil::CreateR1<complex128>({{77, 88}});
|
||||
EXPECT_EQ(output, expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateR2C64) {
|
||||
Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
|
||||
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
|
||||
@ -897,6 +922,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
|
||||
EXPECT_EQ(output, expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateWithValueR2C128) {
|
||||
Literal output(ShapeUtil::MakeShape(C128, {2, 2}));
|
||||
output.PopulateWithValue<complex128>({4, 2});
|
||||
auto expected =
|
||||
LiteralUtil::CreateR2<complex128>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
|
||||
EXPECT_EQ(output, expected);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
|
||||
Literal output(ShapeUtil::MakeShape(F16, {}));
|
||||
half h(0.25f);
|
||||
@ -1308,7 +1341,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
|
||||
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
|
||||
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
|
||||
}}, layout_r4_dim0major_);
|
||||
// clang-format on
|
||||
auto c128 = LiteralUtil::CreateR4WithLayout<complex128>({{
|
||||
{{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
|
||||
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
|
||||
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
|
||||
}}, layout_r4_dim0major_); // clang-format on
|
||||
Literal conv;
|
||||
|
||||
conv = s8.Convert(U16).ConsumeValueOrDie();
|
||||
@ -1374,10 +1411,20 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
|
||||
conv = s32.Convert(U16).ConsumeValueOrDie();
|
||||
EXPECT_EQ(conv, u16);
|
||||
|
||||
conv = s32.Convert(C128).ConsumeValueOrDie();
|
||||
EXPECT_EQ(conv, c128);
|
||||
|
||||
conv = f16.Convert(C128).ConsumeValueOrDie();
|
||||
EXPECT_EQ(conv, c128);
|
||||
|
||||
EXPECT_EQ(s32.Convert(TUPLE).status().code(),
|
||||
tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_EQ(c128.Convert(F32).status().code(),
|
||||
tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_EQ(c128.Convert(S32).status().code(),
|
||||
tensorflow::error::UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, BitcastConvert) {
|
||||
@ -1739,7 +1786,8 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
|
||||
|
||||
Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
|
||||
ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
|
||||
ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}),
|
||||
ShapeUtil::MakeShape(C128, {})}));
|
||||
|
||||
EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
|
||||
EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
|
||||
@ -1747,6 +1795,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
|
||||
EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
|
||||
EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
|
||||
EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
|
||||
EXPECT_EQ(tuple.Get<complex128>({}, {4}), complex128(0.0, 0.0));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
|
||||
@ -1756,6 +1805,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
|
||||
auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
|
||||
auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
|
||||
auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto vector_c128 =
|
||||
LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
|
||||
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
|
||||
auto vector_half =
|
||||
@ -1776,6 +1827,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
|
||||
EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
|
||||
EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
|
||||
EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
|
||||
EXPECT_EQ(vector_c128, to_from_proto(vector_c128));
|
||||
EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
|
||||
EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
|
||||
EXPECT_EQ(tuple, to_from_proto(tuple));
|
||||
|
@ -130,6 +130,8 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
return LiteralUtil::CreateR0<double>(0);
|
||||
case C64:
|
||||
return LiteralUtil::CreateR0<complex64>(0);
|
||||
case C128:
|
||||
return LiteralUtil::CreateR0<complex128>(0);
|
||||
case PRED:
|
||||
return LiteralUtil::CreateR0<bool>(false);
|
||||
case TUPLE:
|
||||
@ -165,6 +167,8 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
return LiteralUtil::CreateR0<double>(1);
|
||||
case C64:
|
||||
return LiteralUtil::CreateR0<complex64>(1);
|
||||
case C128:
|
||||
return LiteralUtil::CreateR0<complex128>(1);
|
||||
case PRED:
|
||||
return LiteralUtil::CreateR0<bool>(true);
|
||||
case S16:
|
||||
@ -201,6 +205,8 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
-std::numeric_limits<double>::infinity());
|
||||
case C64:
|
||||
LOG(FATAL) << "C64 element type has no minimum value";
|
||||
case C128:
|
||||
LOG(FATAL) << "C128 element type has no minimum value";
|
||||
case PRED:
|
||||
return LiteralUtil::CreateR0<bool>(false);
|
||||
case S16:
|
||||
@ -345,6 +351,10 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
new_literal.Set<complex64>(to_multi_index,
|
||||
literal.Get<complex64>(from_multi_index));
|
||||
break;
|
||||
case C128:
|
||||
new_literal.Set<complex128>(to_multi_index,
|
||||
literal.Get<complex128>(from_multi_index));
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled primitive element type: "
|
||||
<< PrimitiveType_Name(literal.shape().element_type());
|
||||
@ -393,6 +403,10 @@ Literal ConvertType(LiteralSlice literal) {
|
||||
return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
|
||||
case U64:
|
||||
return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
|
||||
|
||||
case C128:
|
||||
return LiteralUtil::CreateR0<complex128>(
|
||||
literal.GetFirstElement<complex128>());
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled primitive type "
|
||||
<< literal.shape().element_type();
|
||||
|
@ -27,7 +27,7 @@ bool IsFloatingPointType(PrimitiveType type) {
|
||||
return type == F16 || type == F32 || type == F64 || type == BF16;
|
||||
}
|
||||
|
||||
bool IsComplexType(PrimitiveType type) { return type == C64; }
|
||||
bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; }
|
||||
|
||||
bool IsSignedIntegralType(PrimitiveType type) {
|
||||
return type == S8 || type == S16 || type == S32 || type == S64;
|
||||
@ -67,6 +67,9 @@ int BitWidth(PrimitiveType type) {
|
||||
case C64:
|
||||
return 64;
|
||||
|
||||
case C128:
|
||||
return 128;
|
||||
|
||||
case TUPLE:
|
||||
LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
|
||||
|
||||
@ -82,6 +85,8 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
|
||||
switch (complex_type) {
|
||||
case C64:
|
||||
return F32;
|
||||
case C128:
|
||||
return F64;
|
||||
default:
|
||||
LOG(FATAL) << "Primitive type is not complex: "
|
||||
<< PrimitiveType_Name(complex_type);
|
||||
|
@ -126,6 +126,11 @@ inline PrimitiveType NativeToPrimitiveType<complex64>() {
|
||||
return C64;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline PrimitiveType NativeToPrimitiveType<complex128>() {
|
||||
return C128;
|
||||
}
|
||||
|
||||
bool IsFloatingPointType(PrimitiveType type);
|
||||
|
||||
bool IsComplexType(PrimitiveType type);
|
||||
@ -225,6 +230,11 @@ struct PrimitiveTypeToNative<C64> {
|
||||
using type = complex64;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PrimitiveTypeToNative<C128> {
|
||||
using type = complex128;
|
||||
};
|
||||
|
||||
// Returns the lower-case name of the given primitive type.
|
||||
const string& LowercasePrimitiveTypeName(PrimitiveType s);
|
||||
|
||||
|
@ -54,6 +54,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
|
||||
return NPY_FLOAT64;
|
||||
case C64:
|
||||
return NPY_COMPLEX64;
|
||||
case C128:
|
||||
return NPY_COMPLEX128;
|
||||
case TUPLE:
|
||||
return NPY_OBJECT;
|
||||
default:
|
||||
@ -89,6 +91,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) {
|
||||
return F64;
|
||||
case NPY_COMPLEX64:
|
||||
return C64;
|
||||
case NPY_COMPLEX128:
|
||||
return C128;
|
||||
case NPY_OBJECT:
|
||||
return TUPLE;
|
||||
default:
|
||||
@ -111,6 +115,7 @@ bool NumpyTypeIsValid(int np_type) {
|
||||
case NPY_FLOAT32:
|
||||
case NPY_FLOAT64:
|
||||
case NPY_COMPLEX64:
|
||||
case NPY_COMPLEX128:
|
||||
case NPY_OBJECT:
|
||||
return true;
|
||||
default:
|
||||
@ -430,6 +435,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
||||
case NPY_COMPLEX64:
|
||||
CopyNumpyArrayToLiteral<complex64>(py_array, literal);
|
||||
break;
|
||||
case NPY_COMPLEX128:
|
||||
CopyNumpyArrayToLiteral<complex128>(py_array, literal);
|
||||
break;
|
||||
default:
|
||||
return InvalidArgument(
|
||||
"No XLA literal container for Numpy type number: %d", np_type);
|
||||
@ -470,6 +478,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||
case NPY_COMPLEX64:
|
||||
CopyLiteralToNumpyArray<complex64>(literal, py_array);
|
||||
break;
|
||||
case NPY_COMPLEX128:
|
||||
CopyLiteralToNumpyArray<complex128>(literal, py_array);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
|
||||
}
|
||||
|
@ -199,6 +199,7 @@ XLA_ELEMENT_TYPE_TO_DTYPE = {
|
||||
xla_data_pb2.F32: np.dtype('float32'),
|
||||
xla_data_pb2.F64: np.dtype('float64'),
|
||||
xla_data_pb2.C64: np.dtype('complex64'),
|
||||
xla_data_pb2.C128: np.dtype('complex128'),
|
||||
xla_data_pb2.TUPLE: np.dtype(np.object),
|
||||
}
|
||||
|
||||
|
@ -223,6 +223,7 @@ cc_library(
|
||||
"hlo_evaluator_typed_visitor.h",
|
||||
"hlo_evaluator_typed_visitor_bfloat16.cc",
|
||||
"hlo_evaluator_typed_visitor_bool.cc",
|
||||
"hlo_evaluator_typed_visitor_complex128.cc",
|
||||
"hlo_evaluator_typed_visitor_complex64.cc",
|
||||
"hlo_evaluator_typed_visitor_double.cc",
|
||||
"hlo_evaluator_typed_visitor_float.cc",
|
||||
@ -259,6 +260,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/meta:type_traits",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -877,6 +877,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
|
||||
case C64:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
|
||||
break;
|
||||
case C128:
|
||||
TF_RETURN_IF_ERROR(InvertConstant<complex128>(*b, &new_literal));
|
||||
break;
|
||||
default:
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -856,7 +856,8 @@ Status EmitNonBatchDotOperation(
|
||||
const HloModuleConfig& hlo_module_config,
|
||||
const TargetMachineFeatures& target_machine_features) {
|
||||
PrimitiveType type = target_array.GetShape().element_type();
|
||||
TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type);
|
||||
TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type ||
|
||||
C128 == type);
|
||||
DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
|
||||
target_array, lhs_array, rhs_array, addend_array,
|
||||
executable_run_options_value, b, hlo_module_config,
|
||||
|
@ -239,10 +239,12 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) {
|
||||
int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
|
||||
int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
|
||||
DCHECK_GE(byte_size, 0);
|
||||
// Largest scalar is a complex64 so we don't need to worry about the
|
||||
// Largest scalar is a complex128 so we don't need to worry about the
|
||||
// int64->int truncation here.
|
||||
DCHECK_LE(byte_size, 8);
|
||||
return byte_size;
|
||||
DCHECK_LE(byte_size, 16);
|
||||
|
||||
// Allocations may be 8-byte aligned if part of a small block.
|
||||
return std::min(8LL, byte_size);
|
||||
}
|
||||
|
||||
int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
|
||||
@ -942,7 +944,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
|
||||
auto rhs = dot->operand(1);
|
||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
|
||||
/*supported_types=*/{F16, F32, F64, C64}));
|
||||
/*supported_types=*/{F16, F32, F64, C64, C128}));
|
||||
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
|
||||
|
||||
if (dnums.lhs_contracting_dimensions_size() != 1) {
|
||||
@ -1114,7 +1116,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
|
||||
auto rhs = convolution->operand(1);
|
||||
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
|
||||
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
|
||||
/*supported_types=*/{F16, F32, C64}));
|
||||
/*supported_types=*/{F16, F32, C64, C128}));
|
||||
|
||||
// TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
|
||||
// different data layouts.
|
||||
|
@ -206,6 +206,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
|
||||
return &DoGemm<double>;
|
||||
case C64:
|
||||
return &DoGemm<std::complex<float>>;
|
||||
case C128:
|
||||
return &DoGemm<std::complex<double>>;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type.";
|
||||
}
|
||||
@ -221,6 +223,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type)
|
||||
return &DoGemmWithAlgorithm<double>;
|
||||
case C64:
|
||||
return &DoGemmWithAlgorithm<std::complex<float>>;
|
||||
case C128:
|
||||
return &DoGemmWithAlgorithm<std::complex<double>>;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type.";
|
||||
}
|
||||
@ -235,6 +239,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
|
||||
return &DoGemmAutotune<double>;
|
||||
case C64:
|
||||
return &DoGemmAutotune<std::complex<float>>;
|
||||
case C128:
|
||||
return &DoGemmAutotune<std::complex<double>>;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type.";
|
||||
}
|
||||
@ -255,6 +261,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
|
||||
return se::blas::ComputationType::kF64;
|
||||
case C64:
|
||||
return se::blas::ComputationType::kComplexF32;
|
||||
case C128:
|
||||
return se::blas::ComputationType::kComplexF64;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported type.";
|
||||
}
|
||||
|
@ -54,7 +54,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
|
||||
PrimitiveType output_primitive_type = output_shape.element_type();
|
||||
bool type_is_allowed =
|
||||
(output_primitive_type == F16 || output_primitive_type == F32 ||
|
||||
output_primitive_type == F64 || output_primitive_type == C64);
|
||||
output_primitive_type == F64 || output_primitive_type == C64 ||
|
||||
output_primitive_type == C128);
|
||||
return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
|
||||
IsRank2(rhs_shape, batch_dimensions_size) &&
|
||||
IsRank2(output_shape, batch_dimensions_size) &&
|
||||
|
@ -136,6 +136,37 @@ StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
StatusOr<Literal> Compare<complex128>(const Shape& shape, HloOpcode opcode,
|
||||
LiteralSlice lhs_literal,
|
||||
LiteralSlice rhs_literal) {
|
||||
std::function<bool(complex128, complex128)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
|
||||
return lhs_el == rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kNe:
|
||||
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
|
||||
return lhs_el != rhs_el;
|
||||
};
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
||||
<< HloOpcodeString(opcode);
|
||||
}
|
||||
|
||||
Literal result(shape);
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<bool>([&](absl::Span<const int64> multi_index) {
|
||||
return compare_op(lhs_literal.Get<complex128>(multi_index),
|
||||
rhs_literal.Get<complex128>(multi_index));
|
||||
}));
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Note that unsupported types by the typed visitor does not necessarily imply
|
||||
@ -170,6 +201,8 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations)
|
||||
absl::make_unique<HloEvaluatorTypedVisitor<double>>(this);
|
||||
typed_visitors_[C64] =
|
||||
absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this);
|
||||
typed_visitors_[C128] =
|
||||
absl::make_unique<HloEvaluatorTypedVisitor<complex128>>(this);
|
||||
|
||||
// Most of the evaluator computations we use don't support BF16 (e.g.,
|
||||
// std::ceil, std::tanh). To make evaluator work with BF16, we set all
|
||||
@ -500,6 +533,13 @@ Status HloEvaluator::HandleReal(HloInstruction* real) {
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
case C128: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<float, complex128>(
|
||||
real, [](complex128 elem_operand) { return std::real(elem_operand); },
|
||||
GetEvaluatedLiteralFor(operand));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
case F16: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
|
||||
real, [](Eigen::half elem_operand) { return elem_operand; },
|
||||
@ -530,11 +570,29 @@ Status HloEvaluator::HandleReal(HloInstruction* real) {
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleImag(HloInstruction* imag) {
|
||||
auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
|
||||
imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
|
||||
GetEvaluatedLiteralFor(imag->operand(0)));
|
||||
auto operand = imag->operand(0);
|
||||
switch (operand->shape().element_type()) {
|
||||
case C64: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
|
||||
imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
|
||||
GetEvaluatedLiteralFor(imag->operand(0)));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
case C128: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<double, complex128>(
|
||||
imag, [](complex128 elem_operand) { return std::imag(elem_operand); },
|
||||
GetEvaluatedLiteralFor(imag->operand(0)));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: "
|
||||
<< PrimitiveType_Name(operand->shape().element_type());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -544,11 +602,27 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) {
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape()));
|
||||
|
||||
Literal result(complex->shape());
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
|
||||
return std::complex<float>(real.Get<float>(multi_index),
|
||||
imag.Get<float>(multi_index));
|
||||
}));
|
||||
switch (complex->shape().element_type()) {
|
||||
case C64: {
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<complex64>([&](absl::Span<const int64> multi_index) {
|
||||
return std::complex<float>(real.Get<float>(multi_index),
|
||||
imag.Get<float>(multi_index));
|
||||
}));
|
||||
break;
|
||||
}
|
||||
case C128: {
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<complex128>([&](absl::Span<const int64> multi_index) {
|
||||
return std::complex<float>(real.Get<double>(multi_index),
|
||||
imag.Get<double>(multi_index));
|
||||
}));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: "
|
||||
<< PrimitiveType_Name(complex->shape().element_type());
|
||||
}
|
||||
|
||||
evaluated_[complex] = std::move(result);
|
||||
return Status::OK();
|
||||
@ -647,6 +721,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
|
||||
Compare<complex64>(compare->shape(), opcode,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case C128: {
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||
Compare<complex128>(compare->shape(), opcode,
|
||||
lhs_literal, rhs_literal));
|
||||
} break;
|
||||
default:
|
||||
LOG(FATAL) << "HandleCompare: unknown primitive type: "
|
||||
<< PrimitiveType_Name(lhs->shape().element_type());
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/meta/type_traits.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -39,9 +40,8 @@ namespace xla {
|
||||
// Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is
|
||||
// a "private" header that's not exposed outside of hlo_evaluator.cc.
|
||||
template <typename T>
|
||||
using is_complex_t = std::is_same<T, complex64>;
|
||||
template <typename T>
|
||||
using is_complex64_t = std::is_same<T, complex64>;
|
||||
using is_complex_t =
|
||||
absl::disjunction<std::is_same<T, complex64>, std::is_same<T, complex128>>;
|
||||
|
||||
// It's UB to use std::sort with std::less<float>, because of NaNs. Define
|
||||
// "safe" less functions which are actually strict weak orders. -NaN and NaN
|
||||
@ -212,7 +212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr>
|
||||
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
|
||||
Status HandleAbs(HloInstruction* abs) {
|
||||
const Literal& operand_literal =
|
||||
parent_->GetEvaluatedLiteralFor(abs->operand(0));
|
||||
@ -231,6 +231,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
// specifying the ElementwiseT explicitly as C64 is needed below.
|
||||
if (abs->operand(0)->shape().element_type() == C64) {
|
||||
return HandleAbs<complex64>(abs);
|
||||
} else if (abs->operand(0)->shape().element_type() == C128) {
|
||||
return HandleAbs<complex128>(abs);
|
||||
}
|
||||
return HandleAbs<ElementwiseT>(abs);
|
||||
}
|
||||
@ -1616,6 +1618,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map));
|
||||
break;
|
||||
}
|
||||
case C128: {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex128>(map));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "HandleMap: unhandled primitive type for "
|
||||
"input operand: "
|
||||
@ -3040,6 +3046,7 @@ extern template class HloEvaluatorTypedVisitor<Eigen::half, float>;
|
||||
extern template class HloEvaluatorTypedVisitor<float>;
|
||||
extern template class HloEvaluatorTypedVisitor<double>;
|
||||
extern template class HloEvaluatorTypedVisitor<complex64>;
|
||||
extern template class HloEvaluatorTypedVisitor<complex128>;
|
||||
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
|
||||
|
||||
} // namespace xla
|
||||
|
@ -0,0 +1,22 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||
|
||||
namespace xla {
|
||||
template class HloEvaluatorTypedVisitor<complex128>;
|
||||
} // namespace xla
|
@ -551,6 +551,17 @@ ENTRY %Transpose.v2 () -> s32[1,2,3] {
|
||||
ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
{
|
||||
"TransposeC128",
|
||||
R"(HloModule TransposeC128_module
|
||||
|
||||
ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
|
||||
%input = c128[1,2,3]{2,1,0} parameter(0)
|
||||
ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// Dynamic slice
|
||||
|
@ -188,7 +188,16 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
|
||||
}
|
||||
return cplx_t;
|
||||
}
|
||||
// A Tuple contains an array of pointers. Use i8*.
|
||||
case C128: {
|
||||
auto cplx_t = module->getTypeByName("complex128");
|
||||
if (cplx_t == nullptr) {
|
||||
return llvm::StructType::create(
|
||||
{llvm::Type::getDoubleTy(module->getContext()),
|
||||
llvm::Type::getDoubleTy(module->getContext())},
|
||||
"complex128", /*isPacked=*/true);
|
||||
}
|
||||
return cplx_t;
|
||||
} // A Tuple contains an array of pointers. Use i8*.
|
||||
case TUPLE:
|
||||
// An Opaque is like a void*, use i8*.
|
||||
case OPAQUE:
|
||||
|
@ -906,6 +906,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
broadcast_dimensions));
|
||||
if (lhs.element_type() == F32 && rhs.element_type() == F32) {
|
||||
return ShapeUtil::ChangeElementType(shape, C64);
|
||||
} else if (lhs.element_type() == F64 && rhs.element_type() == F64) {
|
||||
return ShapeUtil::ChangeElementType(shape, C128);
|
||||
} else {
|
||||
return Unimplemented("Complex component type is not implemented.");
|
||||
}
|
||||
@ -1733,7 +1735,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
case FFT:
|
||||
case IFFT:
|
||||
if (in.element_type() != C64) {
|
||||
return InvalidArgument("%s requires C64 input type, found %s.",
|
||||
return InvalidArgument("%s requires complex input type, found %s.",
|
||||
FftType_Name(fft_type),
|
||||
PrimitiveType_Name(in.element_type()));
|
||||
}
|
||||
|
@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test {
|
||||
protected:
|
||||
// Some handy scalar shapes.
|
||||
const Shape s32_ = ShapeUtil::MakeShape(S32, {});
|
||||
const Shape f16_ = ShapeUtil::MakeShape(F16, {});
|
||||
const Shape f32_ = ShapeUtil::MakeShape(F32, {});
|
||||
const Shape f64_ = ShapeUtil::MakeShape(F64, {});
|
||||
const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
|
||||
@ -260,8 +261,8 @@ TEST_F(ShapeInferenceTest, Complex) {
|
||||
ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
|
||||
// Component types must match.
|
||||
ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
|
||||
// Only F32->C64 supported.
|
||||
ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok());
|
||||
// Only F32->C64 and F64->C128 supported.
|
||||
ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
|
||||
// Validate correct uses.
|
||||
Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
|
||||
@ -285,6 +286,9 @@ TEST_F(ShapeInferenceTest, Complex) {
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
|
||||
ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
|
||||
|
@ -378,6 +378,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
case U32:
|
||||
case U64:
|
||||
case C64:
|
||||
case C128:
|
||||
case TUPLE:
|
||||
case OPAQUE:
|
||||
case TOKEN:
|
||||
@ -639,6 +640,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return sizeof(double);
|
||||
case C64:
|
||||
return sizeof(complex64);
|
||||
case C128:
|
||||
return sizeof(complex128);
|
||||
case TOKEN:
|
||||
// Tokens require no space.
|
||||
return 0;
|
||||
|
@ -431,7 +431,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
std::is_same<NativeT, complex64>::value ||
|
||||
std::is_same<NativeT, complex128>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
@ -455,7 +456,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
std::is_same<NativeT, complex64>::value ||
|
||||
std::is_same<NativeT, complex128>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
|
||||
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
|
||||
@ -480,7 +482,8 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
std::is_same<NativeT, complex64>::value ||
|
||||
std::is_same<NativeT, complex128>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
|
||||
@ -506,7 +509,8 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
std::is_same<NativeT, complex64>::value ||
|
||||
std::is_same<NativeT, complex128>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
|
||||
@ -532,7 +536,8 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
|
||||
std::is_same<NativeT, double>::value ||
|
||||
std::is_same<NativeT, bfloat16>::value ||
|
||||
std::is_same<NativeT, half>::value ||
|
||||
std::is_same<NativeT, complex64>::value,
|
||||
std::is_same<NativeT, complex64>::value ||
|
||||
std::is_same<NativeT, complex128>::value,
|
||||
"Float or complex type required when specifying an ErrorSpec");
|
||||
Literal expected_literal =
|
||||
LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
|
||||
|
@ -41,6 +41,7 @@ using ::tensorflow::uint32;
|
||||
using ::tensorflow::uint64;
|
||||
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
using ::Eigen::half;
|
||||
|
||||
|
@ -56,6 +56,7 @@ enum PrimitiveType {
|
||||
|
||||
// Complex values of fixed width.
|
||||
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>.
|
||||
C128 = 18; // Paired F64 (real, imag), as in std::complex<double>.
|
||||
|
||||
// A tuple is a polymorphic sequence; e.g. a shape that holds different
|
||||
// sub-shapes. They are used for things like returning multiple values from a
|
||||
@ -75,7 +76,7 @@ enum PrimitiveType {
|
||||
// primitive type will have empty dimensions and tuple_shapes fields.
|
||||
TOKEN = 17;
|
||||
|
||||
// Next = 18
|
||||
// Next = 19
|
||||
}
|
||||
|
||||
// Describes the padding configuration for Pad operation. The padding amount on
|
||||
@ -367,6 +368,7 @@ message LiteralProto {
|
||||
repeated float f32s = 8;
|
||||
repeated double f64s = 9;
|
||||
repeated float c64s = 12; // Stored as interleaved real, imag floats.
|
||||
repeated double c128s = 18; // Stored as interleaved real, imag doubles.
|
||||
repeated LiteralProto tuple_literals = 10;
|
||||
// The F16s, BF16s, U16s and S16s are encoded in little endian byte order
|
||||
bytes f16s = 11;
|
||||
@ -374,7 +376,7 @@ message LiteralProto {
|
||||
bytes u16s = 16;
|
||||
bytes s16s = 17;
|
||||
repeated int64 sparse_indices = 14;
|
||||
// Next = 18
|
||||
// Next = 19
|
||||
}
|
||||
|
||||
message WindowDimension {
|
||||
|
Loading…
Reference in New Issue
Block a user