Refactor ConvertToTensorProto to avoid some duplication
Share implementation to populate tensors unless the element type requires special handling. PiperOrigin-RevId: 310820768 Change-Id: Ibdf11da1e9e41b3f2ddb10a43563da244f044b62
This commit is contained in:
parent
4ce786df35
commit
02b34b853d
@ -37,7 +37,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/platform/tstring.h"
|
#include "tensorflow/core/platform/tstring.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
@ -208,11 +207,12 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
|
|||||||
|
|
||||||
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
|
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
|
||||||
// proto.
|
// proto.
|
||||||
void ConvertStringElementsAttr(
|
Status ConvertStringElementsAttr(const DenseStringElementsAttr attr,
|
||||||
const DenseStringElementsAttr attr,
|
TensorProto* output_tensor) {
|
||||||
protobuf::RepeatedPtrField<std::string>* output) {
|
for (const auto& val : attr.getRawStringData()) {
|
||||||
for (const auto& val : attr.getRawStringData())
|
output_tensor->add_string_val(val.data(), val.size());
|
||||||
output->Add({val.data(), val.size()});
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
|
||||||
@ -226,80 +226,139 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
|
|||||||
return InvalidArgument("Unexpected elements attribute type from MLIR.");
|
return InvalidArgument("Unexpected elements attribute type from MLIR.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR elements attribute and adds it to specified repeated field.
|
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||||
template <typename T>
|
// with the double_val field updated.
|
||||||
void ConvertElementsAttr(const mlir::DenseElementsAttr attr,
|
Status ConvertDoubleElementsAttr(const ElementsAttr attr,
|
||||||
protobuf::RepeatedField<T>* output) {
|
TensorProto* output_tensor) {
|
||||||
if (attr.isSplat()) {
|
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||||
output->Add(attr.getSplatValue<T>());
|
if (elts.isSplat()) {
|
||||||
|
output_tensor->add_double_val(elts.getSplatValue<double>());
|
||||||
} else {
|
} else {
|
||||||
for (auto value : attr.getValues<T>()) output->Add(value);
|
for (auto value : elts.getValues<double>())
|
||||||
|
output_tensor->add_double_val(value);
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR elements attribute containing half values and adds it to
|
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||||
// specified repeated field.
|
// with the float_val field updated.
|
||||||
void ConvertHalfElementsAttr(const DenseFPElementsAttr attr,
|
Status ConvertFloatElementsAttr(const ElementsAttr attr,
|
||||||
protobuf::RepeatedField<int>* output_tensor) {
|
TensorProto* output_tensor) {
|
||||||
if (attr.isSplat()) {
|
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||||
output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue());
|
if (elts.isSplat()) {
|
||||||
|
output_tensor->add_float_val(elts.getSplatValue<float>());
|
||||||
} else {
|
} else {
|
||||||
for (const llvm::APFloat value : attr.getFloatValues())
|
for (auto value : elts.getValues<float>())
|
||||||
output_tensor->Add(value.bitcastToAPInt().getSExtValue());
|
output_tensor->add_float_val(value);
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an MLIR elements attribute containing int values and adds it to
|
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||||
// specified repeated field.
|
// with the half_val field updated.
|
||||||
void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr,
|
Status ConvertHalfElementsAttr(const ElementsAttr attr,
|
||||||
protobuf::RepeatedField<int>* output) {
|
TensorProto* output_tensor) {
|
||||||
if (attr.isSplat()) {
|
if (auto elts = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||||
output->Add((*attr.begin()).getSExtValue());
|
if (elts.isSplat()) {
|
||||||
|
output_tensor->add_half_val(
|
||||||
|
(*elts.begin()).bitcastToAPInt().getSExtValue());
|
||||||
} else {
|
} else {
|
||||||
for (const llvm::APInt val : attr) output->Add(val.getSExtValue());
|
for (const auto& value : elts.getFloatValues())
|
||||||
|
output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue());
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr,
|
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||||
protobuf::RepeatedField<int>* output) {
|
// with the int_val field updated.
|
||||||
|
Status ConvertIntElementsAttr(const mlir::ElementsAttr attr,
|
||||||
|
TensorProto* output_tensor) {
|
||||||
|
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||||
|
if (elts.isSplat()) {
|
||||||
|
output_tensor->add_int_val((*elts.begin()).getSExtValue());
|
||||||
|
} else {
|
||||||
|
for (const auto& val : elts)
|
||||||
|
output_tensor->add_int_val(val.getSExtValue());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr,
|
||||||
|
TensorProto* output_tensor) {
|
||||||
|
auto elts = attr.dyn_cast<DenseFPElementsAttr>();
|
||||||
|
if (!elts) {
|
||||||
|
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
// Bfloat16 is internally represented as `double` in MLIR.
|
// Bfloat16 is internally represented as `double` in MLIR.
|
||||||
if (attr.isSplat()) {
|
if (elts.isSplat()) {
|
||||||
double v = attr.getSplatValue<double>();
|
double v = elts.getSplatValue<double>();
|
||||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
|
||||||
} else {
|
} else {
|
||||||
for (auto v : attr.getValues<double>()) {
|
for (auto v : elts.getValues<double>()) {
|
||||||
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
bfloat16 bf16_val = static_cast<bfloat16>(v);
|
||||||
output->Add(absl::bit_cast<int16>(bf16_val));
|
output_tensor->add_half_val(absl::bit_cast<int16>(bf16_val));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
|
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||||
|
// with the int64_val field updated.
|
||||||
|
Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr,
|
||||||
|
TensorProto* output_tensor) {
|
||||||
|
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||||
|
if (elts.isSplat()) {
|
||||||
|
output_tensor->add_int64_val((*elts.begin()).getSExtValue());
|
||||||
|
} else {
|
||||||
|
for (const auto& val : elts)
|
||||||
|
output_tensor->add_int64_val(val.getSExtValue());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts an MLIR elements attribute to a TensorFlow tensor proto
|
||||||
|
// with bool_val field updated.
|
||||||
|
Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr,
|
||||||
|
TensorProto* output_tensor) {
|
||||||
|
if (auto elts = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||||
|
for (const auto& val : elts) {
|
||||||
|
output_tensor->add_bool_val(val.getBoolValue());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return ConvertOpaqueElementsAttr(attr, output_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ConvertToTensorProto(const ElementsAttr attr,
|
||||||
|
TensorProto* output_tensor) {
|
||||||
auto type = attr.getType();
|
auto type = attr.getType();
|
||||||
auto shape = type.getShape();
|
auto shape = type.getShape();
|
||||||
DataType output_dtype;
|
DataType output_dtype;
|
||||||
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
|
TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype));
|
||||||
output->set_dtype(output_dtype);
|
output_tensor->set_dtype(output_dtype);
|
||||||
ConvertToTensorShapeProto(shape, output->mutable_tensor_shape());
|
ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape());
|
||||||
|
|
||||||
if (attr.isa<OpaqueElementsAttr>())
|
|
||||||
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(), output);
|
|
||||||
|
|
||||||
auto dense_attr = attr.dyn_cast<mlir::DenseElementsAttr>();
|
|
||||||
if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr");
|
|
||||||
|
|
||||||
switch (output_dtype) {
|
switch (output_dtype) {
|
||||||
case DT_FLOAT:
|
case DT_FLOAT:
|
||||||
ConvertElementsAttr<float>(dense_attr, output->mutable_float_val());
|
return ConvertFloatElementsAttr(attr, output_tensor);
|
||||||
break;
|
|
||||||
case DT_HALF:
|
case DT_HALF:
|
||||||
ConvertHalfElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
|
// Handles both DenseFPElementsAttr and OpaqueElementsAttr.
|
||||||
output->mutable_half_val());
|
return ConvertHalfElementsAttr(attr, output_tensor);
|
||||||
break;
|
|
||||||
case DT_DOUBLE:
|
case DT_DOUBLE:
|
||||||
ConvertElementsAttr(dense_attr, output->mutable_double_val());
|
return ConvertDoubleElementsAttr(attr, output_tensor);
|
||||||
break;
|
|
||||||
case DT_QUINT8:
|
case DT_QUINT8:
|
||||||
case DT_UINT8:
|
case DT_UINT8:
|
||||||
case DT_INT8:
|
case DT_INT8:
|
||||||
@ -307,28 +366,20 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
|
|||||||
case DT_UINT16:
|
case DT_UINT16:
|
||||||
case DT_INT16:
|
case DT_INT16:
|
||||||
case DT_INT32:
|
case DT_INT32:
|
||||||
ConvertIntElementsAttr(dense_attr.cast<DenseIntElementsAttr>(),
|
return ConvertIntElementsAttr(attr, output_tensor);
|
||||||
output->mutable_int_val());
|
|
||||||
break;
|
|
||||||
case DT_INT64:
|
case DT_INT64:
|
||||||
ConvertElementsAttr(dense_attr, output->mutable_int64_val());
|
return ConvertInt64ElementsAttr(attr, output_tensor);
|
||||||
break;
|
|
||||||
case DT_BOOL:
|
case DT_BOOL:
|
||||||
ConvertElementsAttr(dense_attr, output->mutable_bool_val());
|
return ConvertBoolElementsAttr(attr, output_tensor);
|
||||||
break;
|
|
||||||
case DT_BFLOAT16:
|
case DT_BFLOAT16:
|
||||||
ConvertBfloat16ElementsAttr(dense_attr.cast<DenseFPElementsAttr>(),
|
return ConvertBfloat16ElementsAttr(attr, output_tensor);
|
||||||
output->mutable_half_val());
|
|
||||||
break;
|
|
||||||
case DT_STRING:
|
case DT_STRING:
|
||||||
ConvertStringElementsAttr(dense_attr.cast<DenseStringElementsAttr>(),
|
return ConvertStringElementsAttr(attr.cast<DenseStringElementsAttr>(),
|
||||||
output->mutable_string_val());
|
output_tensor);
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
return errors::Unimplemented(absl::StrCat("Unimplemented data type ",
|
return ConvertOpaqueElementsAttr(attr.cast<OpaqueElementsAttr>(),
|
||||||
DataTypeString(output_dtype)));
|
output_tensor);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
|
Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) {
|
||||||
|
Loading…
Reference in New Issue
Block a user