Refactor ConvertToTensorProto to avoid some duplication
Share implementation to populate tensors unless the element type requires special handling. PiperOrigin-RevId: 310820768 Change-Id: Ibdf11da1e9e41b3f2ddb10a43563da244f044b62
This commit is contained in:
parent
4ce786df35
commit
02b34b853d
@ -37,7 +37,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
@ -208,11 +207,12 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
|
||||
|
||||
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
|
||||
// proto.
|
||||
void ConvertStringElementsAttr(
|
||||
const DenseStringElementsAttr attr,
|
||||
protobuf::RepeatedPtrField<std::string>* output) {
|
||||
for (const auto& val : attr.getRawStringData())
|
||||
output->Add({val.data(), val.size()});
|
||||
Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
for (const auto& val : attr.getRawStringData()) {
|
||||
output_tensor->add_string_val(val.data(), val.size());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
||||
@ -226,80 +226,139 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
|
||||
return InvalidArgument("Unexpected elements attribute type from MLIR.");
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute and adds it to specified repeated field.
|
||||
template <typename T>
|
||||
void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
|
||||
protobuf::RepeatedField<T>* output) {
|
||||
if (attr.isSplat()) {
|
||||
output->Add(attr.getSplatValue<T>());
|
||||
} else {
|
||||
for (auto value : attr.getValues<T>()) output->Add(value);
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the double_val field updated.
|
||||
Status ConvertDoubleElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_double_val(elts.getSplatValue<double>());
|
||||
} else {
|
||||
for (auto value : elts.getValues<double>())
|
||||
output_tensor->add_double_val(value);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute containing half values and adds it to
|
||||
// specified repeated field.
|
||||
void ConvertHalfElementsAttr(const DenseFPElementsAttr attr,
|
||||
protobuf::RepeatedField<int>* output_tensor) {
|
||||
if (attr.isSplat()) {
|
||||
output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
|
||||
} else {
|
||||
for (const llvm::APFloat value : attr.getFloatValues())
|
||||
output_tensor->Add(value.bitcastToAPInt().getSExtValue());
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the float_val field updated.
|
||||
Status ConvertFloatElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_float_val(elts.getSplatValue<float>());
|
||||
} else {
|
||||
for (auto value : elts.getValues<float>())
|
||||
output_tensor->add_float_val(value);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute containing int values and adds it to
|
||||
// specified repeated field.
|
||||
void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
|
||||
protobuf::RepeatedField<int>* output) {
|
||||
if (attr.isSplat()) {
|
||||
output->Add((*attr.begin()).getSExtValue());
|
||||
} else {
|
||||
for (const llvm::APInt val : attr) output->Add(val.getSExtValue());
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the half_val field updated.
|
||||
Status ConvertHalfElementsAttr(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_half_val(
|
||||
(*elts.begin()).bitcastToAPInt().getSExtValue());
|
||||
} else {
|
||||
for (const auto& value : elts.getFloatValues())
|
||||
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
|
||||
protobuf::RepeatedField<int>* output) {
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the int_val field updated.
|
||||
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_int_val((*elts.begin()).getSExtValue());
|
||||
} else {
|
||||
for (const auto& val : elts)
|
||||
output_tensor->add_int_val(val.getSExtValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
auto elts = attr.dyn_cast<DenseFPElementsAttr>();
|
||||
if (!elts) {
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Bfloat16 is internally represented as `double` in MLIR.
|
||||
if (attr.isSplat()) {
|
||||
double v = attr.getSplatValue<double>();
|
||||
if (elts.isSplat()) {
|
||||
double v = elts.getSplatValue<double>();
|
||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
||||
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
|
||||
} else {
|
||||
for (auto v : attr.getValues<double>()) {
|
||||
for (auto v : elts.getValues<double>()) {
|
||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
||||
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with the int64_val field updated.
|
||||
Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (elts.isSplat()) {
|
||||
output_tensor->add_int64_val((*elts.begin()).getSExtValue());
|
||||
} else {
|
||||
for (const auto& val : elts)
|
||||
output_tensor->add_int64_val(val.getSExtValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||
// with bool_val field updated.
|
||||
Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
for (const auto& val : elts) {
|
||||
output_tensor->add_bool_val(val.getBoolValue());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||
}
|
||||
|
||||
Status ConvertToTensorProto(const ElementsAttr attr,
|
||||
TensorProto* output_tensor) {
|
||||
auto type = attr.getType();
|
||||
auto shape = type.getShape();
|
||||
DataType output_dtype;
|
||||
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
|
||||
output->set_dtype(output_dtype);
|
||||
ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
|
||||
|
||||
if (attr.isa<OpaqueElementsAttr>())
|
||||
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
|
||||
|
||||
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
|
||||
if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
|
||||
output_tensor->set_dtype(output_dtype);
|
||||
ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
|
||||
|
||||
switch (output_dtype) {
|
||||
case DT_FLOAT:
|
||||
ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
|
||||
break;
|
||||
return ConvertFloatElementsAttr(attr, output_tensor);
|
||||
case DT_HALF:
|
||||
ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
|
||||
output->mutable_half_val());
|
||||
break;
|
||||
// Handles both DenseFPElementsAttr and OpaqueElementsAttr.
|
||||
return ConvertHalfElementsAttr(attr, output_tensor);
|
||||
case DT_DOUBLE:
|
||||
ConvertElementsAttr(dense_attr, output->mutable_double_val());
|
||||
break;
|
||||
return ConvertDoubleElementsAttr(attr, output_tensor);
|
||||
case DT_QUINT8:
|
||||
case DT_UINT8:
|
||||
case DT_INT8:
|
||||
@ -307,28 +366,20 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
|
||||
case DT_UINT16:
|
||||
case DT_INT16:
|
||||
case DT_INT32:
|
||||
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
|
||||
output->mutable_int_val());
|
||||
break;
|
||||
return ConvertIntElementsAttr(attr, output_tensor);
|
||||
case DT_INT64:
|
||||
ConvertElementsAttr(dense_attr, output->mutable_int64_val());
|
||||
break;
|
||||
return ConvertInt64ElementsAttr(attr, output_tensor);
|
||||
case DT_BOOL:
|
||||
ConvertElementsAttr(dense_attr, output->mutable_bool_val());
|
||||
break;
|
||||
return ConvertBoolElementsAttr(attr, output_tensor);
|
||||
case DT_BFLOAT16:
|
||||
ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
|
||||
output->mutable_half_val());
|
||||
break;
|
||||
return ConvertBfloat16ElementsAttr(attr, output_tensor);
|
||||
case DT_STRING:
|
||||
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
|
||||
output->mutable_string_val());
|
||||
break;
|
||||
return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
|
||||
output_tensor);
|
||||
default:
|
||||
return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
|
||||
DataTypeString(output_dtype)));
|
||||
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
|
||||
output_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
|
||||
|
Loading…
Reference in New Issue
Block a user