diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 664dfe0e3ba..4b3ed9fa2ad 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -353,6 +353,25 @@ tf_cc_test( ], ) +cc_library( + name = "perception_ops_utils", + srcs = [ + "utils/perception_ops_utils.cc", + ], + hdrs = [ + "utils/perception_ops_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/lite/c:common", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "stateful_ops_utils", srcs = [ @@ -385,6 +404,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "perception_ops_utils_test", + size = "small", + srcs = ["utils/perception_ops_utils_test.cc"], + deps = [ + ":perception_ops_utils", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + ], +) + cc_library( name = "tensorflow_lite_legalize_tf", srcs = [ @@ -414,6 +450,7 @@ cc_library( ":constant_utils", ":lstm_utils", ":nms_utils", + ":perception_ops_utils", ":stateful_ops_utils", ":tensorflow_lite", ":tftext_utils", diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 477efbb7c38..6e845afd2d0 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -583,3 +583,68 @@ func private @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %a return %0, %1, %2, %3 : tensor, tensor, tensor, tensor } } + +// ----- + +module { +func @max_unpooling_2d(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>} { + %0 = "tf.Const"() {value = dense<[4, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<0> : tensor<1x1x2x1xi32>} : () -> tensor<1x1x2x1xi32> + %3 = "tf.Const"() {value = dense<[1, 2, 4, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %4 = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + %5 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %6 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %7 = "tf.FloorDiv"(%arg1, %5) {device = ""} : (tensor<1x1x2x1xi32>, tensor) -> tensor<1x1x2x1xi32> + %8 = "tf.FloorMod"(%7, %4) {device = ""} : (tensor<1x1x2x1xi32>, tensor) -> tensor<1x1x2x1xi32> + %9 = "tf.FloorDiv"(%arg1, %4) {device = ""} : (tensor<1x1x2x1xi32>, tensor) -> tensor<1x1x2x1xi32> + %10 = "tf.Pack"(%2, %9, %8, %2) {axis = 0 : i64, device = ""} : (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) -> tensor<4x1x1x2x1xi32> + %11 = "tf.Reshape"(%10, %0) {device = ""} : (tensor<4x1x1x2x1xi32>, tensor<2xi32>) -> tensor<4x2xi32> + %12 = "tf.Transpose"(%11, %6) {device = ""} : (tensor<4x2xi32>, tensor<2xi32>) -> tensor<2x4xi32> + %13 = "tf.Reshape"(%arg0, %1) {device = ""} : (tensor<1x1x2x1xf32>, tensor<1xi32>) -> tensor<2xf32> + %14 = "tf.ScatterNd"(%12, %13, %3) {device = ""} : (tensor<2x4xi32>, tensor<2xf32>, tensor<4xi32>) -> tensor<1x2x4x1xf32> + %15 = "tf.Identity"(%14) {device = ""} : (tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> + return %15 : tensor<1x2x4x1xf32> +} + +// CHECK-LABEL: func @max_unpooling_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x2x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = "MaxUnpooling2D"} { +// CHECK-NEXT: %[[VAL_2:.*]] = "tfl.custom"(%[[VAL_0]], %[[VAL_1]]) {custom_code = "MaxUnpooling2D", custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>} : (tensor<1x1x2x1xf32>, tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> +// CHECK-NEXT: return %[[VAL_2]] : tensor<1x2x4x1xf32> +// CHECK-NEXT: } +} + +// ----- + +module { +// expected-error @+1 {{Invalid number of results from MaxUnpooling2D}} +func private @max_unpooling_2d_invalid_results(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> (tensor<1x2x4x1xf32>, tensor<1x2x4x1xi32>) attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>} + +// expected-error @+1 {{Invalid number of arguments to MaxUnpooling2D}} +func private @max_unpooling_2d_invalid_args(%arg0: tensor<1x1x2x1xf32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>} + +// expected-error @+1 {{Padding for MaxUnpooling2D must be 'SAME' or 'VALID'}} +func private @max_unpooling_2d_wrong_padding(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "NO", pool_size = [2, 2], strides = [2, 2]}>} + +// expected-error @+1 {{'pool_size' attribute for MaxUnpooling2D must be set and has size of 2}} +func private @max_unpooling_2d_wrong_filter(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2], strides = [2, 2]}>} + +// expected-error @+1 {{'strides' attribute for MaxUnpooling2D must be set and has size of 2}} +func private @max_unpooling_2d_wrong_strides(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2, 2]}>} + +// expected-error @+1 {{'padding' attribute for MaxUnpooling2D is not set or not a string}} +func private @max_unpooling_2d_no_padding(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {pool_size = [2, 2], strides = [2, 2]}>} + +// expected-error @+1 {{'pool_size' attribute for MaxUnpooling2D must be set and has size of 2}} +func private @max_unpooling_2d_no_filter(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", strides = [2, 2]}>} + +// expected-error @+1 {{'strides' attribute for MaxUnpooling2D must be set and has size of 2}} +func private @max_unpooling_2d_no_strides(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2]}>} + +// expected-error @+1 {{'pool_size' attribute for MaxUnpooling2D does not contain integer values}} +func private @max_unpooling_2d_filter_wrong_type(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = ["a", "b"], strides = [2, 2]}>} + + // expected-error @+1 {{'strides' attribute for MaxUnpooling2D does not contain integer values}} +func private @max_unpooling_2d_strides_wrong_type(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = ["2", "2"]}>} +} diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 49a0a31088f..6450bd400d9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h" #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -61,6 +62,7 @@ constexpr char kTFAPIImplements[] = "tf.api_implements"; constexpr char kTFTextAPIPrefix[] = "tftext:"; constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess"; constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2"; +constexpr char kCustomMaxUnpooling[] = "addons:MaxUnpooling2D"; using mlir::TF::FuncAttr; @@ -294,6 +296,12 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes( failed(convert_ssd_postprocess.RewriteFunc())) { return signalPassFailure(); } + } else if (api_name == kCustomMaxUnpooling) { + ConvertMaxUnpoolingFunc max_unpooling(func, attr); + if (failed(max_unpooling.VerifySignature()) || + failed(max_unpooling.RewriteFunc())) { + return signalPassFailure(); + } } } diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc new file mode 100644 index 00000000000..41cce577f20 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h" + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/lite/c/builtin_op_data.h" + +namespace mlir { +namespace TFL { + +namespace { + +constexpr char kTFImplements[] = "tf._implements"; +constexpr char kMaxUnpooling[] = "MaxUnpooling2D"; + +inline OpaqueElementsAttr CustomOption(OpBuilder* builder, + const std::string& content) { + ShapedType type = RankedTensorType::get( + {static_cast(content.size())}, builder->getIntegerType(8)); + return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"), + type, + StringRef(content.data(), content.size())); +} + +inline LogicalResult GetIntegerArraySafe( + FuncOp* func, const DictionaryAttr& attrs, const std::string& attr_name, + llvm::SmallVectorImpl* results, int N) { + ArrayAttr array_attr = attrs.get(attr_name).dyn_cast_or_null(); + if (array_attr == nullptr || array_attr.size() != N) { + return func->emitError() + << "'" << attr_name << "' attribute for " << kMaxUnpooling + << " must be set and has size of " << N; + } + results->reserve(N); + + for (Attribute integer_attr : array_attr.getValue()) { + IntegerAttr value = integer_attr.dyn_cast(); + if (!value) { + return func->emitError() + << "'" << attr_name << "' attribute for " << kMaxUnpooling + << " does not contain integer values"; + } + results->push_back(value.getInt()); + } + return success(); +} + +} // namespace + +LogicalResult ConvertMaxUnpoolingFunc::RewriteFunc() { + func_.eraseBody(); + func_.addEntryBlock(); + func_.setAttr(kTFImplements, + StringAttr::get(kMaxUnpooling, func_.getContext())); + + OpBuilder builder(func_.getBody()); + std::string custom_option_buffer; + if (failed(CreateCustomOptions(custom_option_buffer))) { + return failure(); + } + auto op = builder.create( + func_.getLoc(), func_.getType().getResults(), func_.getArguments(), + kMaxUnpooling, CustomOption(&builder, custom_option_buffer)); + builder.create(func_.getLoc(), op.getResults()); + + return success(); +} + +LogicalResult ConvertMaxUnpoolingFunc::VerifySignature() { + // Verify high-level function signature. + if (func_.getNumArguments() != 2) { + return func_.emitError() + << "Invalid number of arguments to " << kMaxUnpooling << ": " + << func_.getNumArguments(); + } + if (func_.getType().getNumResults() != 1) { + return func_.emitError() + << "Invalid number of results from " << kMaxUnpooling << ": " + << func_.getType().getNumResults(); + } + return success(); +} + +LogicalResult ConvertMaxUnpoolingFunc::CreateCustomOptions( + std::string& custom_option_buffer) { + auto attrs = attr_.GetAttrs(); + TfLitePoolParams pool_params; + + llvm::SmallVector pool_size; + if (failed(GetIntegerArraySafe(&func_, attrs, "pool_size", &pool_size, 2))) { + return failure(); + } + pool_params.filter_height = pool_size[0]; + pool_params.filter_width = pool_size[1]; + + // Retrieve strides. + llvm::SmallVector strides; + if (failed(GetIntegerArraySafe(&func_, attrs, "strides", &strides, 2))) { + return failure(); + } + pool_params.stride_height = strides[0]; + pool_params.stride_width = strides[1]; + + // Retrieves padding. + auto padding = attrs.get("padding").dyn_cast_or_null(); + if (!padding) { + return func_.emitError() << "'padding' attribute for " << kMaxUnpooling + << " is not set or not a string"; + } + if (padding.getValue().equals("VALID")) { + pool_params.padding = kTfLitePaddingValid; + } else if (padding.getValue().equals("SAME")) { + pool_params.padding = kTfLitePaddingSame; + } else { + return func_.emitError() + << "Padding for " << kMaxUnpooling << " must be 'SAME' or 'VALID'"; + } + + pool_params.activation = kTfLiteActNone; + pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; + + custom_option_buffer.assign(reinterpret_cast(&pool_params), + sizeof(TfLitePoolParams)); + return success(); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h new file mode 100644 index 00000000000..e82c77957b7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" + +namespace mlir { +namespace TFL { + +// Fuse MaxUnpooling2D ops annotated by tf.function to a TFLite custom op. +class ConvertMaxUnpoolingFunc { + public: + explicit ConvertMaxUnpoolingFunc(FuncOp func, mlir::TF::FuncAttr attr) + : func_(func), attr_(attr) {} + + LogicalResult RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + LogicalResult CreateCustomOptions(std::string& custom_option_buffer); + + FuncOp func_; + mlir::TF::FuncAttr attr_; +}; + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc new file mode 100644 index 00000000000..19a2b81325c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc @@ -0,0 +1,196 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h" + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir { +namespace TFL { +namespace { + +template +FuncOp createMaxUnpoolingFunc( + mlir::Builder* builder, const SmallVector& input_types, + const SmallVector& output_types) { + auto func_type = builder->getFunctionType(input_types, output_types); + auto func = + FuncOp::create(mlir::NameLoc::get(builder->getIdentifier("fused_func"), + builder->getContext()), + "fused_func", func_type, {}); + + func.addEntryBlock(); + mlir::StringAttr attr_value = builder->getStringAttr("MaxUnpooling2D"); + func.setAttr("tf._implements", attr_value); + return func; +} + +FuncOp createMaxUnpoolingFunc(mlir::Builder* builder, + const SmallVector& input_shape, + const SmallVector& output_shape) { + auto input_type = RankedTensorType::get(input_shape, builder->getF32Type()); + auto indices_type = RankedTensorType::get(input_shape, builder->getI64Type()); + auto output_type = RankedTensorType::get(output_shape, builder->getF32Type()); + SmallVector input_types{input_type, indices_type}; + SmallVector output_types{output_type}; + return createMaxUnpoolingFunc<2, 1>(builder, input_types, output_types); +} + +template +ArrayAttr createInt32Array(mlir::Builder* builder, mlir::MLIRContext* context, + const SmallVector& values) { + SmallVector ret; + for (int32_t value : values) { + ret.push_back(builder->getI32IntegerAttr(value)); + } + return ArrayAttr::get(ret, context); +} + +template +ArrayAttr createInt64Array(mlir::Builder* builder, mlir::MLIRContext* context, + const SmallVector& values) { + SmallVector ret; + for (int64_t value : values) { + ret.push_back(builder->getI64IntegerAttr(value)); + } + return ArrayAttr::get(ret, context); +} + +mlir::TF::FuncAttr createMaxUnpoolingAttr(mlir::MLIRContext* context, + const std::string& padding, + const ArrayAttr& pool_size, + const ArrayAttr& strides) { + SmallVector<::mlir::NamedAttribute, 3> fields; + + auto padding_id = ::mlir::Identifier::get("padding", context); + fields.emplace_back(padding_id, StringAttr::get(padding, context)); + + auto pool_size_id = ::mlir::Identifier::get("pool_size", context); + fields.emplace_back(pool_size_id, pool_size); + + auto strides_id = ::mlir::Identifier::get("strides", context); + fields.emplace_back(strides_id, strides); + + DictionaryAttr dict = DictionaryAttr::get(fields, context); + return TF::FuncAttr::get(context, "MaxUnpooling2D", dict); +} + +} // namespace + +class PerceptionUtilsTest : public ::testing::Test { + protected: + PerceptionUtilsTest() {} + + void SetUp() override { + context_ = std::make_unique(); + context_->loadDialect(); + builder_ = std::unique_ptr(new Builder(context_.get())); + + fused_max_unpooling_func_ = + createMaxUnpoolingFunc(builder_.get(), {2, 4, 4, 2}, {2, 2, 2, 2}); + + func_attr_ = createMaxUnpoolingAttr( + context_.get(), "SAME", + createInt32Array<2>(builder_.get(), context_.get(), {2, 2}), + createInt32Array<2>(builder_.get(), context_.get(), {2, 2})); + } + + void TearDown() override { + fused_max_unpooling_func_.erase(); + builder_.reset(); + } + + FuncOp fused_max_unpooling_func_; + mlir::TF::FuncAttr func_attr_; + std::unique_ptr context_; + std::unique_ptr builder_; +}; + +TEST_F(PerceptionUtilsTest, VerifySignatureValid) { + mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_, + func_attr_); + + EXPECT_FALSE(failed(convert.VerifySignature())); +} + +TEST_F(PerceptionUtilsTest, VerifySignatureInvalid) { + auto input_type = RankedTensorType::get({1, 2, 2, 1}, builder_->getF32Type()); + auto output_type = + RankedTensorType::get({1, 2, 1, 1}, builder_->getF32Type()); + SmallVector input_types{input_type}; + SmallVector output_types{output_type}; + + auto max_unpooling_func = + createMaxUnpoolingFunc<1, 1>(builder_.get(), input_types, output_types); + mlir::TFL::ConvertMaxUnpoolingFunc convert(max_unpooling_func, func_attr_); + + EXPECT_TRUE(failed(convert.VerifySignature())); + max_unpooling_func->erase(); +} + +TEST_F(PerceptionUtilsTest, RewriteValid) { + mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_, + func_attr_); + + EXPECT_FALSE(failed(convert.RewriteFunc())); +} + +TEST_F(PerceptionUtilsTest, RewriteWrongPadding) { + auto func_attr = createMaxUnpoolingAttr( + context_.get(), "INVALID", + createInt32Array<2>(builder_.get(), context_.get(), {2, 2}), + createInt32Array<2>(builder_.get(), context_.get(), {2, 2})); + mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_, + func_attr); + + EXPECT_TRUE(failed(convert.RewriteFunc())); +} + +TEST_F(PerceptionUtilsTest, RewriteWrongFilter) { + auto func_attr = createMaxUnpoolingAttr( + context_.get(), "VALID", + createInt32Array<2>(builder_.get(), context_.get(), {2, 2, 2}), + createInt32Array<2>(builder_.get(), context_.get(), {2, 2})); + mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_, + func_attr); + + EXPECT_TRUE(failed(convert.RewriteFunc())); +} + +TEST_F(PerceptionUtilsTest, RewriteWrongStrides) { + auto func_attr = createMaxUnpoolingAttr( + context_.get(), "VALID", + createInt32Array<2>(builder_.get(), context_.get(), {2, 2}), + createInt32Array<2>(builder_.get(), context_.get(), {2, 2, 0})); + mlir::TFL::ConvertMaxUnpoolingFunc convert(fused_max_unpooling_func_, + func_attr); + + EXPECT_TRUE(failed(convert.RewriteFunc())); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/lite/kernels/perception/BUILD b/tensorflow/lite/kernels/perception/BUILD index 7c18db40fe5..0cead4052b7 100644 --- a/tensorflow/lite/kernels/perception/BUILD +++ b/tensorflow/lite/kernels/perception/BUILD @@ -1,5 +1,9 @@ +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "pybind_extension") + package( default_visibility = [ "//visibility:public", @@ -48,3 +52,20 @@ cc_test( "@flatbuffers", ], ) + +pybind_extension( + name = "pywrap_perception_ops", + srcs = [ + "perception_ops_wrapper.cc", + ], + hdrs = ["perception_ops.h"], + additional_exported_symbols = ["PerceptionOpsRegisterer"], + link_in_framework = True, + module_name = "pywrap_perception_ops", + deps = [ + ":perception_ops", + "//tensorflow/lite:mutable_op_resolver", + "//third_party/python_runtime:headers", + "@pybind11", + ], +) diff --git a/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc b/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc new file mode 100644 index 00000000000..7fd1282e604 --- /dev/null +++ b/tensorflow/lite/kernels/perception/perception_ops_wrapper.cc @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "tensorflow/lite/kernels/perception/perception_ops.h" + +PYBIND11_MODULE(pywrap_perception_ops, m) { + m.doc() = R"pbdoc( + pywrap_perception_ops + ----- + )pbdoc"; + m.def( + "PerceptionOpsRegisterer", + [](uintptr_t resolver) { + tflite::ops::custom::AddPerceptionOps( + reinterpret_cast(resolver)); + }, + R"pbdoc( + Perception op registerer function with the correct signature. Registers + Perception custom ops. + )pbdoc"); +} diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py index 6b3aff1d233..580e47ef567 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py @@ -28,6 +28,7 @@ from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.lite.python import convert_saved_model as _convert_saved_model +from tensorflow.lite.python import interpreter as _interpreter from tensorflow.lite.python import lite as _lite from tensorflow.lite.python import util as _util from tensorflow.python.client import session as _session @@ -170,7 +171,9 @@ def _check_model_quantized_to_16x8(tflite_model): raise ValueError("Could not find int16 activations.") -def _get_tflite_interpreter(tflite_model, input_shapes_resize=None): +def _get_tflite_interpreter(tflite_model, + input_shapes_resize=None, + custom_op_registerers=None): """Creates a TFLite interpreter with resized input tensors. Args: @@ -178,11 +181,15 @@ def _get_tflite_interpreter(tflite_model, input_shapes_resize=None): input_shapes_resize: A map where the key is the input tensor name and the value is the shape of the input tensor. This resize happens after model conversion, prior to calling allocate tensors. (default None) + custom_op_registerers: Op registerers for custom ops. Returns: lite.Interpreter """ - interpreter = _lite.Interpreter(model_content=tflite_model) + if custom_op_registerers is None: + custom_op_registerers = [] + interpreter = _interpreter.InterpreterWithCustomOps( + model_content=tflite_model, custom_op_registerers=custom_op_registerers) if input_shapes_resize: input_details = interpreter.get_input_details() input_details_map = { @@ -194,17 +201,19 @@ def _get_tflite_interpreter(tflite_model, input_shapes_resize=None): return interpreter -def _get_input_data_map(tflite_model, input_data): +def _get_input_data_map(tflite_model, input_data, custom_op_registerers=None): """Generates a map of input data based on the TFLite model. Args: tflite_model: Serialized TensorFlow Lite model. input_data: List of np.ndarray. + custom_op_registerers: Op registerers for custom ops. Returns: {str: [np.ndarray]}. """ - interpreter = _get_tflite_interpreter(tflite_model) + interpreter = _get_tflite_interpreter( + tflite_model, custom_op_registerers=custom_op_registerers) interpreter.allocate_tensors() input_details = interpreter.get_input_details() return { @@ -216,7 +225,8 @@ def _get_input_data_map(tflite_model, input_data): def _generate_random_input_data(tflite_model, seed=None, input_data_range=None, - input_shapes_resize=None): + input_shapes_resize=None, + custom_op_registerers=None): """Generates input data based on the input tensors in the TFLite model. Args: @@ -230,11 +240,15 @@ def _generate_random_input_data(tflite_model, input_shapes_resize: A map where the key is the input tensor name and the value is the shape of the input tensor. This resize happens after model conversion, prior to calling allocate tensors. (default None) + custom_op_registerers: Op registerers for custom ops. Returns: ([np.ndarray], {str : [np.ndarray]}). """ - interpreter = _get_tflite_interpreter(tflite_model, input_shapes_resize) + interpreter = _get_tflite_interpreter( + tflite_model, + input_shapes_resize, + custom_op_registerers=custom_op_registerers) interpreter.allocate_tensors() input_details = interpreter.get_input_details() @@ -254,11 +268,15 @@ def _generate_random_input_data(tflite_model, ) * val + input_data_range[input_tensor["name"]][0] input_data.append(np.array(val, dtype=input_tensor["dtype"])) - input_data_map = _get_input_data_map(tflite_model, input_data) + input_data_map = _get_input_data_map( + tflite_model, input_data, custom_op_registerers=custom_op_registerers) return input_data, input_data_map -def _evaluate_tflite_model(tflite_model, input_data, input_shapes_resize=None): +def _evaluate_tflite_model(tflite_model, + input_data, + input_shapes_resize=None, + custom_op_registerers=None): """Returns evaluation of input data on TFLite model. Args: @@ -267,11 +285,15 @@ def _evaluate_tflite_model(tflite_model, input_data, input_shapes_resize=None): input_shapes_resize: A map where the key is the input tensor name and the value is the shape of the input tensor. This resize happens after model conversion, prior to calling allocate tensors. (default None) + custom_op_registerers: Op registerers for custom ops. Returns: List of np.ndarray. """ - interpreter = _get_tflite_interpreter(tflite_model, input_shapes_resize) + interpreter = _get_tflite_interpreter( + tflite_model, + input_shapes_resize, + custom_op_registerers=custom_op_registerers) interpreter.allocate_tensors() input_details = interpreter.get_input_details() @@ -403,6 +425,31 @@ def compare_models(tflite_model, np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) +def _compare_tf_tflite_results(tf_results, + tflite_results, + tflite_labels, + tolerance=5): + """Compare the result of TF and TFLite model. + + Args: + tf_results: results returned by the TF model. + tflite_results: results returned by the TFLite model. + tflite_labels: names of the output tensors in the TFlite model. + tolerance: Decimal place to check accuracy to. (default 5). + """ + # Convert the output TensorFlow results into an ordered list. + if isinstance(tf_results, dict): + if len(tf_results) == 1: + tf_results = [tf_results[list(tf_results.keys())[0]]] + else: + tf_results = [tf_results[tflite_label] for tflite_label in tflite_labels] + else: + tf_results = [tf_results] + + for tf_result, tflite_result in zip(tf_results, tflite_results): + np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) + + def compare_models_v2(tflite_model, tf_eval_func, input_data=None, @@ -444,17 +491,49 @@ def compare_models_v2(tflite_model, tflite_results, tflite_labels = _evaluate_tflite_model( tflite_model, input_data) - # Convert the output TensorFlow results into an ordered list. - if isinstance(tf_results, dict): - if len(tf_results) == 1: - tf_results = [tf_results[list(tf_results.keys())[0]]] - else: - tf_results = [tf_results[tflite_label] for tflite_label in tflite_labels] - else: - tf_results = [tf_results] + _compare_tf_tflite_results(tf_results, tflite_results, tflite_labels, + tolerance) - for tf_result, tflite_result in zip(tf_results, tflite_results): - np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) + +def compare_tflite_keras_models_v2(tflite_model, + keras_model, + input_data=None, + input_data_range=None, + tolerance=5, + custom_op_registerers=None): + """Similar to compare_models_v2 but accept Keras model. + + Unless the input data is provided, the models are compared with random data. + Currently only 1 input and 1 output are supported by this function. + + Args: + tflite_model: Serialized TensorFlow Lite model. + keras_model: Keras model to evaluate. + input_data: np.ndarray to pass into models during inference. (default None). + input_data_range: A map where the key is the input tensor name and the value + is a tuple (min_val, max_val) which specifies the value range of + the corresponding input tensor. For example, '{'input1': (1, 5)}' means to + generate a random value for tensor `input1` within range [1.0, 5.0) + (half-inclusive). (default None) + tolerance: Decimal place to check accuracy to. (default 5) + custom_op_registerers: Op registerers for custom ops. + """ + # Generate random input data if not provided. + if input_data is None: + input_data, _ = _generate_random_input_data( + tflite_model=tflite_model, + input_data_range=input_data_range, + custom_op_registerers=custom_op_registerers) + + if len(input_data) > 1: + tf_results = keras_model.predict(input_data) + else: + tf_results = keras_model.predict(input_data[0]) + tflite_results, tflite_labels = _evaluate_tflite_model( + tflite_model, input_data, custom_op_registerers=custom_op_registerers) + + _compare_tf_tflite_results(tf_results, tflite_results, tflite_labels, + tolerance) def compare_model_golden(tflite_model,