Convert appropriate composite functions to TFLite custom NMS op.
PiperOrigin-RevId: 328347136 Change-Id: I199e19c04818ef5536dd3ebd7f480f46f37a02b4
This commit is contained in:
parent
03d511ce8f
commit
c66e39713e
@ -280,6 +280,28 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nms_utils",
|
||||
srcs = [
|
||||
"utils/nms_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/nms_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||
"//tensorflow/core:framework",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tftext_utils",
|
||||
srcs = [
|
||||
@ -373,6 +395,7 @@ cc_library(
|
||||
deps = [
|
||||
":constant_utils",
|
||||
":lstm_utils",
|
||||
":nms_utils",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
":tftext_utils",
|
||||
|
@ -520,3 +520,42 @@ func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> {
|
||||
return %0 : tensor<100xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
|
||||
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tflite_custom_nms(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x100x4xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x100x91xf32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = "TFLite_Detection_PostProcess", tf._reference = "mlir"} {
|
||||
// CHECK: %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {custom_code = "TFLite_Detection_PostProcess", custom_option = opaque<"tfl", "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F010000000A0000009A99193F0000003F5B0000000000000000000040000020410000A0400E06060E0E06060E0E0E322601"> : tensor<217xi8>} : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
// CHECK: return %[[VAL_3]]#0, %[[VAL_3]]#1, %[[VAL_3]]#2, %[[VAL_3]]#3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// expected-error @+1 {{Invalid number of results from TFLite_Detection_PostProcess}}
|
||||
func @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
|
||||
|
||||
// expected-error @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}}
|
||||
func @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}
|
||||
|
||||
// expected-error @+1 {{max_classes_per_detection attribute is not set or not an integer}}
|
||||
func @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = #tf.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
|
||||
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
|
||||
}
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -59,6 +60,7 @@ namespace {
|
||||
|
||||
constexpr char kTFAPIImplements[] = "tf.api_implements";
|
||||
constexpr char kTFTextAPIPrefix[] = "tftext:";
|
||||
constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
|
||||
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
|
||||
|
||||
using mlir::TF::FuncAttr;
|
||||
@ -99,59 +101,6 @@ class ConvertEmbeddedLookupFunc {
|
||||
FuncOp func_;
|
||||
};
|
||||
|
||||
// Abstracts the conversion of the padded NMS composite function.
|
||||
class ConvertNMSPaddedFunc {
|
||||
public:
|
||||
explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
|
||||
|
||||
void RewriteFunc() {
|
||||
func_.setAttr(kTFImplements,
|
||||
StringAttr::get(kTfNMSPadded, func_.getContext()));
|
||||
Value boxes = func_.getArgument(0);
|
||||
Value scores = func_.getArgument(1);
|
||||
Value max_output_size = func_.getArgument(2);
|
||||
Value iou_threshold = func_.getArgument(3);
|
||||
Value score_threshold = func_.getArgument(4);
|
||||
auto output_type0 = func_.getType().getResult(0);
|
||||
auto output_type1 = func_.getType().getResult(1);
|
||||
|
||||
OpBuilder builder(func_.getBody());
|
||||
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
|
||||
func_.getLoc(), output_type0, output_type1, boxes, scores,
|
||||
max_output_size, iou_threshold, score_threshold);
|
||||
|
||||
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
|
||||
}
|
||||
|
||||
LogicalResult VerifySignature() {
|
||||
// Verify high-level function signature.
|
||||
// Relevant argument characteristics are checked by the TFL op definition.
|
||||
if (func_.getNumArguments() < 5) {
|
||||
return func_.emitError()
|
||||
<< "Invalid number of arguments to "
|
||||
"non_max_suppression_padded_v2 (need atleast 5): "
|
||||
<< func_.getNumArguments();
|
||||
}
|
||||
if (func_.getType().getNumResults() != 2) {
|
||||
return func_.emitError() << "Invalid number of results from "
|
||||
"non_max_suppression_padded_v2 (need 2): "
|
||||
<< func_.getType().getNumResults();
|
||||
}
|
||||
// The TFLite fused op does not support batching yet.
|
||||
// TODO(b/158709815): Add support for batches with padded NMS.
|
||||
auto boxes_type =
|
||||
func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
|
||||
return func_.emitError() << "TFLite does not support batched input for "
|
||||
"non_max_suppression_padded";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
FuncOp func_;
|
||||
};
|
||||
|
||||
// This pass uses mechanisms listed in RFC:
|
||||
// https://github.com/tensorflow/community/pull/113
|
||||
// It prepares composite functions that are attributed to indicate
|
||||
@ -219,6 +168,12 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes(
|
||||
if (failed(ConvertTFTextAPI(func, api_name, attr))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
} else if (api_name == kCustomSSDPostprocessing) {
|
||||
ConvertSSDPostProcessFunc convert_ssd_postprocess(func, attr);
|
||||
if (failed(convert_ssd_postprocess.VerifySignature()) ||
|
||||
failed(convert_ssd_postprocess.RewriteFunc())) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
174
tensorflow/compiler/mlir/lite/utils/nms_utils.cc
Normal file
174
tensorflow/compiler/mlir/lite/utils/nms_utils.cc
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(b/162842801): Consolidate all util definitions of kTFImplements.
|
||||
constexpr char kTFImplements[] = "tf._implements";
|
||||
constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
|
||||
constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
|
||||
|
||||
inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
|
||||
const std::string& content) {
|
||||
ShapedType type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
|
||||
return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
|
||||
type,
|
||||
StringRef(content.data(), content.size()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ConvertNMSPaddedFunc::RewriteFunc() {
|
||||
func_.setAttr(kTFImplements,
|
||||
StringAttr::get(kTfNMSPadded, func_.getContext()));
|
||||
Value boxes = func_.getArgument(0);
|
||||
Value scores = func_.getArgument(1);
|
||||
Value max_output_size = func_.getArgument(2);
|
||||
Value iou_threshold = func_.getArgument(3);
|
||||
Value score_threshold = func_.getArgument(4);
|
||||
auto output_type0 = func_.getType().getResult(0);
|
||||
auto output_type1 = func_.getType().getResult(1);
|
||||
|
||||
OpBuilder builder(func_.getBody());
|
||||
auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
|
||||
func_.getLoc(), output_type0, output_type1, boxes, scores,
|
||||
max_output_size, iou_threshold, score_threshold);
|
||||
|
||||
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
|
||||
}
|
||||
|
||||
LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
|
||||
// Verify high-level function signature.
|
||||
// Relevant argument characteristics are checked by the TFL op definition.
|
||||
if (func_.getNumArguments() < 5) {
|
||||
return func_.emitError()
|
||||
<< "Invalid number of arguments to "
|
||||
"non_max_suppression_padded_v2 (need atleast 5): "
|
||||
<< func_.getNumArguments();
|
||||
}
|
||||
if (func_.getType().getNumResults() != 2) {
|
||||
return func_.emitError() << "Invalid number of results from "
|
||||
"non_max_suppression_padded_v2 (need 2): "
|
||||
<< func_.getType().getNumResults();
|
||||
}
|
||||
// The TFLite fused op does not support batching yet.
|
||||
// TODO(b/158709815): Add support for batches with padded NMS.
|
||||
auto boxes_type = func_.getArgument(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!boxes_type.hasRank() || boxes_type.getRank() != 2) {
|
||||
return func_.emitError() << "TFLite does not support batched input for "
|
||||
"non_max_suppression_padded";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
|
||||
func_.eraseBody();
|
||||
func_.addEntryBlock();
|
||||
func_.setAttr(kTFImplements,
|
||||
StringAttr::get(kCustomSSDPostprocessing, func_.getContext()));
|
||||
|
||||
OpBuilder builder(func_.getBody());
|
||||
std::string custom_option_buffer;
|
||||
if (failed(CreateNMSCustomOptions(func_, attr_.GetAttrs(),
|
||||
custom_option_buffer))) {
|
||||
return failure();
|
||||
}
|
||||
auto op = builder.create<CustomOp>(
|
||||
func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
|
||||
kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer));
|
||||
builder.create<ReturnOp>(func_.getLoc(), op.getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
|
||||
FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
|
||||
flexbuffers::Builder fbb;
|
||||
size_t start_map = fbb.StartMap();
|
||||
|
||||
if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
|
||||
failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
|
||||
failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
|
||||
failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
|
||||
failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
|
||||
failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
|
||||
failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
|
||||
failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
|
||||
failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
|
||||
return failure();
|
||||
auto use_regular_nms =
|
||||
attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
|
||||
if (!use_regular_nms) {
|
||||
return func.emitError()
|
||||
<< "use_regular_nms attribute is not set or not a bool";
|
||||
}
|
||||
fbb.Int("use_regular_nms", use_regular_nms.getValue());
|
||||
|
||||
fbb.EndMap(start_map);
|
||||
fbb.Finish();
|
||||
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
|
||||
FuncOp func, DictionaryAttr attrs, const std::string& attribute,
|
||||
flexbuffers::Builder* builder) {
|
||||
auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
|
||||
if (!int_attr) {
|
||||
return func.emitError()
|
||||
<< attribute.c_str() << " attribute is not set or not an integer";
|
||||
}
|
||||
builder->Int(attribute.c_str(), int_attr.getInt());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
|
||||
FuncOp func, DictionaryAttr attrs, const std::string& attribute,
|
||||
flexbuffers::Builder* builder) {
|
||||
auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
|
||||
if (!float_attr) {
|
||||
return func.emitError()
|
||||
<< attribute.c_str() << " attribute is not set or not a float";
|
||||
}
|
||||
builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
|
||||
// Verify high-level function signature.
|
||||
if (func_.getNumArguments() != 3) {
|
||||
return func_.emitError()
|
||||
<< "Invalid number of arguments to " << kCustomSSDPostprocessing
|
||||
<< ": " << func_.getNumArguments();
|
||||
}
|
||||
if (func_.getType().getNumResults() != 4) {
|
||||
return func_.emitError()
|
||||
<< "Invalid number of results from " << kCustomSSDPostprocessing
|
||||
<< ": " << func_.getType().getNumResults();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
76
tensorflow/compiler/mlir/lite/utils/nms_utils.h
Normal file
76
tensorflow/compiler/mlir/lite/utils/nms_utils.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header file defines common utils used by TFLite transformation
|
||||
// passes to work with NMS ops in TFLite.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Abstracts the conversion of the padded NMS composite function.
|
||||
class ConvertNMSPaddedFunc {
|
||||
public:
|
||||
explicit ConvertNMSPaddedFunc(FuncOp func) : func_(func) {}
|
||||
|
||||
void RewriteFunc();
|
||||
|
||||
LogicalResult VerifySignature();
|
||||
|
||||
private:
|
||||
FuncOp func_;
|
||||
};
|
||||
|
||||
// Abstracts the conversion of the SSD post-processing composite function to
|
||||
// TFLite.
|
||||
class ConvertSSDPostProcessFunc {
|
||||
public:
|
||||
explicit ConvertSSDPostProcessFunc(FuncOp func, mlir::TF::FuncAttr attr)
|
||||
: func_(func), attr_(attr) {}
|
||||
|
||||
LogicalResult RewriteFunc();
|
||||
|
||||
LogicalResult VerifySignature();
|
||||
|
||||
private:
|
||||
LogicalResult CreateNMSCustomOptions(FuncOp func, DictionaryAttr attrs,
|
||||
std::string& custom_option_buffer);
|
||||
|
||||
LogicalResult AddIntAttr(FuncOp func, DictionaryAttr attrs,
|
||||
const std::string& attribute,
|
||||
flexbuffers::Builder* builder);
|
||||
|
||||
LogicalResult AddFloatAttr(FuncOp func, DictionaryAttr attrs,
|
||||
const std::string& attribute,
|
||||
flexbuffers::Builder* builder);
|
||||
|
||||
FuncOp func_;
|
||||
mlir::TF::FuncAttr attr_;
|
||||
};
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
|
Loading…
x
Reference in New Issue
Block a user