Move tf attribute import to util
PiperOrigin-RevId: 336169024 Change-Id: Iefd595261a9ccf5a99e9ad277bf823676551fbe7
This commit is contained in:
parent
34ee735c2b
commit
ebe1d5c6c0
tensorflow/compiler/mlir/tensorflow
@ -1045,6 +1045,7 @@ cc_library(
|
||||
"translate/import_model.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_attr",
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":dump_mlir_util",
|
||||
@ -1262,6 +1263,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_attr",
|
||||
srcs = ["utils/convert_attr.cc"],
|
||||
hdrs = ["utils/convert_attr.h"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":tensorflow_attributes",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_type",
|
||||
srcs = ["utils/convert_type.cc"],
|
||||
|
@ -74,6 +74,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
@ -310,21 +311,6 @@ class ImporterBase {
|
||||
return ::tensorflow::ConvertTensorProto(value, &builder_);
|
||||
}
|
||||
|
||||
// Converts the tensor shape proto into an MLIR shape attribute.
|
||||
StatusOr<mlir::TF::ShapeAttr> ConvertTensorShapeProto(
|
||||
const TensorShapeProto& shape) {
|
||||
if (shape.unknown_rank())
|
||||
return mlir::TF::ShapeAttr::get(builder_.getContext(), llvm::None);
|
||||
|
||||
llvm::SmallVector<int64_t, 4> dims;
|
||||
dims.reserve(shape.dim().size());
|
||||
for (const auto& dim : shape.dim()) {
|
||||
dims.push_back(dim.size());
|
||||
}
|
||||
return mlir::TF::ShapeAttr::get(builder_.getContext(),
|
||||
llvm::makeArrayRef(dims));
|
||||
}
|
||||
|
||||
// Converts func name in graphdef to mlir::SymbolRefAttribute.
|
||||
StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName(
|
||||
const std::string& func_name);
|
||||
@ -1142,74 +1128,36 @@ StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
|
||||
StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
|
||||
const AttrValue& value) {
|
||||
switch (value.value_case()) {
|
||||
case AttrValue::kI:
|
||||
return builder_.getI64IntegerAttr(value.i());
|
||||
case AttrValue::kS:
|
||||
return builder_.getStringAttr(value.s());
|
||||
case AttrValue::kF:
|
||||
return builder_.getFloatAttr(builder_.getF32Type(), value.f());
|
||||
case AttrValue::kB:
|
||||
return builder_.getBoolAttr(value.b());
|
||||
case AttrValue::kType: {
|
||||
mlir::Type type;
|
||||
TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type));
|
||||
return mlir::TypeAttr::get(type);
|
||||
}
|
||||
case AttrValue::kShape:
|
||||
return ConvertTensorShapeProto(value.shape());
|
||||
case AttrValue::kTensor:
|
||||
return ConvertTensorProto(value.tensor());
|
||||
case AttrValue::kList: {
|
||||
absl::InlinedVector<mlir::Attribute, 8> attrs;
|
||||
for (const auto& item : value.list().i())
|
||||
attrs.push_back(builder_.getI64IntegerAttr(item));
|
||||
for (const auto& item : value.list().s())
|
||||
attrs.push_back(builder_.getStringAttr(item));
|
||||
for (const auto& item : value.list().f())
|
||||
attrs.push_back(builder_.getFloatAttr(builder_.getF32Type(), item));
|
||||
for (const auto& item : value.list().b())
|
||||
attrs.push_back(builder_.getBoolAttr(item));
|
||||
for (const auto& item : value.list().type()) {
|
||||
mlir::Type type;
|
||||
TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), builder_, &type));
|
||||
attrs.push_back(mlir::TypeAttr::get(type));
|
||||
}
|
||||
for (const auto& item : value.list().shape()) {
|
||||
TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorShapeProto(item));
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
for (const auto& item : value.list().tensor()) {
|
||||
TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item));
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
for (const auto& item : value.list().func()) {
|
||||
TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
|
||||
if (item.attr_size() != 0)
|
||||
return errors::Unimplemented(
|
||||
"func attributes with non-zero attr.size()");
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
return builder_.getArrayAttr(
|
||||
llvm::makeArrayRef(attrs.begin(), attrs.end()));
|
||||
}
|
||||
case AttrValue::kFunc: {
|
||||
// TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
|
||||
// Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
|
||||
// will not use this representation.
|
||||
NamedAttrList attrs;
|
||||
for (const auto& func_attr : value.func().attr()) {
|
||||
TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(func_attr.second));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto attr, ImporterBase::ConvertAttributeValue(func_attr.second));
|
||||
attrs.push_back(builder_.getNamedAttr(func_attr.first, attr));
|
||||
}
|
||||
auto func_attrs = builder_.getDictionaryAttr(attrs);
|
||||
return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs);
|
||||
}
|
||||
case AttrValue::VALUE_NOT_SET:
|
||||
return builder_.getUnitAttr();
|
||||
// kPlaceholder is not implemented.
|
||||
case AttrValue::kList: {
|
||||
if (!value.list().func().empty()) {
|
||||
absl::InlinedVector<mlir::Attribute, 8> attrs;
|
||||
for (const auto& item : value.list().func()) {
|
||||
TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
|
||||
if (item.attr_size() != 0)
|
||||
return errors::Unimplemented(
|
||||
"func attributes with non-zero attr.size()");
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
return builder_.getArrayAttr(
|
||||
llvm::makeArrayRef(attrs.begin(), attrs.end()));
|
||||
}
|
||||
return ConvertNonFuncAttributeValue(value, &builder_);
|
||||
}
|
||||
default:
|
||||
return errors::Unimplemented(
|
||||
absl::StrCat("Attribute ", value.DebugString()));
|
||||
return ConvertNonFuncAttributeValue(value, &builder_);
|
||||
}
|
||||
}
|
||||
|
||||
|
113
tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc
Normal file
113
tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc
Normal file
@ -0,0 +1,113 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts non func AttrValue proto into an MLIR attribute. Func attribute is
|
||||
// exclused in this function because the function might be renamed when the
|
||||
// function definition is imported.
|
||||
StatusOr<mlir::Attribute> ConvertNonFuncAttributeValue(const AttrValue& value,
|
||||
mlir::Builder* builder) {
|
||||
switch (value.value_case()) {
|
||||
case AttrValue::kI:
|
||||
return builder->getI64IntegerAttr(value.i());
|
||||
case AttrValue::kS:
|
||||
return builder->getStringAttr(value.s());
|
||||
case AttrValue::kF:
|
||||
return builder->getFloatAttr(builder->getF32Type(), value.f());
|
||||
case AttrValue::kB:
|
||||
return builder->getBoolAttr(value.b());
|
||||
case AttrValue::kType: {
|
||||
mlir::Type type;
|
||||
TF_RETURN_IF_ERROR(ConvertDataType(value.type(), *builder, &type));
|
||||
return mlir::TypeAttr::get(type);
|
||||
}
|
||||
case AttrValue::kShape:
|
||||
return ConvertTensorShapeProto(value.shape(), builder->getContext());
|
||||
case AttrValue::kTensor:
|
||||
return ConvertTensorProto(value.tensor(), builder);
|
||||
case AttrValue::kList: {
|
||||
absl::InlinedVector<mlir::Attribute, 8> attrs;
|
||||
for (const auto& item : value.list().i())
|
||||
attrs.push_back(builder->getI64IntegerAttr(item));
|
||||
for (const auto& item : value.list().s())
|
||||
attrs.push_back(builder->getStringAttr(item));
|
||||
for (const auto& item : value.list().f())
|
||||
attrs.push_back(builder->getFloatAttr(builder->getF32Type(), item));
|
||||
for (const auto& item : value.list().b())
|
||||
attrs.push_back(builder->getBoolAttr(item));
|
||||
for (const auto& item : value.list().type()) {
|
||||
mlir::Type type;
|
||||
TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), *builder, &type));
|
||||
attrs.push_back(mlir::TypeAttr::get(type));
|
||||
}
|
||||
for (const auto& item : value.list().shape()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto attr, ConvertTensorShapeProto(item, builder->getContext()));
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
for (const auto& item : value.list().tensor()) {
|
||||
TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item, builder));
|
||||
attrs.push_back(attr);
|
||||
}
|
||||
if (!value.list().func().empty()) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
absl::StrCat("Attribute ", value.DebugString()));
|
||||
}
|
||||
return builder->getArrayAttr(
|
||||
llvm::makeArrayRef(attrs.begin(), attrs.end()));
|
||||
}
|
||||
case AttrValue::VALUE_NOT_SET:
|
||||
return builder->getUnitAttr();
|
||||
// kPlaceholder is not implemented.
|
||||
default:
|
||||
return tensorflow::errors::Unimplemented(
|
||||
absl::StrCat("Attribute ", value.DebugString()));
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value,
|
||||
mlir::Builder* builder) {
|
||||
switch (value.value_case()) {
|
||||
case AttrValue::kFunc: {
|
||||
// TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
|
||||
// Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
|
||||
// will not use this representation.
|
||||
mlir::NamedAttrList attrs;
|
||||
for (const auto& func_attr : value.func().attr()) {
|
||||
TF_ASSIGN_OR_RETURN(auto attr,
|
||||
ConvertAttributeValue(func_attr.second, builder));
|
||||
attrs.push_back(builder->getNamedAttr(func_attr.first, attr));
|
||||
}
|
||||
auto func_attrs = builder->getDictionaryAttr(attrs);
|
||||
return mlir::TF::FuncAttr::get(builder->getContext(), value.func().name(),
|
||||
func_attrs);
|
||||
}
|
||||
default:
|
||||
return ConvertNonFuncAttributeValue(value, builder);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
39
tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h
Normal file
39
tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
// Converts non func AttrValue proto into an MLIR attribute. Func attribute is
|
||||
// exclused in this function because the function might be renamed when the
|
||||
// function definition is imported.
|
||||
StatusOr<mlir::Attribute> ConvertNonFuncAttributeValue(const AttrValue& value,
|
||||
mlir::Builder* builder);
|
||||
|
||||
// Converts all kinds of AttrValue proto into an MLIR attribute.
|
||||
StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value,
|
||||
mlir::Builder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_
|
@ -214,6 +214,20 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
|
||||
return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef<int64_t>());
|
||||
}
|
||||
|
||||
// Converts the tensor shape proto into an MLIR shape attribute.
|
||||
StatusOr<mlir::Attribute> ConvertTensorShapeProto(const TensorShapeProto& shape,
|
||||
mlir::MLIRContext* context) {
|
||||
if (shape.unknown_rank())
|
||||
return mlir::TF::ShapeAttr::get(context, llvm::None);
|
||||
|
||||
llvm::SmallVector<int64_t, 4> dims;
|
||||
dims.reserve(shape.dim().size());
|
||||
for (const auto& dim : shape.dim()) {
|
||||
dims.push_back(dim.size());
|
||||
}
|
||||
return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims));
|
||||
}
|
||||
|
||||
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
|
||||
// proto.
|
||||
void ConvertStringElementsAttr(
|
||||
|
@ -48,6 +48,10 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type);
|
||||
// Converts an MLIR shaped type to a TensorFlow shape attribute.
|
||||
mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type);
|
||||
|
||||
// Converts a TensorFlow shape attribute to an MLIR shape attribute.
|
||||
StatusOr<mlir::Attribute> ConvertTensorShapeProto(const TensorShapeProto& shape,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Converts an MLIR elements attribute to a TensorFlow tensor proto.
|
||||
Status ConvertToTensorProto(mlir::ElementsAttr attr,
|
||||
TensorProto* output_tensor);
|
||||
|
Loading…
Reference in New Issue
Block a user