diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 1c740731acd..b09021e8689 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -841,6 +841,7 @@ cc_library( "transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_island_coarsening.cc", "transforms/executor_tpuv1_outline_tpu_island.cc", + "transforms/fold_broadcast.cc", "transforms/fold_switch.cc", "transforms/functional_control_flow_to_cfg.cc", "transforms/functional_control_flow_to_regions.cc", @@ -930,6 +931,7 @@ cc_library( ":shape_inference_utils", ":tensorflow", ":tensorflow_analysis", + ":tensorflow_ops", ":tensorflow_optimize_inc_gen", ":tensorflow_types", ":tf_data_optimization", @@ -957,6 +959,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Parser", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir new file mode 100644 index 00000000000..afc9e1e51ed --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/fold-broadcast.mlir @@ -0,0 +1,43 @@ +// RUN: tf-opt -tf-broadcast-fold %s | FileCheck %s + +// CHECK-LABEL: @broadcast_mul0 +func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_mul1 +func @broadcast_mul1(%arg0: tensor<7xf32>, %arg1: tensor<5x7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.Mul"(%0, %arg1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.Mul"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_add_implicit_fold +func @broadcast_add_implicit_fold(%arg0: tensor<5x1xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { + %cst = constant dense<[5, 7]> : tensor<2xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tf.AddV2"(%arg0, %0) : (tensor<5x1xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + return %1 : tensor<5x7xf32> + // CHECK: %[[V0:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<5x1xf32>, tensor<7xf32>) -> tensor<5x7xf32> + // CHECK: %[[V0]] : tensor<5x7xf32> +} + +// CHECK-LABEL: @broadcast_mul_implicit_no_fold +func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32>) -> tensor<3x5x7xf32> { + %cst = constant dense<[3, 5, 7]> : tensor<3xi32> + %0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32> + %1 = "tf.Mul"(%arg0, %0) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32> + return %1 : tensor<3x5x7xf32> + // CHECK: %[[C0:.*]] = constant dense<[3, 5, 7]> : tensor<3xi32> + // CHECK: %[[V0:.*]] = "tf.BroadcastTo"(%arg1, %[[C0]]) : (tensor<5xf32>, tensor<3xi32>) -> tensor<3x5x7xf32> + // CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32> + // CHECK: %[[V1]] : tensor<3x5x7xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc new file mode 100644 index 00000000000..66311101cee --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace { + +class ConvertResultsBroadcastableShapeOp : public RewritePattern { + public: + ConvertResultsBroadcastableShapeOp() + : RewritePattern(1, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +class BroadcastFoldPass : public PassWrapper { + public: + void runOnFunction() override; +}; + +LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + if (!op->hasTrait()) return failure(); + if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1) + return failure(); + + // Check that the result shape is fully defined. + auto result_type = + op->getResultTypes().front().dyn_cast_or_null(); + if (!result_type || !result_type.hasStaticShape()) return failure(); + + for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) { + // Check that the i'th operand is a broadcast. + auto broadcast = llvm::dyn_cast_or_null( + op->getOpOperand(i).get().getDefiningOp()); + if (!broadcast) continue; + + // Check that the operand of the broadcast has fully defined shape. + auto broadcast_arg_type = + broadcast.input().getType().dyn_cast_or_null(); + if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; + + // Check that the other argument has fully defined shape. + auto argument_type = op->getOpOperand(1 - i) + .get() + .getType() + .dyn_cast_or_null(); + if (!argument_type || !argument_type.hasStaticShape()) continue; + + // Check that the input of the broadcast and the other operand is broadcast + // compatible. + llvm::SmallVector broadcasted_shape; + if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(), + argument_type.getShape(), + broadcasted_shape)) + continue; + + // Check that an implicit broadcast between the operand of the broadcast and + // the other argument would result in the same type as the result type. + if (broadcasted_shape != result_type.getShape()) continue; + + // Update the operand of the op to be the operand of the broadcast. + rewriter.updateRootInPlace( + op, [&]() { op->getOpOperand(i).set(broadcast.input()); }); + return success(); + } + + return failure(); +} + +void BroadcastFoldPass::runOnFunction() { + OwningRewritePatternList patterns; + auto func = getFunction(); + + patterns.insert(); + applyPatternsAndFoldGreedily(func, patterns); +} + +} // namespace + +namespace TF { +std::unique_ptr> CreateBroadcastFoldPass() { + return absl::make_unique(); +} +} // namespace TF + +static PassRegistration pass( + "tf-broadcast-fold", + "Fold explicit broadcasts into the following operations if they support " + "implicit broadcasting on their operand."); + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index a4ddb713ec0..4a12c80c8d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -84,6 +84,10 @@ std::unique_ptr> CreateGpuOpFusionPass(); std::unique_ptr> CreateTensorDeviceCopyConversionPass(); +// Returns a pass that folds tf.BroadcastTo nodes with subsequent nodes if they +// have built in broadcasting support. +std::unique_ptr> CreateBroadcastFoldPass(); + struct LayoutOptimizationPipelineOptions : public PassPipelineOptions { Option force_data_format{