Convert complex and unsigned integer tensors to and from dense elements attr

PiperOrigin-RevId: 311022341
Change-Id: Ib1f5bd9f3a0a857e60c50cf999d4371c987c091d
This commit is contained in:
Smit Hinsu 2020-05-11 16:54:05 -07:00 committed by TensorFlower Gardener
parent 1bfde45145
commit abaffb8ad1
3 changed files with 96 additions and 38 deletions

View File

@ -823,6 +823,7 @@ cc_library(
":mangling_util",
":tensorflow_attributes",
":tensorflow_types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@ -132,13 +133,21 @@ StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
case DTYPE: \
return ConvertFlatTensor<CTYPE>(input_tensor, type);
// TODO(fengliuai): customize the conversions for more types.
// TODO(fengliuai): customize the conversions for quantized and string types.
switch (input_dtype) {
CONVERT_FLAT(DT_BOOL, bool)
CONVERT_FLAT(DT_FLOAT, float)
CONVERT_FLAT(DT_DOUBLE, double)
CONVERT_FLAT(DT_INT8, int8)
CONVERT_FLAT(DT_INT16, int16)
CONVERT_FLAT(DT_INT32, int32)
CONVERT_FLAT(DT_INT64, int64)
CONVERT_FLAT(DT_UINT8, uint8)
CONVERT_FLAT(DT_UINT16, uint16)
CONVERT_FLAT(DT_UINT32, uint32)
CONVERT_FLAT(DT_UINT64, uint64)
CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
// BFLOAT16 is a special case that it needs to be cast to double type to
// match its storage type.
@ -215,6 +224,15 @@ void ConvertStringElementsAttr(
output->Add({val.data(), val.size()});
}
template <typename T>
void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
protobuf::RepeatedField<T>* output) {
for (const auto& val : attr.getValues<std::complex<T>>()) {
output->Add(val.real());
output->Add(val.imag());
}
}
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {
@ -310,6 +328,12 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
output->mutable_int_val());
break;
case DT_UINT32:
ConvertElementsAttr(dense_attr, output->mutable_uint32_val());
break;
case DT_UINT64:
ConvertElementsAttr(dense_attr, output->mutable_uint64_val());
break;
case DT_INT64:
ConvertElementsAttr(dense_attr, output->mutable_int64_val());
break;
@ -324,6 +348,12 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
output->mutable_string_val());
break;
case DT_COMPLEX64:
ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
break;
case DT_COMPLEX128:
ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
break;
default:
return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
DataTypeString(output_dtype)));

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include <cstring>
#include <initializer_list>
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@ -99,48 +100,74 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) {
EXPECT_EQ(string_values[3], mlir::StringRef("four"));
}
TEST(ConvertTypeToTensorTypeTest, Convert16BitFloats) {
class ConvertTensorTest : public ::testing::Test {
protected:
template <typename T>
void VerifyConversion(std::initializer_list<T> values, DataType dtype,
mlir::Type expected_ty) {
mlir::Builder b(expected_ty.getContext());
Tensor tensor(dtype, TensorShape({static_cast<int64>(values.size())}));
tensor.flat<T>().setValues(values);
auto value_or = ConvertTensor(tensor, &b);
TF_ASSERT_OK(value_or.status());
auto attr = value_or.ValueOrDie();
EXPECT_EQ(attr.getType().getElementType(), expected_ty);
Tensor out;
TF_ASSERT_OK(ConvertToTensor(attr, &out));
test::ExpectTensorEqual<T>(tensor, out);
}
};
TEST_F(ConvertTensorTest, Simple) {
RegisterDialects();
mlir::MLIRContext context;
mlir::Builder b(&context);
ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
{Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
ASSERT_NO_FATAL_FAILURE(
VerifyConversion<bfloat16>({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16,
mlir::FloatType::getBF16(&context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<float>(
{1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<double>(
{1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
{
// Create the sample tensor to convert.
Tensor tensor(DT_HALF, TensorShape({1}));
auto Tt = tensor.flat<Eigen::half>();
Tt.setValues({Eigen::half(1.0)});
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
{1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
{1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
{1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
{1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
auto value_or = ConvertTensor(tensor, &b);
TF_EXPECT_OK(value_or.status());
auto attr = value_or.ValueOrDie();
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
{1, 2}, DT_UINT8,
mlir::IntegerType::get(
8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
{1, 2}, DT_UINT16,
mlir::IntegerType::get(
16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
{1, 2}, DT_UINT32,
mlir::IntegerType::get(
32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
{1, 2}, DT_UINT64,
mlir::IntegerType::get(
64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
EXPECT_TRUE(attr.getType().getElementType().isF16());
Tensor out;
TF_ASSERT_OK(ConvertToTensor(attr, &out));
test::ExpectTensorEqual<Eigen::half>(tensor, out);
}
{
// Create the sample tensor to convert.
Tensor tensor(DT_BFLOAT16, TensorShape({2}));
auto Tt = tensor.flat<bfloat16>();
Tt.setValues({bfloat16(1.0), bfloat16(-1.0)});
auto value_or = ConvertTensor(tensor, &b);
TF_EXPECT_OK(value_or.status());
auto attr = value_or.ValueOrDie();
EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
EXPECT_TRUE(attr.getType().getElementType().isBF16());
Tensor out;
TF_ASSERT_OK(ConvertToTensor(attr, &out));
test::ExpectTensorEqual<bfloat16>(tensor, out);
}
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
mlir::ComplexType::get(mlir::FloatType::getF32(&context))));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<double>>(
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128,
mlir::ComplexType::get(mlir::FloatType::getF64(&context))));
}
} // namespace