Add pass that removes Identity/IdentityN ops from the TPU computation.
Identity/IdentityN ops are not special when legalized to HLO. To reduce forwarding of values as new values, Identity/IdentityN ops are removed and their operands are forwarded to their results. This is extracted from resource op lifting. PiperOrigin-RevId: 326675870 Change-Id: Ic934f982087a252737334deb8e43c92d56575f12
This commit is contained in:
parent
6297d314fb
commit
f4bae1839c
@ -788,6 +788,7 @@ cc_library(
|
||||
"transforms/tpu_extract_head_tail_outside_compilation.cc",
|
||||
"transforms/tpu_extract_outside_compilation.cc",
|
||||
"transforms/tpu_host_computation_expansion.cc",
|
||||
"transforms/tpu_identity_pruning.cc",
|
||||
"transforms/tpu_merge_variables_with_execute.cc",
|
||||
"transforms/tpu_outside_compilation_cluster.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
|
@ -0,0 +1,93 @@
|
||||
// RUN: tf-opt %s -tf-tpu-identity-pruning | FileCheck %s --dump-input=always
|
||||
|
||||
// Tests Identity op in cluster is pruned away.
|
||||
|
||||
// CHECK-LABEL: func @testIdentity
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @testIdentity(%arg0: tensor<i32>) {
|
||||
// CHECK-NOT: "tf.Identity"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: tf_device.return [[ARG0]]
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return %1 : tensor<i32>
|
||||
}) : () -> tensor<i32>
|
||||
return
|
||||
}
|
||||
|
||||
// Tests IdentityN op in cluster is pruned away.
|
||||
|
||||
// CHECK-LABEL: func @testIdentityN
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>, [[ARG1:%.*]]: tensor<f32>)
|
||||
func @testIdentityN(%arg0: tensor<i32>, %arg1: tensor<f32>) {
|
||||
// CHECK-NOT: "tf.IdentityN"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: tf_device.return [[ARG0]], [[ARG1]]
|
||||
%0:2 = "tf_device.cluster"() ( {
|
||||
%1:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>)
|
||||
tf_device.return %1#0, %1#1 : tensor<i32>, tensor<f32>
|
||||
}) : () -> (tensor<i32>, tensor<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
// Tests transitive Identity ops reachable from the cluster are pruned away.
|
||||
|
||||
// CHECK-LABEL: func @testTransitiveIdentity
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @testTransitiveIdentity(%arg0: tensor<i32>) {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.PartitionedCall"([[ARG0]])
|
||||
// CHECK-SAME: f = @callee0
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee0} : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return %1 : tensor<i32>
|
||||
}) : () -> tensor<i32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @callee0
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @callee0(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NOT: "tf.Identity"
|
||||
// CHECK: "tf.PartitionedCall"([[ARG0]])
|
||||
// CHECK-SAME: f = @callee1
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee1} : (tensor<i32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @callee1
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @callee1(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NOT: "tf.Identity"
|
||||
// CHECK: return [[ARG0]]
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// Tests Identity ops not reachable from the cluster are not pruned away.
|
||||
|
||||
// CHECK-LABEL: func @testIdentityOutsideCluster
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @testIdentityOutsideCluster(%arg0: tensor<i32>) {
|
||||
// CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]])
|
||||
// CHECK: [[CLUSTER:%.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: tf_device.return [[IDENTITY]]
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%1 = "tf_device.cluster"() ( {
|
||||
tf_device.return %0 : tensor<i32>
|
||||
}) : () -> tensor<i32>
|
||||
// CHECK: "tf.PartitionedCall"([[CLUSTER]])
|
||||
// CHECK-SAME: f = @callee2
|
||||
%2 = "tf.PartitionedCall"(%1) {config = "", config_proto = "", executor_type = "", f = @callee2} : (tensor<i32>) -> tensor<i32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @callee2
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>)
|
||||
func @callee2(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]])
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: return [[IDENTITY]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
@ -271,6 +271,9 @@ namespace TFTPU {
|
||||
// `_tpu_replicate` attribute.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass();
|
||||
|
||||
// Creates a pass that removes Identity/IdentityN ops from a cluster.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass();
|
||||
|
||||
// Creates a pass that allows TPU program inputs to have layouts determined at
|
||||
// run time.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();
|
||||
|
@ -0,0 +1,113 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pass removes Identity/IdentityN ops from the TPU computation and
|
||||
// reachable functions.
|
||||
// TODO(lyandy): Remove this pass once resource op lifting is migrated to use
|
||||
// resource alias analysis and support region based control flow. Removing
|
||||
// Identity ops may remove `_XlaSharding` annotation attribute if Identity ops
|
||||
// are used to propagate such information.
|
||||
|
||||
struct TPUIdentityPruning
|
||||
: public PassWrapper<TPUIdentityPruning, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Collects all reachable functions (via call ops) from a given region.
|
||||
SmallVector<FuncOp, 4> CollectReachableFunctions(Region& region) {
|
||||
llvm::SmallPtrSet<FuncOp, 4> reachable_funcs;
|
||||
|
||||
auto collect_reachable_funcs =
|
||||
[&reachable_funcs](Region& src, SmallVectorImpl<FuncOp>& funcs_to_visit) {
|
||||
src.walk([&reachable_funcs, &funcs_to_visit](CallOpInterface call_op) {
|
||||
auto func = dyn_cast_or_null<FuncOp>(call_op.resolveCallable());
|
||||
if (func && reachable_funcs.insert(func).second)
|
||||
funcs_to_visit.push_back(func);
|
||||
});
|
||||
};
|
||||
|
||||
SmallVector<FuncOp, 4> funcs_to_visit;
|
||||
collect_reachable_funcs(region, funcs_to_visit);
|
||||
|
||||
while (!funcs_to_visit.empty()) {
|
||||
SmallVector<FuncOp, 4> new_funcs_to_visit;
|
||||
for (FuncOp func_to_visit : funcs_to_visit) {
|
||||
if (!func_to_visit.getCallableRegion()) continue;
|
||||
collect_reachable_funcs(*func_to_visit.getCallableRegion(),
|
||||
new_funcs_to_visit);
|
||||
}
|
||||
funcs_to_visit.swap(new_funcs_to_visit);
|
||||
}
|
||||
|
||||
return llvm::to_vector<4>(reachable_funcs);
|
||||
}
|
||||
|
||||
// Removes Identity/IdentityN ops from a region and forwards its operands to its
|
||||
// results.
|
||||
void RemoveIdentityFromRegion(Region& region) {
|
||||
region.walk([](Operation* op) {
|
||||
if (isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
|
||||
op->replaceAllUsesWith(op->getOperands());
|
||||
op->erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void TPUIdentityPruning::runOnOperation() {
|
||||
SmallVector<tf_device::ClusterOp, 4> clusters;
|
||||
getOperation().walk(
|
||||
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
|
||||
|
||||
for (tf_device::ClusterOp cluster : clusters) {
|
||||
RemoveIdentityFromRegion(cluster.body());
|
||||
auto reachable_funcs = CollectReachableFunctions(cluster.body());
|
||||
for (FuncOp reachable_func : reachable_funcs)
|
||||
RemoveIdentityFromRegion(*reachable_func.getCallableRegion());
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUIdentityPruningPass() {
|
||||
return std::make_unique<TPUIdentityPruning>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUIdentityPruning> pass(
|
||||
"tf-tpu-identity-pruning",
|
||||
"Removes Identity/IdentityN ops from the TPU computation");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
Loading…
x
Reference in New Issue
Block a user