Handle Array of SymbolRef attributes during tf-executor-tpu-v1-island-outlining pass.
PiperOrigin-RevId: 313671791 Change-Id: I088c53bb45df7f1f9a6284f0dc50173c91bf1b98
This commit is contained in:
parent
1e22a99527
commit
b1cb3f12da
@ -0,0 +1,47 @@
|
|||||||
|
// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK: func @control_input
|
||||||
|
// CHECK-NOT: func @
|
||||||
|
// CHECK-LABEL: module @_tpu_v1_compat_outlined
|
||||||
|
// CHECK: @_tpu_v1_compat_outlined_func0
|
||||||
|
// CHECK: func @branch_0
|
||||||
|
// CHECK: func @branch_1
|
||||||
|
// CHECK: func @branch_2
|
||||||
|
// CHECK: func @branch_3
|
||||||
|
// CHECK: func @branch_4
|
||||||
|
module {
|
||||||
|
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
|
||||||
|
%0 = tf_executor.graph {
|
||||||
|
%output, %control = tf_executor.island {
|
||||||
|
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
|
||||||
|
%index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||||
|
%input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||||
|
%result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
|
tf_executor.yield %result : tensor<i32>
|
||||||
|
}
|
||||||
|
tf_executor.fetch %output : tensor<i32>
|
||||||
|
|
||||||
|
}
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
func @branch_0(%arg0: tensor<i32>) -> tensor<i32> {
|
||||||
|
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
func @branch_1(%arg0: tensor<i32>) -> tensor<i32> {
|
||||||
|
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
func @branch_2(%arg0: tensor<i32>) -> tensor<i32> {
|
||||||
|
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
func @branch_3(%arg0: tensor<i32>) -> tensor<i32> {
|
||||||
|
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
func @branch_4(%arg0: tensor<i32>) -> tensor<i32> {
|
||||||
|
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
}
|
@ -49,6 +49,16 @@ struct TPUBridgeExecutorIslandOutlining
|
|||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Move FuncOp referenced by `symbol_ref` from one symbol table to another.
|
||||||
|
void MoveFuncOp(FlatSymbolRefAttr &symbol_ref, SymbolTable &from,
|
||||||
|
SymbolTable &to) {
|
||||||
|
if (to.lookup<FuncOp>(symbol_ref.getValue())) return;
|
||||||
|
FuncOp callee = from.lookup<FuncOp>(symbol_ref.getValue());
|
||||||
|
callee.getOperation()->getBlock()->getOperations().remove(
|
||||||
|
callee.getOperation());
|
||||||
|
to.insert(callee);
|
||||||
|
}
|
||||||
|
|
||||||
void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
||||||
MLIRContext *ctx = &getContext();
|
MLIRContext *ctx = &getContext();
|
||||||
|
|
||||||
@ -141,14 +151,17 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
|||||||
for (FuncOp func : outlined_module.getOps<FuncOp>()) {
|
for (FuncOp func : outlined_module.getOps<FuncOp>()) {
|
||||||
func.walk([&](Operation *op) {
|
func.walk([&](Operation *op) {
|
||||||
for (NamedAttribute attr : op->getAttrs()) {
|
for (NamedAttribute attr : op->getAttrs()) {
|
||||||
auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
|
if (auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
|
||||||
if (!symbol_ref) continue;
|
MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
|
||||||
if (outlined_symbol_table.lookup<FuncOp>(symbol_ref.getValue()))
|
|
||||||
continue;
|
continue;
|
||||||
FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
|
}
|
||||||
callee.getOperation()->getBlock()->getOperations().remove(
|
if (auto array_attr = attr.second.dyn_cast<ArrayAttr>()) {
|
||||||
callee.getOperation());
|
for (const Attribute &attribute : array_attr) {
|
||||||
outlined_symbol_table.insert(callee);
|
auto symbol_ref = attribute.dyn_cast<FlatSymbolRefAttr>();
|
||||||
|
if (!symbol_ref) continue;
|
||||||
|
MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user