Refactor ConvertToTensorProto to avoid some duplication

Share implementation to populate tensors unless the element type requires special handling.

PiperOrigin-RevId: 310820768
Change-Id: Ibdf11da1e9e41b3f2ddb10a43563da244f044b62
This commit is contained in:
A. Unique TensorFlower 2020-05-10 15:27:57 -07:00 committed by TensorFlower Gardener
parent 4ce786df35
commit 02b34b853d

View File

@ -37,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/tstring.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
@ -208,11 +207,12 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
// Converts an MLIR dense string elements attribute to a TensorFlow tensor // Converts an MLIR dense string elements attribute to a TensorFlow tensor
// proto. // proto.
void ConvertStringElementsAttr( Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
const DenseStringElementsAttr attr, TensorProto* output_tensor) {
protobuf::RepeatedPtrField<std::string>* output) { for (const auto& val : attr.getRawStringData()) {
for (const auto& val : attr.getRawStringData()) output_tensor->add_string_val(val.data(), val.size());
output->Add({val.data(), val.size()}); }
return Status::OK();
} }
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto. // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
@ -226,80 +226,139 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
return InvalidArgument("Unexpected elements attribute type from MLIR."); return InvalidArgument("Unexpected elements attribute type from MLIR.");
} }
// Converts an MLIR elements attribute and adds it to specified repeated field. // Converts an MLIR elements attribute to a TensorFlow tensor proto
template <typename T> // with the double_val field updated.
void ConvertElementsAttr(const mlir::DenseElementsAttr attr, Status ConvertDoubleElementsAttr(const ElementsAttr attr,
protobuf::RepeatedField<T>* output) { TensorProto* output_tensor) {
if (attr.isSplat()) { if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
output->Add(attr.getSplatValue<T>()); if (elts.isSplat()) {
output_tensor->add_double_val(elts.getSplatValue<double>());
} else { } else {
for (auto value : attr.getValues<T>()) output->Add(value); for (auto value : elts.getValues<double>())
output_tensor->add_double_val(value);
} }
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
} }
// Converts an MLIR elements attribute containing half values and adds it to // Converts an MLIR elements attribute to a TensorFlow tensor proto
// specified repeated field. // with the float_val field updated.
void ConvertHalfElementsAttr(const DenseFPElementsAttr attr, Status ConvertFloatElementsAttr(const ElementsAttr attr,
protobuf::RepeatedField<int>* output_tensor) { TensorProto* output_tensor) {
if (attr.isSplat()) { if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue()); if (elts.isSplat()) {
output_tensor->add_float_val(elts.getSplatValue<float>());
} else { } else {
for (const llvm::APFloat value : attr.getFloatValues()) for (auto value : elts.getValues<float>())
output_tensor->Add(value.bitcastToAPInt().getSExtValue()); output_tensor->add_float_val(value);
} }
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
} }
// Converts an MLIR elements attribute containing int values and adds it to // Converts an MLIR elements attribute to a TensorFlow tensor proto
// specified repeated field. // with the half_val field updated.
void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr, Status ConvertHalfElementsAttr(const ElementsAttr attr,
protobuf::RepeatedField<int>* output) { TensorProto* output_tensor) {
if (attr.isSplat()) { if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
output->Add((*attr.begin()).getSExtValue()); if (elts.isSplat()) {
output_tensor->add_half_val(
(*elts.begin()).bitcastToAPInt().getSExtValue());
} else { } else {
for (const llvm::APInt val : attr) output->Add(val.getSExtValue()); for (const auto& value : elts.getFloatValues())
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
} }
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
} }
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr, // Converts an MLIR elements attribute to a TensorFlow tensor proto
protobuf::RepeatedField<int>* output) { // with the int_val field updated.
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_int_val((*elts.begin()).getSExtValue());
} else {
for (const auto& val : elts)
output_tensor->add_int_val(val.getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
auto elts = attr.dyn_cast<DenseFPElementsAttr>();
if (!elts) {
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Bfloat16 is internally represented as `double` in MLIR. // Bfloat16 is internally represented as `double` in MLIR.
if (attr.isSplat()) { if (elts.isSplat()) {
double v = attr.getSplatValue<double>(); double v = elts.getSplatValue<double>();
bfloat16 bf16_val = static_cast<bfloat16>(v); bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val)); output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
} else { } else {
for (auto v : attr.getValues<double>()) { for (auto v : elts.getValues<double>()) {
bfloat16 bf16_val = static_cast<bfloat16>(v); bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val)); output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
} }
} }
return Status::OK();
} }
Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { // Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the int64_val field updated.
Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_int64_val((*elts.begin()).getSExtValue());
} else {
for (const auto& val : elts)
output_tensor->add_int64_val(val.getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with bool_val field updated.
Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
for (const auto& val : elts) {
output_tensor->add_bool_val(val.getBoolValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
Status ConvertToTensorProto(const ElementsAttr attr,
TensorProto* output_tensor) {
auto type = attr.getType(); auto type = attr.getType();
auto shape = type.getShape(); auto shape = type.getShape();
DataType output_dtype; DataType output_dtype;
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype)); TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
output->set_dtype(output_dtype); output_tensor->set_dtype(output_dtype);
ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
if (attr.isa<OpaqueElementsAttr>())
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
switch (output_dtype) { switch (output_dtype) {
case DT_FLOAT: case DT_FLOAT:
ConvertElementsAttr<float>(dense_attr, output->mutable_float_val()); return ConvertFloatElementsAttr(attr, output_tensor);
break;
case DT_HALF: case DT_HALF:
ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(), // Handles both DenseFPElementsAttr and OpaqueElementsAttr.
output->mutable_half_val()); return ConvertHalfElementsAttr(attr, output_tensor);
break;
case DT_DOUBLE: case DT_DOUBLE:
ConvertElementsAttr(dense_attr, output->mutable_double_val()); return ConvertDoubleElementsAttr(attr, output_tensor);
break;
case DT_QUINT8: case DT_QUINT8:
case DT_UINT8: case DT_UINT8:
case DT_INT8: case DT_INT8:
@ -307,28 +366,20 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
case DT_UINT16: case DT_UINT16:
case DT_INT16: case DT_INT16:
case DT_INT32: case DT_INT32:
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(), return ConvertIntElementsAttr(attr, output_tensor);
output->mutable_int_val());
break;
case DT_INT64: case DT_INT64:
ConvertElementsAttr(dense_attr, output->mutable_int64_val()); return ConvertInt64ElementsAttr(attr, output_tensor);
break;
case DT_BOOL: case DT_BOOL:
ConvertElementsAttr(dense_attr, output->mutable_bool_val()); return ConvertBoolElementsAttr(attr, output_tensor);
break;
case DT_BFLOAT16: case DT_BFLOAT16:
ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(), return ConvertBfloat16ElementsAttr(attr, output_tensor);
output->mutable_half_val());
break;
case DT_STRING: case DT_STRING:
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(), return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
output->mutable_string_val()); output_tensor);
break;
default: default:
return errors::Unimplemented(absl::StrCat("Unimplemented data type ", return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
DataTypeString(output_dtype))); output_tensor);
} }
return Status::OK();
} }
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) { Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {