From 02b34b853deb653d293e0c1c3d5d2e0b1453445c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 10 May 2020 15:27:57 -0700 Subject: [PATCH] Refactor ConvertToTensorProto to avoid some duplication Share implementation to populate tensors unless the element type requires special handling. PiperOrigin-RevId: 310820768 Change-Id: Ibdf11da1e9e41b3f2ddb10a43563da244f044b62 --- .../mlir/tensorflow/utils/convert_tensor.cc | 191 +++++++++++------- 1 file changed, 121 insertions(+), 70 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index b492945fe8b..fcfef565952 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -208,11 +207,12 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { // Converts an MLIR dense string elements attribute to a TensorFlow tensor // proto. -void ConvertStringElementsAttr( - const DenseStringElementsAttr attr, - protobuf::RepeatedPtrField* output) { - for (const auto& val : attr.getRawStringData()) - output->Add({val.data(), val.size()}); +Status ConvertStringElementsAttr(const DenseStringElementsAttr attr, + TensorProto* output_tensor) { + for (const auto& val : attr.getRawStringData()) { + output_tensor->add_string_val(val.data(), val.size()); + } + return Status::OK(); } // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto. @@ -226,80 +226,139 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr, return InvalidArgument("Unexpected elements attribute type from MLIR."); } -// Converts an MLIR elements attribute and adds it to specified repeated field. -template -void ConvertElementsAttr(const mlir::DenseElementsAttr attr, - protobuf::RepeatedField* output) { - if (attr.isSplat()) { - output->Add(attr.getSplatValue()); - } else { - for (auto value : attr.getValues()) output->Add(value); +// Converts an MLIR elements attribute to a TensorFlow tensor proto +// with the double_val field updated. +Status ConvertDoubleElementsAttr(const ElementsAttr attr, + TensorProto* output_tensor) { + if (auto elts = attr.dyn_cast()) { + if (elts.isSplat()) { + output_tensor->add_double_val(elts.getSplatValue()); + } else { + for (auto value : elts.getValues()) + output_tensor->add_double_val(value); + } + return Status::OK(); } + return ConvertOpaqueElementsAttr(attr, output_tensor); } -// Converts an MLIR elements attribute containing half values and adds it to -// specified repeated field. -void ConvertHalfElementsAttr(const DenseFPElementsAttr attr, - protobuf::RepeatedField* output_tensor) { - if (attr.isSplat()) { - output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue()); - } else { - for (const llvm::APFloat value : attr.getFloatValues()) - output_tensor->Add(value.bitcastToAPInt().getSExtValue()); +// Converts an MLIR elements attribute to a TensorFlow tensor proto +// with the float_val field updated. +Status ConvertFloatElementsAttr(const ElementsAttr attr, + TensorProto* output_tensor) { + if (auto elts = attr.dyn_cast()) { + if (elts.isSplat()) { + output_tensor->add_float_val(elts.getSplatValue()); + } else { + for (auto value : elts.getValues()) + output_tensor->add_float_val(value); + } + return Status::OK(); } + return ConvertOpaqueElementsAttr(attr, output_tensor); } -// Converts an MLIR elements attribute containing int values and adds it to -// specified repeated field. -void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr, - protobuf::RepeatedField* output) { - if (attr.isSplat()) { - output->Add((*attr.begin()).getSExtValue()); - } else { - for (const llvm::APInt val : attr) output->Add(val.getSExtValue()); +// Converts an MLIR elements attribute to a TensorFlow tensor proto +// with the half_val field updated. +Status ConvertHalfElementsAttr(const ElementsAttr attr, + TensorProto* output_tensor) { + if (auto elts = attr.dyn_cast()) { + if (elts.isSplat()) { + output_tensor->add_half_val( + (*elts.begin()).bitcastToAPInt().getSExtValue()); + } else { + for (const auto& value : elts.getFloatValues()) + output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue()); + } + return Status::OK(); } + return ConvertOpaqueElementsAttr(attr, output_tensor); } -void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr, - protobuf::RepeatedField* output) { +// Converts an MLIR elements attribute to a TensorFlow tensor proto +// with the int_val field updated. +Status ConvertIntElementsAttr(const mlir::ElementsAttr attr, + TensorProto* output_tensor) { + if (auto elts = attr.dyn_cast()) { + if (elts.isSplat()) { + output_tensor->add_int_val((*elts.begin()).getSExtValue()); + } else { + for (const auto& val : elts) + output_tensor->add_int_val(val.getSExtValue()); + } + return Status::OK(); + } + return ConvertOpaqueElementsAttr(attr, output_tensor); +} + +Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr, + TensorProto* output_tensor) { + auto elts = attr.dyn_cast(); + if (!elts) { + return ConvertOpaqueElementsAttr(attr, output_tensor); + } + // Bfloat16 is internally represented as `double` in MLIR. - if (attr.isSplat()) { - double v = attr.getSplatValue(); + if (elts.isSplat()) { + double v = elts.getSplatValue(); bfloat16 bf16_val = static_cast(v); - output->Add(absl::bit_cast(bf16_val)); + output_tensor->add_half_val(absl::bit_cast(bf16_val)); } else { - for (auto v : attr.getValues()) { + for (auto v : elts.getValues()) { bfloat16 bf16_val = static_cast(v); - output->Add(absl::bit_cast(bf16_val)); + output_tensor->add_half_val(absl::bit_cast(bf16_val)); } } + + return Status::OK(); } -Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { +// Converts an MLIR elements attribute to a TensorFlow tensor proto +// with the int64_val field updated. +Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr, + TensorProto* output_tensor) { + if (auto elts = attr.dyn_cast()) { + if (elts.isSplat()) { + output_tensor->add_int64_val((*elts.begin()).getSExtValue()); + } else { + for (const auto& val : elts) + output_tensor->add_int64_val(val.getSExtValue()); + } + return Status::OK(); + } + return ConvertOpaqueElementsAttr(attr, output_tensor); +} + +// Converts an MLIR elements attribute to a TensorFlow tensor proto +// with bool_val field updated. +Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr, + TensorProto* output_tensor) { + if (auto elts = attr.dyn_cast()) { + for (const auto& val : elts) { + output_tensor->add_bool_val(val.getBoolValue()); + } + return Status::OK(); + } + return ConvertOpaqueElementsAttr(attr, output_tensor); +} + +Status ConvertToTensorProto(const ElementsAttr attr, + TensorProto* output_tensor) { auto type = attr.getType(); auto shape = type.getShape(); DataType output_dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype)); - output->set_dtype(output_dtype); - ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); - - if (attr.isa()) - return ConvertOpaqueElementsAttr(attr.cast(), output); - - auto dense_attr = attr.dyn_cast(); - if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr"); + output_tensor->set_dtype(output_dtype); + ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape()); switch (output_dtype) { case DT_FLOAT: - ConvertElementsAttr(dense_attr, output->mutable_float_val()); - break; + return ConvertFloatElementsAttr(attr, output_tensor); case DT_HALF: - ConvertHalfElementsAttr(dense_attr.cast(), - output->mutable_half_val()); - break; + // Handles both DenseFPElementsAttr and OpaqueElementsAttr. + return ConvertHalfElementsAttr(attr, output_tensor); case DT_DOUBLE: - ConvertElementsAttr(dense_attr, output->mutable_double_val()); - break; + return ConvertDoubleElementsAttr(attr, output_tensor); case DT_QUINT8: case DT_UINT8: case DT_INT8: @@ -307,28 +366,20 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { case DT_UINT16: case DT_INT16: case DT_INT32: - ConvertIntElementsAttr(dense_attr.cast(), - output->mutable_int_val()); - break; + return ConvertIntElementsAttr(attr, output_tensor); case DT_INT64: - ConvertElementsAttr(dense_attr, output->mutable_int64_val()); - break; + return ConvertInt64ElementsAttr(attr, output_tensor); case DT_BOOL: - ConvertElementsAttr(dense_attr, output->mutable_bool_val()); - break; + return ConvertBoolElementsAttr(attr, output_tensor); case DT_BFLOAT16: - ConvertBfloat16ElementsAttr(dense_attr.cast(), - output->mutable_half_val()); - break; + return ConvertBfloat16ElementsAttr(attr, output_tensor); case DT_STRING: - ConvertStringElementsAttr(dense_attr.cast(), - output->mutable_string_val()); - break; + return ConvertStringElementsAttr(attr.cast(), + output_tensor); default: - return errors::Unimplemented(absl::StrCat("Unimplemented data type ", - DataTypeString(output_dtype))); + return ConvertOpaqueElementsAttr(attr.cast(), + output_tensor); } - return Status::OK(); } Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {