Move tf attribute import to util

PiperOrigin-RevId: 336169024
Change-Id: Iefd595261a9ccf5a99e9ad277bf823676551fbe7
This commit is contained in:
Feng Liu 2020-10-08 14:45:51 -07:00 committed by TensorFlower Gardener
parent 34ee735c2b
commit ebe1d5c6c0
6 changed files with 208 additions and 71 deletions

View File

@ -1045,6 +1045,7 @@ cc_library(
"translate/import_model.h",
],
deps = [
":convert_attr",
":convert_tensor",
":convert_type",
":dump_mlir_util",
@ -1262,6 +1263,24 @@ cc_library(
],
)
cc_library(
name = "convert_attr",
srcs = ["utils/convert_attr.cc"],
hdrs = ["utils/convert_attr.h"],
visibility = [
"//visibility:public",
],
deps = [
":convert_tensor",
":convert_type",
":tensorflow_attributes",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/stream_executor/lib",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "convert_type",
srcs = ["utils/convert_type.cc"],

View File

@ -74,6 +74,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
@ -310,21 +311,6 @@ class ImporterBase {
return ::tensorflow::ConvertTensorProto(value, &builder_);
}
// Converts the tensor shape proto into an MLIR shape attribute.
StatusOr<mlir::TF::ShapeAttr> ConvertTensorShapeProto(
const TensorShapeProto& shape) {
if (shape.unknown_rank())
return mlir::TF::ShapeAttr::get(builder_.getContext(), llvm::None);
llvm::SmallVector<int64_t, 4> dims;
dims.reserve(shape.dim().size());
for (const auto& dim : shape.dim()) {
dims.push_back(dim.size());
}
return mlir::TF::ShapeAttr::get(builder_.getContext(),
llvm::makeArrayRef(dims));
}
// Converts func name in graphdef to mlir::SymbolRefAttribute.
StatusOr<mlir::FlatSymbolRefAttr> ConvertFunctionCallName(
const std::string& func_name);
@ -1142,74 +1128,36 @@ StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
const AttrValue& value) {
switch (value.value_case()) {
case AttrValue::kI:
return builder_.getI64IntegerAttr(value.i());
case AttrValue::kS:
return builder_.getStringAttr(value.s());
case AttrValue::kF:
return builder_.getFloatAttr(builder_.getF32Type(), value.f());
case AttrValue::kB:
return builder_.getBoolAttr(value.b());
case AttrValue::kType: {
mlir::Type type;
TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type));
return mlir::TypeAttr::get(type);
}
case AttrValue::kShape:
return ConvertTensorShapeProto(value.shape());
case AttrValue::kTensor:
return ConvertTensorProto(value.tensor());
case AttrValue::kList: {
absl::InlinedVector<mlir::Attribute, 8> attrs;
for (const auto& item : value.list().i())
attrs.push_back(builder_.getI64IntegerAttr(item));
for (const auto& item : value.list().s())
attrs.push_back(builder_.getStringAttr(item));
for (const auto& item : value.list().f())
attrs.push_back(builder_.getFloatAttr(builder_.getF32Type(), item));
for (const auto& item : value.list().b())
attrs.push_back(builder_.getBoolAttr(item));
for (const auto& item : value.list().type()) {
mlir::Type type;
TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), builder_, &type));
attrs.push_back(mlir::TypeAttr::get(type));
}
for (const auto& item : value.list().shape()) {
TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorShapeProto(item));
attrs.push_back(attr);
}
for (const auto& item : value.list().tensor()) {
TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item));
attrs.push_back(attr);
}
for (const auto& item : value.list().func()) {
TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
if (item.attr_size() != 0)
return errors::Unimplemented(
"func attributes with non-zero attr.size()");
attrs.push_back(attr);
}
return builder_.getArrayAttr(
llvm::makeArrayRef(attrs.begin(), attrs.end()));
}
case AttrValue::kFunc: {
// TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
// Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
// will not use this representation.
NamedAttrList attrs;
for (const auto& func_attr : value.func().attr()) {
TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(func_attr.second));
TF_ASSIGN_OR_RETURN(
auto attr, ImporterBase::ConvertAttributeValue(func_attr.second));
attrs.push_back(builder_.getNamedAttr(func_attr.first, attr));
}
auto func_attrs = builder_.getDictionaryAttr(attrs);
return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs);
}
case AttrValue::VALUE_NOT_SET:
return builder_.getUnitAttr();
// kPlaceholder is not implemented.
case AttrValue::kList: {
if (!value.list().func().empty()) {
absl::InlinedVector<mlir::Attribute, 8> attrs;
for (const auto& item : value.list().func()) {
TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name()));
if (item.attr_size() != 0)
return errors::Unimplemented(
"func attributes with non-zero attr.size()");
attrs.push_back(attr);
}
return builder_.getArrayAttr(
llvm::makeArrayRef(attrs.begin(), attrs.end()));
}
return ConvertNonFuncAttributeValue(value, &builder_);
}
default:
return errors::Unimplemented(
absl::StrCat("Attribute ", value.DebugString()));
return ConvertNonFuncAttributeValue(value, &builder_);
}
}

