Convert appropriate composite functions to TFLite custom NMS op.
PiperOrigin-RevId: 328347136 Change-Id: I199e19c04818ef5536dd3ebd7f480f46f37a02b4
This commit is contained in:
parent
03d511ce8f
commit
c66e39713e
@ -280,6 +280,28 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "nms_utils",
|
||||||
|
srcs = [
|
||||||
|
"utils/nms_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"utils/nms_utils.h",
|
||||||
|
],
|
||||||
|
copts = ["-std=c++14"],
|
||||||
|
deps = [
|
||||||
|
":tensorflow_lite",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tftext_utils",
|
name = "tftext_utils",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -373,6 +395,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":constant_utils",
|
":constant_utils",
|
||||||
":lstm_utils",
|
":lstm_utils",
|
||||||
|
":nms_utils",
|
||||||
":stateful_ops_utils",
|
":stateful_ops_utils",
|
||||||
":tensorflow_lite",
|
":tensorflow_lite",
|
||||||
":tftext_utils",
|
":tftext_utils",
|
||||||
|
@ -520,3 +520,42 @@ func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> {
|
|||||||
return %0 : tensor<100xf32>
|
return %0 : tensor<100xf32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
|
||||||
|
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
%1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
%2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
%3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tflite_custom_nms(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x100x4xf32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x100x91xf32>,
|
||||||
|
// CHECK-SAME: %[[VAL_2:.*]]: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = "TFLite_Detection_PostProcess", tf._reference = "mlir"} {
|
||||||
|
// CHECK: %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {custom_code = "TFLite_Detection_PostProcess", custom_option = opaque<"tfl", "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F010000000A0000009A99193F0000003F5B0000000000000000000040000020410000A0400E06060E0E06060E0E0E322601"> : tensor<217xi8>} : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
|
||||||
|
// CHECK: return %[[VAL_3]]#0, %[[VAL_3]]#1, %[[VAL_3]]#2, %[[VAL_3]]#3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
|
||||||
|
// CHECK: }
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
// expected-error @+1 {{Invalid number of results from TFLite_Detection_PostProcess}}
|
||||||
|
func @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
|
||||||
|
|
||||||
|
// expected-error @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}}
|
||||||
|
func @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
|
||||||
|
|
||||||
|
// expected-error @+1 {{max_classes_per_detection attribute is not set or not an integer}}
|
||||||
|
func @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
|
||||||
|
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
%1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
%2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
%3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||||
|
#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
|
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
@ -59,6 +60,7 @@ namespace {
|
|||||||
|
|
||||||
constexpr char kTFAPIImplements[] = "tf.api_implements";
|
constexpr char kTFAPIImplements[] = "tf.api_implements";
|
||||||
constexpr char kTFTextAPIPrefix[] = "tftext:";
|
constexpr char kTFTextAPIPrefix[] = "tftext:";
|
||||||
|
constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
|
||||||
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
|
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
|
||||||
|
|
||||||
using mlir::TF::FuncAttr;
|
using mlir::TF::FuncAttr;
|
||||||
@ -99,59 +101,6 @@ class ConvertEmbeddedLookupFunc {
|
|||||||
FuncOp func_;
|
FuncOp func_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Abstracts the conversion of the padded NMS composite function.
|
|
||||||
class ConvertNMSPaddedFunc {
|
|
||||||
public:
|
|
||||||
explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
|
|
||||||
|
|
||||||
void RewriteFunc() {
|
|
||||||
func_.setAttr(kTFImplements,
|
|
||||||
StringAttr::get(kTfNMSPadded, func_.getContext()));
|
|
||||||
Value boxes = func_.getArgument(0);
|
|
||||||
Value scores = func_.getArgument(1);
|
|
||||||
Value max_output_size = func_.getArgument(2);
|
|
||||||
Value iou_threshold = func_.getArgument(3);
|
|
||||||
Value score_threshold = func_.getArgument(4);
|
|
||||||
auto output_type0 = func_.getType().getResult(0);
|
|
||||||
auto output_type1 = func_.getType().getResult(1);
|
|
||||||
|
|
||||||
OpBuilder builder(func_.getBody());
|
|
||||||
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
|
|
||||||
func_.getLoc(), output_type0, output_type1, boxes, scores,
|
|
||||||
max_output_size, iou_threshold, score_threshold);
|
|
||||||
|
|
||||||
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult VerifySignature() {
|
|
||||||
// Verify high-level function signature.
|
|
||||||
// Relevant argument characteristics are checked by the TFL op definition.
|
|
||||||
if (func_.getNumArguments() < 5) {
|
|
||||||
return func_.emitError()
|
|
||||||
<< "Invalid number of arguments to "
|
|
||||||
"non_max_suppression_padded_v2 (need atleast 5): "
|
|
||||||
<< func_.getNumArguments();
|
|
||||||
}
|
|
||||||
if (func_.getType().getNumResults() != 2) {
|
|
||||||
return func_.emitError() << "Invalid number of results from "
|
|
||||||
"non_max_suppression_padded_v2 (need 2): "
|
|
||||||
<< func_.getType().getNumResults();
|
|
||||||
}
|
|
||||||
// The TFLite fused op does not support batching yet.
|
|
||||||
// TODO(b/158709815): Add support for batches with padded NMS.
|
|
||||||
auto boxes_type =
|
|
||||||
func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
|
|
||||||
return func_.emitError() << "TFLite does not support batched input for "
|
|
||||||
"non_max_suppression_padded";
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
FuncOp func_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// This pass uses mechanisms listed in RFC:
|
// This pass uses mechanisms listed in RFC:
|
||||||
// https://github.com/tensorflow/community/pull/113
|
// https://github.com/tensorflow/community/pull/113
|
||||||
// It prepares composite functions that are attributed to indicate
|
// It prepares composite functions that are attributed to indicate
|
||||||
@ -219,6 +168,12 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
|
|||||||
if (failed(ConvertTFTextAPI(func, api_name, attr))) {
|
if (failed(ConvertTFTextAPI(func, api_name, attr))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
} else if (api_name == kCustomSSDPostprocessing) {
|
||||||
|
ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
|
||||||
|
if (failed(convert_ssd_postprocess.VerifySignature()) ||
|
||||||
|
failed(convert_ssd_postprocess.RewriteFunc())) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
174
tensorflow/compiler/mlir/lite/utils/nms_utils.cc
Normal file
174
tensorflow/compiler/mlir/lite/utils/nms_utils.cc
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TFL {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// TODO(b/162842801): Consolidate all util definitions of kTFImplements.
|
||||||
|
constexpr char kTFImplements[] = "tf._implements";
|
||||||
|
constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
|
||||||
|
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
|
||||||
|
|
||||||
|
inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
|
||||||
|
const std::string& content) {
|
||||||
|
ShapedType type = RankedTensorType::get(
|
||||||
|
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
|
||||||
|
return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
|
||||||
|
type,
|
||||||
|
StringRef(content.data(), content.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ConvertNMSPaddedFunc::RewriteFunc() {
|
||||||
|
func_.setAttr(kTFImplements,
|
||||||
|
StringAttr::get(kTfNMSPadded, func_.getContext()));
|
||||||
|
Value boxes = func_.getArgument(0);
|
||||||
|
Value scores = func_.getArgument(1);
|
||||||
|
Value max_output_size = func_.getArgument(2);
|
||||||
|
Value iou_threshold = func_.getArgument(3);
|
||||||
|
Value score_threshold = func_.getArgument(4);
|
||||||
|
auto output_type0 = func_.getType().getResult(0);
|
||||||
|
auto output_type1 = func_.getType().getResult(1);
|
||||||
|
|
||||||
|
OpBuilder builder(func_.getBody());
|
||||||
|
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
|
||||||
|
func_.getLoc(), output_type0, output_type1, boxes, scores,
|
||||||
|
max_output_size, iou_threshold, score_threshold);
|
||||||
|
|
||||||
|
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
|
||||||
|
// Verify high-level function signature.
|
||||||
|
// Relevant argument characteristics are checked by the TFL op definition.
|
||||||
|
if (func_.getNumArguments() < 5) {
|
||||||
|
return func_.emitError()
|
||||||
|
<< "Invalid number of arguments to "
|
||||||
|
"non_max_suppression_padded_v2 (need atleast 5): "
|
||||||
|
<< func_.getNumArguments();
|
||||||
|
}
|
||||||
|
if (func_.getType().getNumResults() != 2) {
|
||||||
|
return func_.emitError() << "Invalid number of results from "
|
||||||
|
"non_max_suppression_padded_v2 (need 2): "
|
||||||
|
<< func_.getType().getNumResults();
|
||||||
|
}
|
||||||
|
// The TFLite fused op does not support batching yet.
|
||||||
|
// TODO(b/158709815): Add support for batches with padded NMS.
|
||||||
|
auto boxes_type = func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
|
||||||
|
return func_.emitError() << "TFLite does not support batched input for "
|
||||||
|
"non_max_suppression_padded";
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
|
||||||
|
func_.eraseBody();
|
||||||
|
func_.addEntryBlock();
|
||||||
|
func_.setAttr(kTFImplements,
|
||||||
|
StringAttr::get(kCustomSSDPostprocessing, func_.getContext()));
|
||||||
|
|
||||||
|
OpBuilder builder(func_.getBody());
|
||||||
|
std::string custom_option_buffer;
|
||||||
|
if (failed(CreateNMSCustomOptions(func_, attr_.GetAttrs(),
|
||||||
|
custom_option_buffer))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto op = builder.create<CustomOp>(
|
||||||
|
func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
|
||||||
|
kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer));
|
||||||
|
builder.create<ReturnOp>(func_.getLoc(), op.getResults());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
|
||||||
|
FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
|
||||||
|
flexbuffers::Builder fbb;
|
||||||
|
size_t start_map = fbb.StartMap();
|
||||||
|
|
||||||
|
if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
|
||||||
|
failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
|
||||||
|
failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
|
||||||
|
failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
|
||||||
|
failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
|
||||||
|
failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
|
||||||
|
failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
|
||||||
|
failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
|
||||||
|
failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
|
||||||
|
return failure();
|
||||||
|
auto use_regular_nms =
|
||||||
|
attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
|
||||||
|
if (!use_regular_nms) {
|
||||||
|
return func.emitError()
|
||||||
|
<< "use_regular_nms attribute is not set or not a bool";
|
||||||
|
}
|
||||||
|
fbb.Int("use_regular_nms", use_regular_nms.getValue());
|
||||||
|
|
||||||
|
fbb.EndMap(start_map);
|
||||||
|
fbb.Finish();
|
||||||
|
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
|
||||||
|
FuncOp func, DictionaryAttr attrs, const std::string& attribute,
|
||||||
|
flexbuffers::Builder* builder) {
|
||||||
|
auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
|
||||||
|
if (!int_attr) {
|
||||||
|
return func.emitError()
|
||||||
|
<< attribute.c_str() << " attribute is not set or not an integer";
|
||||||
|
}
|
||||||
|
builder->Int(attribute.c_str(), int_attr.getInt());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
|
||||||
|
FuncOp func, DictionaryAttr attrs, const std::string& attribute,
|
||||||
|
flexbuffers::Builder* builder) {
|
||||||
|
auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
|
||||||
|
if (!float_attr) {
|
||||||
|
return func.emitError()
|
||||||
|
<< attribute.c_str() << " attribute is not set or not a float";
|
||||||
|
}
|
||||||
|
builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
|
||||||
|
// Verify high-level function signature.
|
||||||
|
if (func_.getNumArguments() != 3) {
|
||||||
|
return func_.emitError()
|
||||||
|
<< "Invalid number of arguments to " << kCustomSSDPostprocessing
|
||||||
|
<< ": " << func_.getNumArguments();
|
||||||
|
}
|
||||||
|
if (func_.getType().getNumResults() != 4) {
|
||||||
|
return func_.emitError()
|
||||||
|
<< "Invalid number of results from " << kCustomSSDPostprocessing
|
||||||
|
<< ": " << func_.getType().getNumResults();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace TFL
|
||||||
|
} // namespace mlir
|
76
tensorflow/compiler/mlir/lite/utils/nms_utils.h
Normal file
76
tensorflow/compiler/mlir/lite/utils/nms_utils.h
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This header file defines common utils used by TFLite transformation
|
||||||
|
// passes to work with NMS ops in TFLite.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Function.h" // from @llvm-project
|
||||||
|
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace TFL {
|
||||||
|
|
||||||
|
// Abstracts the conversion of the padded NMS composite function.
|
||||||
|
class ConvertNMSPaddedFunc {
|
||||||
|
public:
|
||||||
|
explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
|
||||||
|
|
||||||
|
void RewriteFunc();
|
||||||
|
|
||||||
|
LogicalResult VerifySignature();
|
||||||
|
|
||||||
|
private:
|
||||||
|
FuncOp func_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Abstracts the conversion of the SSD post-processing composite function to
|
||||||
|
// TFLite.
|
||||||
|
class ConvertSSDPostProcessFunc {
|
||||||
|
public:
|
||||||
|
explicit ConvertSSDPostProcessFunc(FuncOp func, mlir::TF::FuncAttr attr)
|
||||||
|
: func_(func), attr_(attr) {}
|
||||||
|
|
||||||
|
LogicalResult RewriteFunc();
|
||||||
|
|
||||||
|
LogicalResult VerifySignature();
|
||||||
|
|
||||||
|
private:
|
||||||
|
LogicalResult CreateNMSCustomOptions(FuncOp func, DictionaryAttr attrs,
|
||||||
|
std::string& custom_option_buffer);
|
||||||
|
|
||||||
|
LogicalResult AddIntAttr(FuncOp func, DictionaryAttr attrs,
|
||||||
|
const std::string& attribute,
|
||||||
|
flexbuffers::Builder* builder);
|
||||||
|
|
||||||
|
LogicalResult AddFloatAttr(FuncOp func, DictionaryAttr attrs,
|
||||||
|
const std::string& attribute,
|
||||||
|
flexbuffers::Builder* builder);
|
||||||
|
|
||||||
|
FuncOp func_;
|
||||||
|
mlir::TF::FuncAttr attr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace TFL
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
|
Loading…
x
Reference in New Issue
Block a user