Refactor ConvertToTensorProto to avoid some duplication

Share implementation to populate tensors unless the element type requires special handling.

PiperOrigin-RevId: 310818801
Change-Id: I27b4d9111578e9ecbec663853aad9ed85e46defc
This commit is contained in:
Smit Hinsu 2020-05-10 14:54:10 -07:00 committed by TensorFlower Gardener
parent 28492e5b5b
commit 4ce786df35

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/stream_executor/lib/statusor.h"
@ -207,12 +208,11 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
// proto.
Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
TensorProto* output_tensor) {
for (const auto& val : attr.getRawStringData()) {
output_tensor->add_string_val(val.data(), val.size());
}
return Status::OK();
void ConvertStringElementsAttr(
const DenseStringElementsAttr attr,
protobuf::RepeatedPtrField<std::string>* output) {
for (const auto& val : attr.getRawStringData())
output->Add({val.data(), val.size()});
}
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
@ -226,139 +226,80 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
return InvalidArgument("Unexpected elements attribute type from MLIR.");
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the double_val field updated.
Status ConvertDoubleElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_double_val(elts.getSplatValue<double>());
// Converts an MLIR elements attribute and adds it to specified repeated field.
template <typename T>
void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
protobuf::RepeatedField<T>* output) {
if (attr.isSplat()) {
output->Add(attr.getSplatValue<T>());
} else {
for (auto value : elts.getValues<double>())
output_tensor->add_double_val(value);
for (auto value : attr.getValues<T>()) output->Add(value);
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the float_val field updated.
Status ConvertFloatElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_float_val(elts.getSplatValue<float>());
// Converts an MLIR elements attribute containing half values and adds it to
// specified repeated field.
void ConvertHalfElementsAttr(const DenseFPElementsAttr attr,
protobuf::RepeatedField<int>* output_tensor) {
if (attr.isSplat()) {
output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
} else {
for (auto value : elts.getValues<float>())
output_tensor->add_float_val(value);
for (const llvm::APFloat value : attr.getFloatValues())
output_tensor->Add(value.bitcastToAPInt().getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the half_val field updated.
Status ConvertHalfElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_half_val(
(*elts.begin()).bitcastToAPInt().getSExtValue());
// Converts an MLIR elements attribute containing int values and adds it to
// specified repeated field.
void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
protobuf::RepeatedField<int>* output) {
if (attr.isSplat()) {
output->Add((*attr.begin()).getSExtValue());
} else {
for (const auto& value : elts.getFloatValues())
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
for (const llvm::APInt val : attr) output->Add(val.getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the int_val field updated.
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_int_val((*elts.begin()).getSExtValue());
} else {
for (const auto& val : elts)
output_tensor->add_int_val(val.getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
auto elts = attr.dyn_cast<DenseFPElementsAttr>();
if (!elts) {
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
protobuf::RepeatedField<int>* output) {
// Bfloat16 is internally represented as `double` in MLIR.
if (elts.isSplat()) {
double v = elts.getSplatValue<double>();
if (attr.isSplat()) {
double v = attr.getSplatValue<double>();
bfloat16 bf16_val = static_cast<bfloat16>(v);
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
output->Add(absl::bit_cast<int16>(bf16_val));
} else {
for (auto v : elts.getValues<double>()) {
for (auto v : attr.getValues<double>()) {
bfloat16 bf16_val = static_cast<bfloat16>(v);
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
output->Add(absl::bit_cast<int16>(bf16_val));
}
}
}
return Status::OK();
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the int64_val field updated.
Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_int64_val((*elts.begin()).getSExtValue());
} else {
for (const auto& val : elts)
output_tensor->add_int64_val(val.getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with bool_val field updated.
Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
for (const auto& val : elts) {
output_tensor->add_bool_val(val.getBoolValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
Status ConvertToTensorProto(const ElementsAttr attr,
TensorProto* output_tensor) {
Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
auto type = attr.getType();
auto shape = type.getShape();
DataType output_dtype;
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
output_tensor->set_dtype(output_dtype);
ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
output->set_dtype(output_dtype);
ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
if (attr.isa<OpaqueElementsAttr>())
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
switch (output_dtype) {
case DT_FLOAT:
return ConvertFloatElementsAttr(attr, output_tensor);
ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
break;
case DT_HALF:
// Handles both DenseFPElementsAttr and OpaqueElementsAttr.
return ConvertHalfElementsAttr(attr, output_tensor);
ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
output->mutable_half_val());
break;
case DT_DOUBLE:
return ConvertDoubleElementsAttr(attr, output_tensor);
ConvertElementsAttr(dense_attr, output->mutable_double_val());
break;
case DT_QUINT8:
case DT_UINT8:
case DT_INT8:
@ -366,20 +307,28 @@ Status ConvertToTensorProto(const ElementsAttr attr,
case DT_UINT16:
case DT_INT16:
case DT_INT32:
return ConvertIntElementsAttr(attr, output_tensor);
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
output->mutable_int_val());
break;
case DT_INT64:
return ConvertInt64ElementsAttr(attr, output_tensor);
ConvertElementsAttr(dense_attr, output->mutable_int64_val());
break;
case DT_BOOL:
return ConvertBoolElementsAttr(attr, output_tensor);
ConvertElementsAttr(dense_attr, output->mutable_bool_val());
break;
case DT_BFLOAT16:
return ConvertBfloat16ElementsAttr(attr, output_tensor);
ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
output->mutable_half_val());
break;
case DT_STRING:
return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
output_tensor);
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
output->mutable_string_val());
break;
default:
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
output_tensor);
return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
DataTypeString(output_dtype)));
}
return Status::OK();
}
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {