Add tfjs optimization pass with the first optimization as transforming tf.keras prelu op sets into tfjs.Prelu op.
PiperOrigin-RevId: 308186660 Change-Id: I9ecf0133761a8fff1edfaa3b27a6b8770bd4436c
This commit is contained in:
parent
a2ffc4c67a
commit
3e074e62e2
tensorflow/compiler/mlir
@ -76,6 +76,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -77,3 +77,64 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tfjs_optimize_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-rewriters",
|
||||
"transforms/generated_optimize.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/optimize_pattern.td",
|
||||
td_srcs = [
|
||||
":tfjs_ops_td_files",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfjs_optimize",
|
||||
srcs = [
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/optimize.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_js",
|
||||
":tensorflow_js_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_js_passes",
|
||||
srcs = ["tf_tfjs_passes.cc"],
|
||||
hdrs = [
|
||||
"tf_tfjs_passes.h",
|
||||
],
|
||||
deps = [
|
||||
":tfjs_optimize",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:decode_constant_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
@ -15,5 +15,6 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
29
tensorflow/compiler/mlir/tfjs/tests/optimize.mlir
Normal file
29
tensorflow/compiler/mlir/tfjs/tests/optimize.mlir
Normal file
@ -0,0 +1,29 @@
|
||||
// Run optimize pass only and check the results.
|
||||
// RUN: tf-opt %s -tfjs-optimize | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: prelu_fusion
|
||||
func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<-0.2> : tensor<3xf32>
|
||||
%0 = "tf.Relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tf.Neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%3 = "tf.Mul"(%alpha, %2) : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%4 = "tf.AddV2"(%0, %3) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %4 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = tfjs.Prelu
|
||||
}
|
||||
|
||||
// CHECK-LABEL: prelu_not_fused
|
||||
// Rank of alpha should be one less than input for PReLU, which is not the case.
|
||||
func @prelu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<-0.2> : tensor<f32>
|
||||
%0 = "tf.Relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tf.Neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%2 = "tf.Relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%3 = "tf.Mul"(%alpha, %2) : (tensor<f32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%4 = "tf.AddV2"(%0, %3) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %4 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tf.Relu"
|
||||
}
|
52
tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc
Normal file
52
tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
/// Create a pass to convert from the TFExecutor to the TF control dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTFExecutorToControlDialectConversion();
|
||||
} // namespace mlir
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) {
|
||||
// Then we pass the MLIR module through the TF standard pipeline, which for
|
||||
mlir::TF::StandardPipelineOptions tf_options;
|
||||
tf_options.enable_inliner = true;
|
||||
mlir::TF::CreateTFStandardPipeline(*pm, tf_options);
|
||||
|
||||
// freeze global tensors.
|
||||
pm->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
|
||||
|
||||
// TFJS dialect passes.
|
||||
pm->addPass(mlir::tfjs::CreateOptimizePass());
|
||||
|
||||
// Canonicalize, CSE etc.
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pm->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
28
tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h
Normal file
28
tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Add the TF to TFJS passes into a pass_manager.
|
||||
void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_
|
64
tensorflow/compiler/mlir/tfjs/transforms/optimize.cc
Normal file
64
tensorflow/compiler/mlir/tfjs/transforms/optimize.cc
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass takes operations in TensorFlow dialect and
|
||||
// optimizes them to resulting operations in TensorFlow.js dialect.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tfjs {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The actual Optimize Pass.
|
||||
namespace {
|
||||
|
||||
// Optimize TFJS operations in functions.
|
||||
struct Optimize : public PassWrapper<Optimize, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/tfjs/transforms/generated_optimize.inc"
|
||||
|
||||
void Optimize::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto *ctx = &getContext();
|
||||
auto func = getFunction();
|
||||
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow.js dialect Optimize pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass() {
|
||||
return std::make_unique<Optimize>();
|
||||
}
|
||||
|
||||
static PassRegistration<Optimize> pass(
|
||||
"tfjs-optimize", "Optimize within the TensorFlow.js dialect");
|
||||
|
||||
} // namespace tfjs
|
||||
} // namespace mlir
|
49
tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td
Normal file
49
tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the optimization pattern definition file for TensorFlow.js.
|
||||
|
||||
include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
// Constraint that makes sure both operands are the same operands.
|
||||
// TODO(b/154826385): Reconsider once equal source pattern symbols are allowed.
|
||||
def EqualOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
// Checks if the operand0's rank is one less than operand1's rank.
|
||||
def PReluAlphaRankCheck : Constraint<
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() == "
|
||||
"$1.getType().cast<ShapedType>().getRank() - 1">>;
|
||||
|
||||
|
||||
// PReLU pattern from Keras:
|
||||
// f(x) = Relu(x) + (-alpha * Relu(-x))
|
||||
def : Pat<(TF_AddV2Op
|
||||
(TF_ReluOp:$relu_out $input1),
|
||||
(TF_MulOp:$mul_out
|
||||
(TF_ReluOp (TF_NegOp:$input_neg_out $input2)),
|
||||
$neg_alpha)),
|
||||
(TFJS_PReluOp $input1, (TF_NegOp $neg_alpha)),
|
||||
[(EqualOperands $input1, $input2),
|
||||
(PReluAlphaRankCheck $neg_alpha, $input1),
|
||||
(HasOneUse $relu_out),
|
||||
(HasOneUse $mul_out),
|
||||
(HasOneUse $input_neg_out)
|
||||
]>;
|
35
tensorflow/compiler/mlir/tfjs/transforms/passes.h
Normal file
35
tensorflow/compiler/mlir/tfjs/transforms/passes.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class FuncOp;
|
||||
template <typename T>
|
||||
class OperationPass;
|
||||
|
||||
namespace tfjs {
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Optimize pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
|
||||
|
||||
} // namespace tfjs
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_
|
Loading…
Reference in New Issue
Block a user