Convert complex and unsigned integer tensors to and from dense elements attr
PiperOrigin-RevId: 311022341 Change-Id: Ib1f5bd9f3a0a857e60c50cf999d4371c987c091d
This commit is contained in:
parent
1bfde45145
commit
abaffb8ad1
@ -823,6 +823,7 @@ cc_library(
|
||||
":mangling_util",
|
||||
":tensorflow_attributes",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
||||
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
@ -132,13 +133,21 @@ StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
|
||||
case DTYPE: \
|
||||
return ConvertFlatTensor<CTYPE>(input_tensor, type);
|
||||
|
||||
// TODO(fengliuai): customize the conversions for more types.
|
||||
// TODO(fengliuai): customize the conversions for quantized and string types.
|
||||
switch (input_dtype) {
|
||||
CONVERT_FLAT(DT_BOOL, bool)
|
||||
CONVERT_FLAT(DT_FLOAT, float)
|
||||
CONVERT_FLAT(DT_DOUBLE, double)
|
||||
CONVERT_FLAT(DT_INT8, int8)
|
||||
CONVERT_FLAT(DT_INT16, int16)
|
||||
CONVERT_FLAT(DT_INT32, int32)
|
||||
CONVERT_FLAT(DT_INT64, int64)
|
||||
CONVERT_FLAT(DT_UINT8, uint8)
|
||||
CONVERT_FLAT(DT_UINT16, uint16)
|
||||
CONVERT_FLAT(DT_UINT32, uint32)
|
||||
CONVERT_FLAT(DT_UINT64, uint64)
|
||||
CONVERT_FLAT(DT_COMPLEX64, std::complex<float>)
|
||||
CONVERT_FLAT(DT_COMPLEX128, std::complex<double>)
|
||||
|
||||
// BFLOAT16 is a special case that it needs to be cast to double type to
|
||||
// match its storage type.
|
||||
@ -215,6 +224,15 @@ void ConvertStringElementsAttr(
|
||||
output->Add({val.data(), val.size()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr,
|
||||
protobuf::RepeatedField<T>* output) {
|
||||
for (const auto& val : attr.getValues<std::complex<T>>()) {
|
||||
output->Add(val.real());
|
||||
output->Add(val.imag());
|
||||
}
|
||||
}
|
||||
|
||||
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
||||
Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
@ -310,6 +328,12 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
|
||||
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
|
||||
output->mutable_int_val());
|
||||
break;
|
||||
case DT_UINT32:
|
||||
ConvertElementsAttr(dense_attr, output->mutable_uint32_val());
|
||||
break;
|
||||
case DT_UINT64:
|
||||
ConvertElementsAttr(dense_attr, output->mutable_uint64_val());
|
||||
break;
|
||||
case DT_INT64:
|
||||
ConvertElementsAttr(dense_attr, output->mutable_int64_val());
|
||||
break;
|
||||
@ -324,6 +348,12 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
|
||||
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
|
||||
output->mutable_string_val());
|
||||
break;
|
||||
case DT_COMPLEX64:
|
||||
ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val());
|
||||
break;
|
||||
case DT_COMPLEX128:
|
||||
ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val());
|
||||
break;
|
||||
default:
|
||||
return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
|
||||
DataTypeString(output_dtype)));
|
||||
|
||||
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <initializer_list>
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
@ -99,48 +100,74 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) {
|
||||
EXPECT_EQ(string_values[3], mlir::StringRef("four"));
|
||||
}
|
||||
|
||||
TEST(ConvertTypeToTensorTypeTest, Convert16BitFloats) {
|
||||
class ConvertTensorTest : public ::testing::Test {
|
||||
protected:
|
||||
template <typename T>
|
||||
void VerifyConversion(std::initializer_list<T> values, DataType dtype,
|
||||
mlir::Type expected_ty) {
|
||||
mlir::Builder b(expected_ty.getContext());
|
||||
Tensor tensor(dtype, TensorShape({static_cast<int64>(values.size())}));
|
||||
tensor.flat<T>().setValues(values);
|
||||
|
||||
auto value_or = ConvertTensor(tensor, &b);
|
||||
TF_ASSERT_OK(value_or.status());
|
||||
auto attr = value_or.ValueOrDie();
|
||||
|
||||
EXPECT_EQ(attr.getType().getElementType(), expected_ty);
|
||||
|
||||
Tensor out;
|
||||
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
||||
|
||||
test::ExpectTensorEqual<T>(tensor, out);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ConvertTensorTest, Simple) {
|
||||
RegisterDialects();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder b(&context);
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
|
||||
{Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
|
||||
ASSERT_NO_FATAL_FAILURE(
|
||||
VerifyConversion<bfloat16>({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16,
|
||||
mlir::FloatType::getBF16(&context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<float>(
|
||||
{1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<double>(
|
||||
{1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context)));
|
||||
|
||||
{
|
||||
// Create the sample tensor to convert.
|
||||
Tensor tensor(DT_HALF, TensorShape({1}));
|
||||
auto Tt = tensor.flat<Eigen::half>();
|
||||
Tt.setValues({Eigen::half(1.0)});
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int8>(
|
||||
{1, -1}, DT_INT8, mlir::IntegerType::get(8, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int16>(
|
||||
{1, -1}, DT_INT16, mlir::IntegerType::get(16, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int32>(
|
||||
{1, -1}, DT_INT32, mlir::IntegerType::get(32, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<int64>(
|
||||
{1, -1}, DT_INT64, mlir::IntegerType::get(64, &context)));
|
||||
|
||||
auto value_or = ConvertTensor(tensor, &b);
|
||||
TF_EXPECT_OK(value_or.status());
|
||||
auto attr = value_or.ValueOrDie();
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint8>(
|
||||
{1, 2}, DT_UINT8,
|
||||
mlir::IntegerType::get(
|
||||
8, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint16>(
|
||||
{1, 2}, DT_UINT16,
|
||||
mlir::IntegerType::get(
|
||||
16, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint32>(
|
||||
{1, 2}, DT_UINT32,
|
||||
mlir::IntegerType::get(
|
||||
32, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<uint64>(
|
||||
{1, 2}, DT_UINT64,
|
||||
mlir::IntegerType::get(
|
||||
64, mlir::IntegerType::SignednessSemantics::Unsigned, &context)));
|
||||
|
||||
EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
|
||||
EXPECT_TRUE(attr.getType().getElementType().isF16());
|
||||
|
||||
Tensor out;
|
||||
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
||||
|
||||
test::ExpectTensorEqual<Eigen::half>(tensor, out);
|
||||
}
|
||||
|
||||
{
|
||||
// Create the sample tensor to convert.
|
||||
Tensor tensor(DT_BFLOAT16, TensorShape({2}));
|
||||
auto Tt = tensor.flat<bfloat16>();
|
||||
Tt.setValues({bfloat16(1.0), bfloat16(-1.0)});
|
||||
|
||||
auto value_or = ConvertTensor(tensor, &b);
|
||||
TF_EXPECT_OK(value_or.status());
|
||||
auto attr = value_or.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(attr.isa<mlir::DenseFPElementsAttr>());
|
||||
EXPECT_TRUE(attr.getType().getElementType().isBF16());
|
||||
|
||||
Tensor out;
|
||||
TF_ASSERT_OK(ConvertToTensor(attr, &out));
|
||||
|
||||
test::ExpectTensorEqual<bfloat16>(tensor, out);
|
||||
}
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<float>>(
|
||||
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64,
|
||||
mlir::ComplexType::get(mlir::FloatType::getF32(&context))));
|
||||
ASSERT_NO_FATAL_FAILURE(VerifyConversion<std::complex<double>>(
|
||||
{{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128,
|
||||
mlir::ComplexType::get(mlir::FloatType::getF64(&context))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user