View File

@ -0,0 +1,113 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
// Converts non func AttrValue proto into an MLIR attribute. Func attribute is
// exclused in this function because the function might be renamed when the
// function definition is imported.
StatusOr<mlir::Attribute> ConvertNonFuncAttributeValue(const AttrValue& value,
mlir::Builder* builder) {
switch (value.value_case()) {
case AttrValue::kI:
return builder->getI64IntegerAttr(value.i());
case AttrValue::kS:
return builder->getStringAttr(value.s());
case AttrValue::kF:
return builder->getFloatAttr(builder->getF32Type(), value.f());
case AttrValue::kB:
return builder->getBoolAttr(value.b());
case AttrValue::kType: {
mlir::Type type;
TF_RETURN_IF_ERROR(ConvertDataType(value.type(), *builder, &type));
return mlir::TypeAttr::get(type);
}
case AttrValue::kShape:
return ConvertTensorShapeProto(value.shape(), builder->getContext());
case AttrValue::kTensor:
return ConvertTensorProto(value.tensor(), builder);
case AttrValue::kList: {
absl::InlinedVector<mlir::Attribute, 8> attrs;
for (const auto& item : value.list().i())
attrs.push_back(builder->getI64IntegerAttr(item));
for (const auto& item : value.list().s())
attrs.push_back(builder->getStringAttr(item));
for (const auto& item : value.list().f())
attrs.push_back(builder->getFloatAttr(builder->getF32Type(), item));
for (const auto& item : value.list().b())
attrs.push_back(builder->getBoolAttr(item));
for (const auto& item : value.list().type()) {
mlir::Type type;
TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), *builder, &type));
attrs.push_back(mlir::TypeAttr::get(type));
}
for (const auto& item : value.list().shape()) {
TF_ASSIGN_OR_RETURN(
auto attr, ConvertTensorShapeProto(item, builder->getContext()));
attrs.push_back(attr);
}
for (const auto& item : value.list().tensor()) {
TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item, builder));
attrs.push_back(attr);
}
if (!value.list().func().empty()) {
return tensorflow::errors::Unimplemented(
absl::StrCat("Attribute ", value.DebugString()));
}
return builder->getArrayAttr(
llvm::makeArrayRef(attrs.begin(), attrs.end()));
}
case AttrValue::VALUE_NOT_SET:
return builder->getUnitAttr();
// kPlaceholder is not implemented.
default:
return tensorflow::errors::Unimplemented(
absl::StrCat("Attribute ", value.DebugString()));
}
}
StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value,
mlir::Builder* builder) {
switch (value.value_case()) {
case AttrValue::kFunc: {
// TODO(b/156546237): Unify kFunc/NameAttrList attribute representation.
// Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue
// will not use this representation.
mlir::NamedAttrList attrs;
for (const auto& func_attr : value.func().attr()) {
TF_ASSIGN_OR_RETURN(auto attr,
ConvertAttributeValue(func_attr.second, builder));
attrs.push_back(builder->getNamedAttr(func_attr.first, attr));
}
auto func_attrs = builder->getDictionaryAttr(attrs);
return mlir::TF::FuncAttr::get(builder->getContext(), value.func().name(),
func_attrs);
}
default:
return ConvertNonFuncAttributeValue(value, builder);
}
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
using stream_executor::port::StatusOr;
// Converts non func AttrValue proto into an MLIR attribute. Func attribute is
// exclused in this function because the function might be renamed when the
// function definition is imported.
StatusOr<mlir::Attribute> ConvertNonFuncAttributeValue(const AttrValue& value,
mlir::Builder* builder);
// Converts all kinds of AttrValue proto into an MLIR attribute.
StatusOr<mlir::Attribute> ConvertAttributeValue(const AttrValue& value,
mlir::Builder* builder);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_

View File

@ -214,6 +214,20 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) {
return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef<int64_t>());
}
// Converts the tensor shape proto into an MLIR shape attribute.
StatusOr<mlir::Attribute> ConvertTensorShapeProto(const TensorShapeProto& shape,
mlir::MLIRContext* context) {
if (shape.unknown_rank())
return mlir::TF::ShapeAttr::get(context, llvm::None);
llvm::SmallVector<int64_t, 4> dims;
dims.reserve(shape.dim().size());
for (const auto& dim : shape.dim()) {
dims.push_back(dim.size());
}
return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims));
}
// Converts an MLIR dense string elements attribute to a TensorFlow tensor
// proto.
void ConvertStringElementsAttr(

View File

@ -48,6 +48,10 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type);
// Converts an MLIR shaped type to a TensorFlow shape attribute.
mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type);
// Converts a TensorFlow shape attribute to an MLIR shape attribute.
StatusOr<mlir::Attribute> ConvertTensorShapeProto(const TensorShapeProto& shape,
mlir::MLIRContext* context);
// Converts an MLIR elements attribute to a TensorFlow tensor proto.
Status ConvertToTensorProto(mlir::ElementsAttr attr,
TensorProto* output_tensor);