Add tfjs optimization pass with the first optimization as transforming tf.keras prelu op sets into tfjs.Prelu op.

PiperOrigin-RevId: 308186660
Change-Id: I9ecf0133761a8fff1edfaa3b27a6b8770bd4436c
This commit is contained in:
A. Unique TensorFlower 2020-04-23 20:46:27 -07:00 committed by TensorFlower Gardener
parent a2ffc4c67a
commit 3e074e62e2
9 changed files with 320 additions and 0 deletions

View File

@ -76,6 +76,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
],
)

View File

@ -77,3 +77,64 @@ cc_library(
],
alwayslink = 1,
)
gentbl(
name = "tfjs_optimize_inc_gen",
tbl_outs = [
(
"-gen-rewriters",
"transforms/generated_optimize.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/optimize_pattern.td",
td_srcs = [
":tfjs_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
cc_library(
name = "tfjs_optimize",
srcs = [
"transforms/generated_optimize.inc",
"transforms/optimize.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_js",
":tensorflow_js_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "tensorflow_js_passes",
srcs = ["tf_tfjs_passes.cc"],
hdrs = [
"tf_tfjs_passes.h",
],
deps = [
":tfjs_optimize",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -15,5 +15,6 @@ filegroup(
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -0,0 +1,29 @@
// Run optimize pass only and check the results.
// RUN: tf-opt %s -tfjs-optimize | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: prelu_fusion
func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%alpha = constant dense<-0.2> : tensor<3xf32>
%0 = "tf.Relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tf.Neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%2 = "tf.Relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tf.Mul"(%alpha, %2) : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%4 = "tf.AddV2"(%0, %3) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %4 : tensor<2x3xf32>
// CHECK: %[[RESULT:[0-9].*]] = tfjs.Prelu
}
// CHECK-LABEL: prelu_not_fused
// Rank of alpha should be one less than input for PReLU, which is not the case.
func @prelu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%alpha = constant dense<-0.2> : tensor<f32>
%0 = "tf.Relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tf.Neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%2 = "tf.Relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tf.Mul"(%alpha, %2) : (tensor<f32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%4 = "tf.AddV2"(%0, %3) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %4 : tensor<2x3xf32>
// CHECK: %[[RESULT:[0-9].*]] = "tf.Relu"
}

View File

@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h"
#include <memory>
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
namespace mlir {
/// Create a pass to convert from the TFExecutor to the TF control dialect.
std::unique_ptr<OperationPass<FuncOp>>
CreateTFExecutorToControlDialectConversion();
} // namespace mlir
namespace tensorflow {
void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
// Then we pass the MLIR module through the TF standard pipeline, which for
mlir::TF::StandardPipelineOptions tf_options;
tf_options.enable_inliner = true;
mlir::TF::CreateTFStandardPipeline(*pm, tf_options);
// freeze global tensors.
pm->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
// TFJS dialect passes.
pm->addPass(mlir::tfjs::CreateOptimizePass());
// Canonicalize, CSE etc.
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
}
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
namespace tensorflow {
// Add the TF to TFJS passes into a pass_manager.
void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass takes operations in TensorFlow dialect and
// optimizes them to resulting operations in TensorFlow.js dialect.
#include <memory>
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
namespace mlir {
namespace tfjs {
//===----------------------------------------------------------------------===//
// The actual Optimize Pass.
namespace {
// Optimize TFJS operations in functions.
struct Optimize : public PassWrapper<Optimize, FunctionPass> {
void runOnFunction() override;
};
#include "tensorflow/compiler/mlir/tfjs/transforms/generated_optimize.inc"
void Optimize::runOnFunction() {
OwningRewritePatternList patterns;
auto *ctx = &getContext();
auto func = getFunction();
populateWithGenerated(ctx, &patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
// Creates an instance of the TensorFlow.js dialect Optimize pass.
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass() {
return std::make_unique<Optimize>();
}
static PassRegistration<Optimize> pass(
"tfjs-optimize", "Optimize within the TensorFlow.js dialect");
} // namespace tfjs
} // namespace mlir

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the optimization pattern definition file for TensorFlow.js.
include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// Constraint that makes sure both operands are the same operands.
// TODO(b/154826385): Reconsider once equal source pattern symbols are allowed.
def EqualOperands : Constraint<CPred<"$0 == $1">>;
// Checks if the operand0's rank is one less than operand1's rank.
def PReluAlphaRankCheck : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() == "
"$1.getType().cast<ShapedType>().getRank() - 1">>;
// PReLU pattern from Keras:
// f(x) = Relu(x) + (-alpha * Relu(-x))
def : Pat<(TF_AddV2Op
(TF_ReluOp:$relu_out $input1),
(TF_MulOp:$mul_out
(TF_ReluOp (TF_NegOp:$input_neg_out $input2)),
$neg_alpha)),
(TFJS_PReluOp $input1, (TF_NegOp $neg_alpha)),
[(EqualOperands $input1, $input2),
(PReluAlphaRankCheck $neg_alpha, $input1),
(HasOneUse $relu_out),
(HasOneUse $mul_out),
(HasOneUse $input_neg_out)
]>;

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_
#include <memory>
namespace mlir {
class FuncOp;
template <typename T>
class OperationPass;
namespace tfjs {
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
} // namespace tfjs
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_