From a477bd308c024490f2841e7de0a854445db96a92 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 19 Oct 2020 03:14:46 -0700 Subject: [PATCH] Add pass for converting tf.BroadcastTo to implicit broadcasts Many TF op supports implicit broadcasting. This pass tries to take advantage of this by rewriting explicit broadcasts to use the implicit broadcasting semantic instead what will mean that the broadcasted tensor won't be materialised. PiperOrigin-RevId: 337820871 Change-Id: Id1f7c1f04da87ebebafe07fa56100e54c489e639 --- tensorflow/compiler/mlir/tensorflow/BUILD | 3 + .../mlir/tensorflow/tests/fold-broadcast.mlir | 43 +++++++ .../tensorflow/transforms/fold_broadcast.cc | 119 ++++++++++++++++++ .../mlir/tensorflow/transforms/passes.h | 4 + 4 files changed, 169 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 1c740731acd..b09021e8689 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -841,6 +841,7 @@ cc_library( "transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_island_coarsening.cc", "transforms/executor_tpuv1_outline_tpu_island.cc", + "transforms/fold_broadcast.cc", "transforms/fold_switch.cc", "transforms/functional_control_flow_to_cfg.cc", "transforms/functional_control_flow_to_regions.cc", @@ -930,6 +931,7 @@ cc_library( ":shape_inference_utils", ":tensorflow", ":tensorflow_analysis", + ":tensorflow_ops", ":tensorflow_optimize_inc_gen", ":tensorflow_types", ":tf_data_optimization", @@ -957,6 +959,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Parser", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir new file mode 100644 index 00000000000..afc9e1e51ed --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir @@ -0,0 +1,43 @@ +// RUN: tf-opt -tf-broadcast-fold %s | FileCheck %s + +// CHECK-LABEL: @broadcast_mul0 +func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_mul1 +func @broadcast_mul1(%arg0: tensor<7xf32>, %arg1: tensor<5x7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.Mul"(%0, %arg1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_add_implicit_fold +func @broadcast_add_implicit_fold(%arg0: tensor<5x1xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.AddV2"(%arg0, %0) : (tensor<5x1xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<5x1xf32>, tensor<7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_mul_implicit_no_fold +func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32>) -> tensor<3x5x7xf32> { + %cst = constant dense<[3, 5, 7]> : tensor<3xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32> + %1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32> + return %1 : tensor<3x5x7xf32> + // CHECK: %[[C0:.*]] = constant dense<[3, 5, 7]> : tensor<3xi32> + // CHECK: %[[V0:.*]] = "tf.BroadcastTo"(%arg1, %[[C0]]) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32> + // CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32> + // CHECK: %[[V1]] : tensor<3x5x7xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc new file mode 100644 index 00000000000..66311101cee --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace { + +class ConvertResultsBroadcastableShapeOp : public RewritePattern { + public: + ConvertResultsBroadcastableShapeOp() + : RewritePattern(1, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +class BroadcastFoldPass : public PassWrapper { + public: + void runOnFunction() override; +}; + +LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + if (!op->hasTrait()) return failure(); + if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1) + return failure(); + + // Check that the result shape is fully defined. + auto result_type = + op->getResultTypes().front().dyn_cast_or_null(); + if (!result_type || !result_type.hasStaticShape()) return failure(); + + for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) { + // Check that the i'th operand is a broadcast. + auto broadcast = llvm::dyn_cast_or_null( + op->getOpOperand(i).get().getDefiningOp()); + if (!broadcast) continue; + + // Check that the operand of the broadcast has fully defined shape. + auto broadcast_arg_type = + broadcast.input().getType().dyn_cast_or_null(); + if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; + + // Check that the other argument has fully defined shape. + auto argument_type = op->getOpOperand(1 - i) + .get() + .getType() + .dyn_cast_or_null(); + if (!argument_type || !argument_type.hasStaticShape()) continue; + + // Check that the input of the broadcast and the other operand is broadcast + // compatible. + llvm::SmallVector broadcasted_shape; + if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(), + argument_type.getShape(), + broadcasted_shape)) + continue; + + // Check that an implicit broadcast between the operand of the broadcast and + // the other argument would result in the same type as the result type. + if (broadcasted_shape != result_type.getShape()) continue; + + // Update the operand of the op to be the operand of the broadcast. + rewriter.updateRootInPlace( + op, [&]() { op->getOpOperand(i).set(broadcast.input()); }); + return success(); + } + + return failure(); +} + +void BroadcastFoldPass::runOnFunction() { + OwningRewritePatternList patterns; + auto func = getFunction(); + + patterns.insert(); + applyPatternsAndFoldGreedily(func, patterns); +} + +} // namespace + +namespace TF { +std::unique_ptr> CreateBroadcastFoldPass() { + return absl::make_unique(); +} +} // namespace TF + +static PassRegistration pass( + "tf-broadcast-fold", + "Fold explicit broadcasts into the following operations if they support " + "implicit broadcasting on their operand."); + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index a4ddb713ec0..4a12c80c8d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -84,6 +84,10 @@ std::unique_ptr> CreateGpuOpFusionPass(); std::unique_ptr> CreateTensorDeviceCopyConversionPass(); +// Returns a pass that folds tf.BroadcastTo nodes with subsequent nodes if they +// have built in broadcasting support. +std::unique_ptr> CreateBroadcastFoldPass(); + struct LayoutOptimizationPipelineOptions : public PassPipelineOptions { Option force_data_format{