Convert appropriate composite functions to TFLite custom NMS op.

PiperOrigin-RevId: 328347136
Change-Id: I199e19c04818ef5536dd3ebd7f480f46f37a02b4
This commit is contained in:
Sachin Joglekar 2020-08-25 09:41:17 -07:00 committed by TensorFlower Gardener
parent 03d511ce8f
commit c66e39713e
5 changed files with 320 additions and 53 deletions

View File

@ -280,6 +280,28 @@ cc_library(
], ],
) )
cc_library(
name = "nms_utils",
srcs = [
"utils/nms_utils.cc",
],
hdrs = [
"utils/nms_utils.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/core:framework",
"@flatbuffers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library( cc_library(
name = "tftext_utils", name = "tftext_utils",
srcs = [ srcs = [
@ -373,6 +395,7 @@ cc_library(
deps = [ deps = [
":constant_utils", ":constant_utils",
":lstm_utils", ":lstm_utils",
":nms_utils",
":stateful_ops_utils", ":stateful_ops_utils",
":tensorflow_lite", ":tensorflow_lite",
":tftext_utils", ":tftext_utils",

View File

@ -520,3 +520,42 @@ func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> {
return %0 : tensor<100xf32> return %0 : tensor<100xf32>
} }
} }
// -----
module {
func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
}
// CHECK-LABEL: func @tflite_custom_nms(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x100x4xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x100x91xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = "TFLite_Detection_PostProcess", tf._reference = "mlir"} {
// CHECK: %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {custom_code = "TFLite_Detection_PostProcess", custom_option = opaque<"tfl", "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F010000000A0000009A99193F0000003F5B0000000000000000000040000020410000A0400E06060E0E06060E0E0E322601"> : tensor<217xi8>} : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
// CHECK: return %[[VAL_3]]#0, %[[VAL_3]]#1, %[[VAL_3]]#2, %[[VAL_3]]#3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
// CHECK: }
}
// -----
module {
// expected-error @+1 {{Invalid number of results from TFLite_Detection_PostProcess}}
func @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
// expected-error @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}}
func @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
// expected-error @+1 {{max_classes_per_detection attribute is not set or not an integer}}
func @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
}
}

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -59,6 +60,7 @@ namespace {
constexpr char kTFAPIImplements[] = "tf.api_implements"; constexpr char kTFAPIImplements[] = "tf.api_implements";
constexpr char kTFTextAPIPrefix[] = "tftext:"; constexpr char kTFTextAPIPrefix[] = "tftext:";
constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2"; constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
using mlir::TF::FuncAttr; using mlir::TF::FuncAttr;
@ -99,59 +101,6 @@ class ConvertEmbeddedLookupFunc {
FuncOp func_; FuncOp func_;
}; };
// Abstracts the conversion of the padded NMS composite function.
class ConvertNMSPaddedFunc {
public:
explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
void RewriteFunc() {
func_.setAttr(kTFImplements,
StringAttr::get(kTfNMSPadded, func_.getContext()));
Value boxes = func_.getArgument(0);
Value scores = func_.getArgument(1);
Value max_output_size = func_.getArgument(2);
Value iou_threshold = func_.getArgument(3);
Value score_threshold = func_.getArgument(4);
auto output_type0 = func_.getType().getResult(0);
auto output_type1 = func_.getType().getResult(1);
OpBuilder builder(func_.getBody());
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
func_.getLoc(), output_type0, output_type1, boxes, scores,
max_output_size, iou_threshold, score_threshold);
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
}
LogicalResult VerifySignature() {
// Verify high-level function signature.
// Relevant argument characteristics are checked by the TFL op definition.
if (func_.getNumArguments() < 5) {
return func_.emitError()
<< "Invalid number of arguments to "
"non_max_suppression_padded_v2 (need atleast 5): "
<< func_.getNumArguments();
}
if (func_.getType().getNumResults() != 2) {
return func_.emitError() << "Invalid number of results from "
"non_max_suppression_padded_v2 (need 2): "
<< func_.getType().getNumResults();
}
// The TFLite fused op does not support batching yet.
// TODO(b/158709815): Add support for batches with padded NMS.
auto boxes_type =
func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
return func_.emitError() << "TFLite does not support batched input for "
"non_max_suppression_padded";
}
return success();
}
private:
FuncOp func_;
};
// This pass uses mechanisms listed in RFC: // This pass uses mechanisms listed in RFC:
// https://github.com/tensorflow/community/pull/113 // https://github.com/tensorflow/community/pull/113
// It prepares composite functions that are attributed to indicate // It prepares composite functions that are attributed to indicate
@ -219,6 +168,12 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
if (failed(ConvertTFTextAPI(func, api_name, attr))) { if (failed(ConvertTFTextAPI(func, api_name, attr))) {
return signalPassFailure(); return signalPassFailure();
} }
} else if (api_name == kCustomSSDPostprocessing) {
ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
if (failed(convert_ssd_postprocess.VerifySignature()) ||
failed(convert_ssd_postprocess.RewriteFunc())) {
return signalPassFailure();
}
} }
} }

