Add pass for converting tf.BroadcastTo to implicit broadcasts

Many TF op supports implicit broadcasting. This pass tries to take advantage
of this by rewriting explicit broadcasts to use the implicit broadcasting
semantic instead what will mean that the broadcasted tensor won't be materialised.

PiperOrigin-RevId: 337820871
Change-Id: Id1f7c1f04da87ebebafe07fa56100e54c489e639
This commit is contained in:
A. Unique TensorFlower 2020-10-19 03:14:46 -07:00 committed by TensorFlower Gardener
parent dc82710fec
commit a477bd308c
4 changed files with 169 additions and 0 deletions

View File

@ -841,6 +841,7 @@ cc_library(
"transforms/executor_tpuv1_inline_tpu_island.cc",
"transforms/executor_tpuv1_island_coarsening.cc",
"transforms/executor_tpuv1_outline_tpu_island.cc",
"transforms/fold_broadcast.cc",
"transforms/fold_switch.cc",
"transforms/functional_control_flow_to_cfg.cc",
"transforms/functional_control_flow_to_regions.cc",
@ -930,6 +931,7 @@ cc_library(
":shape_inference_utils",
":tensorflow",
":tensorflow_analysis",
":tensorflow_ops",
":tensorflow_optimize_inc_gen",
":tensorflow_types",
":tf_data_optimization",
@ -957,6 +959,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Parser",

View File

@ -0,0 +1,43 @@
// RUN: tf-opt -tf-broadcast-fold %s | FileCheck %s
// CHECK-LABEL: @broadcast_mul0
func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> {
%cst = constant dense<[5, 7]> : tensor<2xi32>
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
%1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
return %1 : tensor<5x7xf32>
// CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32>
// CHECK: %[[V0]] : tensor<5x7xf32>
}
// CHECK-LABEL: @broadcast_mul1
func @broadcast_mul1(%arg0: tensor<7xf32>, %arg1: tensor<5x7xf32>) -> tensor<5x7xf32> {
%cst = constant dense<[5, 7]> : tensor<2xi32>
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
%1 = "tf.Mul"(%0, %arg1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
return %1 : tensor<5x7xf32>
// CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
// CHECK: %[[V0]] : tensor<5x7xf32>
}
// CHECK-LABEL: @broadcast_add_implicit_fold
func @broadcast_add_implicit_fold(%arg0: tensor<5x1xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> {
%cst = constant dense<[5, 7]> : tensor<2xi32>
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
%1 = "tf.AddV2"(%arg0, %0) : (tensor<5x1xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
return %1 : tensor<5x7xf32>
// CHECK: %[[V0:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<5x1xf32>, tensor<7xf32>) -> tensor<5x7xf32>
// CHECK: %[[V0]] : tensor<5x7xf32>
}
// CHECK-LABEL: @broadcast_mul_implicit_no_fold
func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32>) -> tensor<3x5x7xf32> {
%cst = constant dense<[3, 5, 7]> : tensor<3xi32>
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32>
%1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
return %1 : tensor<3x5x7xf32>
// CHECK: %[[C0:.*]] = constant dense<[3, 5, 7]> : tensor<3xi32>
// CHECK: %[[V0:.*]] = "tf.BroadcastTo"(%arg1, %[[C0]]) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32>
// CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
// CHECK: %[[V1]] : tensor<3x5x7xf32>
}

View File

@ -0,0 +1,119 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace {
class ConvertResultsBroadcastableShapeOp : public RewritePattern {
public:
ConvertResultsBroadcastableShapeOp()
: RewritePattern(1, MatchAnyOpTypeTag()) {}
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override;
};
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
public:
void runOnFunction() override;
};
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
if (!op->hasTrait<OpTrait::ResultsBroadcastableShape>()) return failure();
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
return failure();
// Check that the result shape is fully defined.
auto result_type =
op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
if (!result_type || !result_type.hasStaticShape()) return failure();
for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
// Check that the i'th operand is a broadcast.
auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
op->getOpOperand(i).get().getDefiningOp());
if (!broadcast) continue;
// Check that the operand of the broadcast has fully defined shape.
auto broadcast_arg_type =
broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
// Check that the other argument has fully defined shape.
auto argument_type = op->getOpOperand(1 - i)
.get()
.getType()
.dyn_cast_or_null<RankedTensorType>();
if (!argument_type || !argument_type.hasStaticShape()) continue;
// Check that the input of the broadcast and the other operand is broadcast
// compatible.
llvm::SmallVector<int64_t, 4> broadcasted_shape;
if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(),
argument_type.getShape(),
broadcasted_shape))
continue;
// Check that an implicit broadcast between the operand of the broadcast and
// the other argument would result in the same type as the result type.
if (broadcasted_shape != result_type.getShape()) continue;
// Update the operand of the op to be the operand of the broadcast.
rewriter.updateRootInPlace(
op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
return success();
}
return failure();
}
void BroadcastFoldPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<ConvertResultsBroadcastableShapeOp>();
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
namespace TF {
std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass() {
return absl::make_unique<BroadcastFoldPass>();
}
} // namespace TF
static PassRegistration<BroadcastFoldPass> pass(
"tf-broadcast-fold",
"Fold explicit broadcasts into the following operations if they support "
"implicit broadcasting on their operand.");
} // namespace mlir

View File

@ -84,6 +84,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
std::unique_ptr<OperationPass<mlir::FuncOp>>
CreateTensorDeviceCopyConversionPass();
// Returns a pass that folds tf.BroadcastTo nodes with subsequent nodes if they
// have built in broadcasting support.
std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass();
struct LayoutOptimizationPipelineOptions
: public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
Option<std::string> force_data_format{