Add pass for converting tf.BroadcastTo to implicit broadcasts
Many TF op supports implicit broadcasting. This pass tries to take advantage of this by rewriting explicit broadcasts to use the implicit broadcasting semantic instead what will mean that the broadcasted tensor won't be materialised. PiperOrigin-RevId: 337820871 Change-Id: Id1f7c1f04da87ebebafe07fa56100e54c489e639
This commit is contained in:
parent
dc82710fec
commit
a477bd308c
@ -841,6 +841,7 @@ cc_library(
|
||||
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
||||
"transforms/executor_tpuv1_island_coarsening.cc",
|
||||
"transforms/executor_tpuv1_outline_tpu_island.cc",
|
||||
"transforms/fold_broadcast.cc",
|
||||
"transforms/fold_switch.cc",
|
||||
"transforms/functional_control_flow_to_cfg.cc",
|
||||
"transforms/functional_control_flow_to_regions.cc",
|
||||
@ -930,6 +931,7 @@ cc_library(
|
||||
":shape_inference_utils",
|
||||
":tensorflow",
|
||||
":tensorflow_analysis",
|
||||
":tensorflow_ops",
|
||||
":tensorflow_optimize_inc_gen",
|
||||
":tensorflow_types",
|
||||
":tf_data_optimization",
|
||||
@ -957,6 +959,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:Parser",
|
||||
|
@ -0,0 +1,43 @@
|
||||
// RUN: tf-opt -tf-broadcast-fold %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @broadcast_mul0
|
||||
func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> {
|
||||
%cst = constant dense<[5, 7]> : tensor<2xi32>
|
||||
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
|
||||
%1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
|
||||
return %1 : tensor<5x7xf32>
|
||||
// CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32>
|
||||
// CHECK: %[[V0]] : tensor<5x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_mul1
|
||||
func @broadcast_mul1(%arg0: tensor<7xf32>, %arg1: tensor<5x7xf32>) -> tensor<5x7xf32> {
|
||||
%cst = constant dense<[5, 7]> : tensor<2xi32>
|
||||
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
|
||||
%1 = "tf.Mul"(%0, %arg1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
|
||||
return %1 : tensor<5x7xf32>
|
||||
// CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
|
||||
// CHECK: %[[V0]] : tensor<5x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_add_implicit_fold
|
||||
func @broadcast_add_implicit_fold(%arg0: tensor<5x1xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> {
|
||||
%cst = constant dense<[5, 7]> : tensor<2xi32>
|
||||
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) : (tensor<5x1xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
|
||||
return %1 : tensor<5x7xf32>
|
||||
// CHECK: %[[V0:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<5x1xf32>, tensor<7xf32>) -> tensor<5x7xf32>
|
||||
// CHECK: %[[V0]] : tensor<5x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_mul_implicit_no_fold
|
||||
func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32>) -> tensor<3x5x7xf32> {
|
||||
%cst = constant dense<[3, 5, 7]> : tensor<3xi32>
|
||||
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32>
|
||||
%1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
|
||||
return %1 : tensor<3x5x7xf32>
|
||||
// CHECK: %[[C0:.*]] = constant dense<[3, 5, 7]> : tensor<3xi32>
|
||||
// CHECK: %[[V0:.*]] = "tf.BroadcastTo"(%arg1, %[[C0]]) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32>
|
||||
// CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
|
||||
// CHECK: %[[V1]] : tensor<3x5x7xf32>
|
||||
}
|
119
tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
Normal file
119
tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
class ConvertResultsBroadcastableShapeOp : public RewritePattern {
|
||||
public:
|
||||
ConvertResultsBroadcastableShapeOp()
|
||||
: RewritePattern(1, MatchAnyOpTypeTag()) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
||||
public:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
if (!op->hasTrait<OpTrait::ResultsBroadcastableShape>()) return failure();
|
||||
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
|
||||
return failure();
|
||||
|
||||
// Check that the result shape is fully defined.
|
||||
auto result_type =
|
||||
op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!result_type || !result_type.hasStaticShape()) return failure();
|
||||
|
||||
for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
|
||||
// Check that the i'th operand is a broadcast.
|
||||
auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
|
||||
op->getOpOperand(i).get().getDefiningOp());
|
||||
if (!broadcast) continue;
|
||||
|
||||
// Check that the operand of the broadcast has fully defined shape.
|
||||
auto broadcast_arg_type =
|
||||
broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
|
||||
|
||||
// Check that the other argument has fully defined shape.
|
||||
auto argument_type = op->getOpOperand(1 - i)
|
||||
.get()
|
||||
.getType()
|
||||
.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!argument_type || !argument_type.hasStaticShape()) continue;
|
||||
|
||||
// Check that the input of the broadcast and the other operand is broadcast
|
||||
// compatible.
|
||||
llvm::SmallVector<int64_t, 4> broadcasted_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(),
|
||||
argument_type.getShape(),
|
||||
broadcasted_shape))
|
||||
continue;
|
||||
|
||||
// Check that an implicit broadcast between the operand of the broadcast and
|
||||
// the other argument would result in the same type as the result type.
|
||||
if (broadcasted_shape != result_type.getShape()) continue;
|
||||
|
||||
// Update the operand of the op to be the operand of the broadcast.
|
||||
rewriter.updateRootInPlace(
|
||||
op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
void BroadcastFoldPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertResultsBroadcastableShapeOp>();
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace TF {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass() {
|
||||
return absl::make_unique<BroadcastFoldPass>();
|
||||
}
|
||||
} // namespace TF
|
||||
|
||||
static PassRegistration<BroadcastFoldPass> pass(
|
||||
"tf-broadcast-fold",
|
||||
"Fold explicit broadcasts into the following operations if they support "
|
||||
"implicit broadcasting on their operand.");
|
||||
|
||||
} // namespace mlir
|
@ -84,6 +84,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
|
||||
std::unique_ptr<OperationPass<mlir::FuncOp>>
|
||||
CreateTensorDeviceCopyConversionPass();
|
||||
|
||||
// Returns a pass that folds tf.BroadcastTo nodes with subsequent nodes if they
|
||||
// have built in broadcasting support.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass();
|
||||
|
||||
struct LayoutOptimizationPipelineOptions
|
||||
: public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
|
||||
Option<std::string> force_data_format{
|
||||
|
Loading…
x
Reference in New Issue
Block a user