View File

@ -0,0 +1,174 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
namespace {
// TODO(b/162842801): Consolidate all util definitions of kTFImplements.
constexpr char kTFImplements[] = "tf._implements";
constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
const std::string& content) {
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
type,
StringRef(content.data(), content.size()));
}
} // namespace
void ConvertNMSPaddedFunc::RewriteFunc() {
func_.setAttr(kTFImplements,
StringAttr::get(kTfNMSPadded, func_.getContext()));
Value boxes = func_.getArgument(0);
Value scores = func_.getArgument(1);
Value max_output_size = func_.getArgument(2);
Value iou_threshold = func_.getArgument(3);
Value score_threshold = func_.getArgument(4);
auto output_type0 = func_.getType().getResult(0);
auto output_type1 = func_.getType().getResult(1);
OpBuilder builder(func_.getBody());
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
func_.getLoc(), output_type0, output_type1, boxes, scores,
max_output_size, iou_threshold, score_threshold);
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
}
LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
// Verify high-level function signature.
// Relevant argument characteristics are checked by the TFL op definition.
if (func_.getNumArguments() < 5) {
return func_.emitError()
<< "Invalid number of arguments to "
"non_max_suppression_padded_v2 (need atleast 5): "
<< func_.getNumArguments();
}
if (func_.getType().getNumResults() != 2) {
return func_.emitError() << "Invalid number of results from "
"non_max_suppression_padded_v2 (need 2): "
<< func_.getType().getNumResults();
}
// The TFLite fused op does not support batching yet.
// TODO(b/158709815): Add support for batches with padded NMS.
auto boxes_type = func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
return func_.emitError() << "TFLite does not support batched input for "
"non_max_suppression_padded";
}
return success();
}
LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
func_.eraseBody();
func_.addEntryBlock();
func_.setAttr(kTFImplements,
StringAttr::get(kCustomSSDPostprocessing, func_.getContext()));
OpBuilder builder(func_.getBody());
std::string custom_option_buffer;
if (failed(CreateNMSCustomOptions(func_, attr_.GetAttrs(),
custom_option_buffer))) {
return failure();
}
auto op = builder.create<CustomOp>(
func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer));
builder.create<ReturnOp>(func_.getLoc(), op.getResults());
return success();
}
LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
flexbuffers::Builder fbb;
size_t start_map = fbb.StartMap();
if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
return failure();
auto use_regular_nms =
attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
if (!use_regular_nms) {
return func.emitError()
<< "use_regular_nms attribute is not set or not a bool";
}
fbb.Int("use_regular_nms", use_regular_nms.getValue());
fbb.EndMap(start_map);
fbb.Finish();
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
return success();
}
LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
FuncOp func, DictionaryAttr attrs, const std::string& attribute,
flexbuffers::Builder* builder) {
auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
if (!int_attr) {
return func.emitError()
<< attribute.c_str() << " attribute is not set or not an integer";
}
builder->Int(attribute.c_str(), int_attr.getInt());
return success();
}
LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
FuncOp func, DictionaryAttr attrs, const std::string& attribute,
flexbuffers::Builder* builder) {
auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
if (!float_attr) {
return func.emitError()
<< attribute.c_str() << " attribute is not set or not a float";
}
builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
return success();
}
LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
// Verify high-level function signature.
if (func_.getNumArguments() != 3) {
return func_.emitError()
<< "Invalid number of arguments to " << kCustomSSDPostprocessing
<< ": " << func_.getNumArguments();
}
if (func_.getType().getNumResults() != 4) {
return func_.emitError()
<< "Invalid number of results from " << kCustomSSDPostprocessing
<< ": " << func_.getType().getNumResults();
}
return success();
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header file defines common utils used by TFLite transformation
// passes to work with NMS ops in TFLite.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_
#include <string>
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
namespace mlir {
namespace TFL {
// Abstracts the conversion of the padded NMS composite function.
class ConvertNMSPaddedFunc {
public:
explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
void RewriteFunc();
LogicalResult VerifySignature();
private:
FuncOp func_;
};
// Abstracts the conversion of the SSD post-processing composite function to
// TFLite.
class ConvertSSDPostProcessFunc {
public:
explicit ConvertSSDPostProcessFunc(FuncOp func, mlir::TF::FuncAttr attr)
: func_(func), attr_(attr) {}
LogicalResult RewriteFunc();
LogicalResult VerifySignature();
private:
LogicalResult CreateNMSCustomOptions(FuncOp func, DictionaryAttr attrs,
std::string& custom_option_buffer);
LogicalResult AddIntAttr(FuncOp func, DictionaryAttr attrs,
const std::string& attribute,
flexbuffers::Builder* builder);
LogicalResult AddFloatAttr(FuncOp func, DictionaryAttr attrs,
const std::string& attribute,
flexbuffers::Builder* builder);
FuncOp func_;
mlir::TF::FuncAttr attr_;
};
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_