Add pass that removes Identity/IdentityN ops from the TPU computation.

Identity/IdentityN ops are not special when legalized to HLO. To reduce forwarding of values as new values, Identity/IdentityN ops are removed and their operands are forwarded to their results. This is extracted from resource op lifting.

PiperOrigin-RevId: 326675870
Change-Id: Ic934f982087a252737334deb8e43c92d56575f12
This commit is contained in:
Andy Ly 2020-08-14 10:00:48 -07:00 committed by TensorFlower Gardener
parent 6297d314fb
commit f4bae1839c
4 changed files with 210 additions and 0 deletions

View File

@ -788,6 +788,7 @@ cc_library(
"transforms/tpu_extract_head_tail_outside_compilation.cc",
"transforms/tpu_extract_outside_compilation.cc",
"transforms/tpu_host_computation_expansion.cc",
"transforms/tpu_identity_pruning.cc",
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_outside_compilation_cluster.cc",
"transforms/tpu_rewrite_pass.cc",

View File

@ -0,0 +1,93 @@
// RUN: tf-opt %s -tf-tpu-identity-pruning | FileCheck %s --dump-input=always
// Tests Identity op in cluster is pruned away.
// CHECK-LABEL: func @testIdentity
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @testIdentity(%arg0: tensor<i32>) {
// CHECK-NOT: "tf.Identity"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: tf_device.return [[ARG0]]
%0 = "tf_device.cluster"() ( {
%1 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %1 : tensor<i32>
}) : () -> tensor<i32>
return
}
// Tests IdentityN op in cluster is pruned away.
// CHECK-LABEL: func @testIdentityN
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>, [[ARG1:%.*]]: tensor<f32>)
func @testIdentityN(%arg0: tensor<i32>, %arg1: tensor<f32>) {
// CHECK-NOT: "tf.IdentityN"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: tf_device.return [[ARG0]], [[ARG1]]
%0:2 = "tf_device.cluster"() ( {
%1:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
tf_device.return %1#0, %1#1 : tensor<i32>, tensor<f32>
}) : () -> (tensor<i32>, tensor<f32>)
return
}
// Tests transitive Identity ops reachable from the cluster are pruned away.
// CHECK-LABEL: func @testTransitiveIdentity
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @testTransitiveIdentity(%arg0: tensor<i32>) {
// CHECK: "tf_device.cluster"
// CHECK: "tf.PartitionedCall"([[ARG0]])
// CHECK-SAME: f = @callee0
%0 = "tf_device.cluster"() ( {
%1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee0} : (tensor<i32>) -> tensor<i32>
tf_device.return %1 : tensor<i32>
}) : () -> tensor<i32>
return
}
// CHECK-LABEL: func @callee0
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @callee0(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-NOT: "tf.Identity"
// CHECK: "tf.PartitionedCall"([[ARG0]])
// CHECK-SAME: f = @callee1
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee1} : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// CHECK-LABEL: func @callee1
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @callee1(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK-NOT: "tf.Identity"
// CHECK: return [[ARG0]]
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
// Tests Identity ops not reachable from the cluster are not pruned away.
// CHECK-LABEL: func @testIdentityOutsideCluster
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @testIdentityOutsideCluster(%arg0: tensor<i32>) {
// CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]])
// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"
// CHECK-NEXT: tf_device.return [[IDENTITY]]
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%1 = "tf_device.cluster"() ( {
tf_device.return %0 : tensor<i32>
}) : () -> tensor<i32>
// CHECK: "tf.PartitionedCall"([[CLUSTER]])
// CHECK-SAME: f = @callee2
%2 = "tf.PartitionedCall"(%1) {config = "", config_proto = "", executor_type = "", f = @callee2} : (tensor<i32>) -> tensor<i32>
return
}
// CHECK-LABEL: func @callee2
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
func @callee2(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]])
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
// CHECK: return [[IDENTITY]]
return %0 : tensor<i32>
}

View File

@ -271,6 +271,9 @@ namespace TFTPU {
// `_tpu_replicate` attribute.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass();
// Creates a pass that removes Identity/IdentityN ops from a cluster.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass();
// Creates a pass that allows TPU program inputs to have layouts determined at
// run time.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();

View File

@ -0,0 +1,113 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <tuple>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFTPU {
namespace {
// This pass removes Identity/IdentityN ops from the TPU computation and
// reachable functions.
// TODO(lyandy): Remove this pass once resource op lifting is migrated to use
// resource alias analysis and support region based control flow. Removing
// Identity ops may remove `_XlaSharding` annotation attribute if Identity ops
// are used to propagate such information.
struct TPUIdentityPruning
: public PassWrapper<TPUIdentityPruning, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
// Collects all reachable functions (via call ops) from a given region.
SmallVector<FuncOp, 4> CollectReachableFunctions(Region& region) {
llvm::SmallPtrSet<FuncOp, 4> reachable_funcs;
auto collect_reachable_funcs =
[&reachable_funcs](Region& src, SmallVectorImpl<FuncOp>& funcs_to_visit) {
src.walk([&reachable_funcs, &funcs_to_visit](CallOpInterface call_op) {
auto func = dyn_cast_or_null<FuncOp>(call_op.resolveCallable());
if (func && reachable_funcs.insert(func).second)
funcs_to_visit.push_back(func);
});
};
SmallVector<FuncOp, 4> funcs_to_visit;
collect_reachable_funcs(region, funcs_to_visit);
while (!funcs_to_visit.empty()) {
SmallVector<FuncOp, 4> new_funcs_to_visit;
for (FuncOp func_to_visit : funcs_to_visit) {
if (!func_to_visit.getCallableRegion()) continue;
collect_reachable_funcs(*func_to_visit.getCallableRegion(),
new_funcs_to_visit);
}
funcs_to_visit.swap(new_funcs_to_visit);
}
return llvm::to_vector<4>(reachable_funcs);
}
// Removes Identity/IdentityN ops from a region and forwards its operands to its
// results.
void RemoveIdentityFromRegion(Region& region) {
region.walk([](Operation* op) {
if (isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
op->replaceAllUsesWith(op->getOperands());
op->erase();
}
});
}
void TPUIdentityPruning::runOnOperation() {
SmallVector<tf_device::ClusterOp, 4> clusters;
getOperation().walk(
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
for (tf_device::ClusterOp cluster : clusters) {
RemoveIdentityFromRegion(cluster.body());
auto reachable_funcs = CollectReachableFunctions(cluster.body());
for (FuncOp reachable_func : reachable_funcs)
RemoveIdentityFromRegion(*reachable_func.getCallableRegion());
}
}
} // anonymous namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass() {
return std::make_unique<TPUIdentityPruning>();
}
static PassRegistration<TPUIdentityPruning> pass(
"tf-tpu-identity-pruning",
"Removes Identity/IdentityN ops from the TPU computation");
} // namespace TFTPU
} // namespace mlir