Refactor ConvertToTensorProto to avoid some duplication

Share implementation to populate tensors unless the element type requires special handling.

PiperOrigin-RevId: 310820768
Change-Id: Ibdf11da1e9e41b3f2ddb10a43563da244f044b62
This commit is contained in:
A. Unique TensorFlower 2020-05-10 15:27:57 -07:00 committed by TensorFlower Gardener
parent 4ce786df35
commit 02b34b853d

View File

@ -37,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/stream_executor/lib/statusor.h"
@ -208,11 +207,12 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
// proto.
void ConvertStringElementsAttr(
const DenseStringElementsAttr attr,
protobuf::RepeatedPtrField<std::string>* output) {
for (const auto& val : attr.getRawStringData())
output->Add({val.data(), val.size()});
Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
TensorProto* output_tensor) {
for (const auto& val : attr.getRawStringData()) {
output_tensor->add_string_val(val.data(), val.size());
}
return Status::OK();
}
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
@ -226,80 +226,139 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
return InvalidArgument("Unexpected elements attribute type from MLIR.");
}
// Converts an MLIR elements attribute and adds it to specified repeated field.
template <typename T>
void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
protobuf::RepeatedField<T>* output) {
if (attr.isSplat()) {
output->Add(attr.getSplatValue<T>());
} else {
for (auto value : attr.getValues<T>()) output->Add(value);
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the double_val field updated.
Status ConvertDoubleElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_double_val(elts.getSplatValue<double>());
} else {
for (auto value : elts.getValues<double>())
output_tensor->add_double_val(value);
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute containing half values and adds it to
// specified repeated field.
void ConvertHalfElementsAttr(const DenseFPElementsAttr attr,
protobuf::RepeatedField<int>* output_tensor) {
if (attr.isSplat()) {
output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
} else {
for (const llvm::APFloat value : attr.getFloatValues())
output_tensor->Add(value.bitcastToAPInt().getSExtValue());
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the float_val field updated.
Status ConvertFloatElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_float_val(elts.getSplatValue<float>());
} else {
for (auto value : elts.getValues<float>())
output_tensor->add_float_val(value);
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute containing int values and adds it to
// specified repeated field.
void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
protobuf::RepeatedField<int>* output) {
if (attr.isSplat()) {
output->Add((*attr.begin()).getSExtValue());
} else {
for (const llvm::APInt val : attr) output->Add(val.getSExtValue());
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the half_val field updated.
Status ConvertHalfElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_half_val(
(*elts.begin()).bitcastToAPInt().getSExtValue());
} else {
for (const auto& value : elts.getFloatValues())
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
protobuf::RepeatedField<int>* output) {
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the int_val field updated.
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_int_val((*elts.begin()).getSExtValue());
} else {
for (const auto& val : elts)
output_tensor->add_int_val(val.getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
auto elts = attr.dyn_cast<DenseFPElementsAttr>();
if (!elts) {
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Bfloat16 is internally represented as `double` in MLIR.
if (attr.isSplat()) {
double v = attr.getSplatValue<double>();
if (elts.isSplat()) {
double v = elts.getSplatValue<double>();
bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val));
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
} else {
for (auto v : attr.getValues<double>()) {
for (auto v : elts.getValues<double>()) {
bfloat16 bf16_val = static_cast<bfloat16>(v);
output->Add(absl::bit_cast<int16>(bf16_val));
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
}
}
return Status::OK();
}
Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with the int64_val field updated.
Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
if (elts.isSplat()) {
output_tensor->add_int64_val((*elts.begin()).getSExtValue());
} else {
for (const auto& val : elts)
output_tensor->add_int64_val(val.getSExtValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
// Converts an MLIR elements attribute to a TensorFlow tensor proto
// with bool_val field updated.
Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
TensorProto* output_tensor) {
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
for (const auto& val : elts) {
output_tensor->add_bool_val(val.getBoolValue());
}
return Status::OK();
}
return ConvertOpaqueElementsAttr(attr, output_tensor);
}
Status ConvertToTensorProto(const ElementsAttr attr,
TensorProto* output_tensor) {
auto type = attr.getType();
auto shape = type.getShape();
DataType output_dtype;
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
output->set_dtype(output_dtype);
ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
if (attr.isa<OpaqueElementsAttr>())
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
output_tensor->set_dtype(output_dtype);
ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
switch (output_dtype) {
case DT_FLOAT:
ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
break;
return ConvertFloatElementsAttr(attr, output_tensor);
case DT_HALF:
ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
output->mutable_half_val());
break;
// Handles both DenseFPElementsAttr and OpaqueElementsAttr.
return ConvertHalfElementsAttr(attr, output_tensor);
case DT_DOUBLE:
ConvertElementsAttr(dense_attr, output->mutable_double_val());
break;
return ConvertDoubleElementsAttr(attr, output_tensor);
case DT_QUINT8:
case DT_UINT8:
case DT_INT8:
@ -307,28 +366,20 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
case DT_UINT16:
case DT_INT16:
case DT_INT32:
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
output->mutable_int_val());
break;
return ConvertIntElementsAttr(attr, output_tensor);
case DT_INT64:
ConvertElementsAttr(dense_attr, output->mutable_int64_val());
break;
return ConvertInt64ElementsAttr(attr, output_tensor);
case DT_BOOL:
ConvertElementsAttr(dense_attr, output->mutable_bool_val());
break;
return ConvertBoolElementsAttr(attr, output_tensor);
case DT_BFLOAT16:
ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
output->mutable_half_val());
break;
return ConvertBfloat16ElementsAttr(attr, output_tensor);
case DT_STRING:
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
output->mutable_string_val());
break;
return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
output_tensor);
default:
return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
DataTypeString(output_dtype)));
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
output_tensor);
}
return Status::OK();
}
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {