From 4ce786df351817f47e9574192ef66358e76ccf29 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Sun, 10 May 2020 14:54:10 -0700 Subject: [PATCH] Refactor ConvertToTensorProto to avoid some duplication Share implementation to populate tensors unless the element type requires special handling. PiperOrigin-RevId: 310818801 Change-Id: I27b4d9111578e9ecbec663853aad9ed85e46defc --- .../mlir/tensorflow/utils/convert_tensor.cc | 213 +++++++----------- 1 file changed, 81 insertions(+), 132 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index fcfef565952..b492945fe8b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -207,12 +208,11 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { // Converts an MLIR dense string elements attribute to a TensorFlow tensor // proto. -Status ConvertStringElementsAttr(const DenseStringElementsAttr attr, - TensorProto* output_tensor) { - for (const auto& val : attr.getRawStringData()) { - output_tensor->add_string_val(val.data(), val.size()); - } - return Status::OK(); +void ConvertStringElementsAttr( + const DenseStringElementsAttr attr, + protobuf::RepeatedPtrField* output) { + for (const auto& val : attr.getRawStringData()) + output->Add({val.data(), val.size()}); } // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto. @@ -226,139 +226,80 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr, return InvalidArgument("Unexpected elements attribute type from MLIR."); } -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the double_val field updated. -Status ConvertDoubleElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_double_val(elts.getSplatValue()); - } else { - for (auto value : elts.getValues()) - output_tensor->add_double_val(value); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the float_val field updated. -Status ConvertFloatElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_float_val(elts.getSplatValue()); - } else { - for (auto value : elts.getValues()) - output_tensor->add_float_val(value); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the half_val field updated. -Status ConvertHalfElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_half_val( - (*elts.begin()).bitcastToAPInt().getSExtValue()); - } else { - for (const auto& value : elts.getFloatValues()) - output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the int_val field updated. -Status ConvertIntElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_int_val((*elts.begin()).getSExtValue()); - } else { - for (const auto& val : elts) - output_tensor->add_int_val(val.getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - auto elts = attr.dyn_cast(); - if (!elts) { - return ConvertOpaqueElementsAttr(attr, output_tensor); - } - - // Bfloat16 is internally represented as `double` in MLIR. - if (elts.isSplat()) { - double v = elts.getSplatValue(); - bfloat16 bf16_val = static_cast(v); - output_tensor->add_half_val(absl::bit_cast(bf16_val)); +// Converts an MLIR elements attribute and adds it to specified repeated field. +template +void ConvertElementsAttr(const mlir::DenseElementsAttr attr, + protobuf::RepeatedField* output) { + if (attr.isSplat()) { + output->Add(attr.getSplatValue()); } else { - for (auto v : elts.getValues()) { + for (auto value : attr.getValues()) output->Add(value); + } +} + +// Converts an MLIR elements attribute containing half values and adds it to +// specified repeated field. +void ConvertHalfElementsAttr(const DenseFPElementsAttr attr, + protobuf::RepeatedField* output_tensor) { + if (attr.isSplat()) { + output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue()); + } else { + for (const llvm::APFloat value : attr.getFloatValues()) + output_tensor->Add(value.bitcastToAPInt().getSExtValue()); + } +} + +// Converts an MLIR elements attribute containing int values and adds it to +// specified repeated field. +void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr, + protobuf::RepeatedField* output) { + if (attr.isSplat()) { + output->Add((*attr.begin()).getSExtValue()); + } else { + for (const llvm::APInt val : attr) output->Add(val.getSExtValue()); + } +} + +void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr, + protobuf::RepeatedField* output) { + // Bfloat16 is internally represented as `double` in MLIR. + if (attr.isSplat()) { + double v = attr.getSplatValue(); + bfloat16 bf16_val = static_cast(v); + output->Add(absl::bit_cast(bf16_val)); + } else { + for (auto v : attr.getValues()) { bfloat16 bf16_val = static_cast(v); - output_tensor->add_half_val(absl::bit_cast(bf16_val)); + output->Add(absl::bit_cast(bf16_val)); } } - - return Status::OK(); } -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the int64_val field updated. -Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_int64_val((*elts.begin()).getSExtValue()); - } else { - for (const auto& val : elts) - output_tensor->add_int64_val(val.getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with bool_val field updated. -Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - for (const auto& val : elts) { - output_tensor->add_bool_val(val.getBoolValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -Status ConvertToTensorProto(const ElementsAttr attr, - TensorProto* output_tensor) { +Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { auto type = attr.getType(); auto shape = type.getShape(); DataType output_dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype)); - output_tensor->set_dtype(output_dtype); - ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape()); + output->set_dtype(output_dtype); + ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); + + if (attr.isa()) + return ConvertOpaqueElementsAttr(attr.cast(), output); + + auto dense_attr = attr.dyn_cast(); + if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr"); switch (output_dtype) { case DT_FLOAT: - return ConvertFloatElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_float_val()); + break; case DT_HALF: - // Handles both DenseFPElementsAttr and OpaqueElementsAttr. - return ConvertHalfElementsAttr(attr, output_tensor); + ConvertHalfElementsAttr(dense_attr.cast(), + output->mutable_half_val()); + break; case DT_DOUBLE: - return ConvertDoubleElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_double_val()); + break; case DT_QUINT8: case DT_UINT8: case DT_INT8: @@ -366,20 +307,28 @@ Status ConvertToTensorProto(const ElementsAttr attr, case DT_UINT16: case DT_INT16: case DT_INT32: - return ConvertIntElementsAttr(attr, output_tensor); + ConvertIntElementsAttr(dense_attr.cast(), + output->mutable_int_val()); + break; case DT_INT64: - return ConvertInt64ElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_int64_val()); + break; case DT_BOOL: - return ConvertBoolElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_bool_val()); + break; case DT_BFLOAT16: - return ConvertBfloat16ElementsAttr(attr, output_tensor); + ConvertBfloat16ElementsAttr(dense_attr.cast(), + output->mutable_half_val()); + break; case DT_STRING: - return ConvertStringElementsAttr(attr.cast(), - output_tensor); + ConvertStringElementsAttr(dense_attr.cast(), + output->mutable_string_val()); + break; default: - return ConvertOpaqueElementsAttr(attr.cast(), - output_tensor); + return errors::Unimplemented(absl::StrCat("Unimplemented data type ", + DataTypeString(output_dtype))); } + return Status::OK(); } Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {