Import and export constants of complex types for HLO

PiperOrigin-RevId: 310660936
Change-Id: I732ec0d8f16a71b0529408a677f6c144ce299228
This commit is contained in:
Smit Hinsu 2020-05-08 17:44:10 -07:00 committed by TensorFlower Gardener
parent 4385e797a9
commit d2fdc7b012
4 changed files with 19 additions and 3 deletions

View File

@ -139,6 +139,10 @@ StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
return CreateDenseAttrFromLiteral<uint32>(type, literal);
case PrimitiveType::U64:
return CreateDenseAttrFromLiteral<uint64>(type, literal);
case PrimitiveType::C64:
return CreateDenseAttrFromLiteral<complex64>(type, literal);
case PrimitiveType::C128:
return CreateDenseAttrFromLiteral<complex128>(type, literal);
default:
return tensorflow::errors::Internal(
absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));

View File

@ -933,6 +933,8 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>)
case xla::PrimitiveType::F16: {
llvm::SmallVector<xla::half, 16> values;
values.reserve(attr.getNumElements());

View File

@ -294,6 +294,12 @@ func @main() {
// CHECK: f16[4] constant({1, -4, -65504, 0.015625}
%cst_8 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16>
// CHECK: c64[] constant((1, 0))
%cst_9 = constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
// CHECK: c128[] constant((1, 0))
%cst_10 = constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
return
}

View File

@ -212,10 +212,14 @@ add {
// CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
%constant.3 = bf16[4] constant({1, 2, 3, 4})
// CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
%constant.4 = c64[] constant((1, 0))
// CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
%constant.5 = c128[] constant((1, 0))
// CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
ROOT %constant.4 = f16[4] constant({1, -4, -65504, 0.015625})
ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625})
}
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual