Handle Array of SymbolRef attributes during tf-executor-tpu-v1-island-outlining pass.
PiperOrigin-RevId: 313671791 Change-Id: I088c53bb45df7f1f9a6284f0dc50173c91bf1b98
This commit is contained in:
parent
1e22a99527
commit
b1cb3f12da
@ -0,0 +1,47 @@
|
||||
// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK: func @control_input
|
||||
// CHECK-NOT: func @
|
||||
// CHECK-LABEL: module @_tpu_v1_compat_outlined
|
||||
// CHECK: @_tpu_v1_compat_outlined_func0
|
||||
// CHECK: func @branch_0
|
||||
// CHECK: func @branch_1
|
||||
// CHECK: func @branch_2
|
||||
// CHECK: func @branch_3
|
||||
// CHECK: func @branch_4
|
||||
module {
|
||||
func @control_input(%arg0: tensor<i1>) -> tensor<i32> {
|
||||
%0 = tf_executor.graph {
|
||||
%output, %control = tf_executor.island {
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> ()
|
||||
%index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i1>) -> tensor<i32>
|
||||
%result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_executor.yield %result : tensor<i32>
|
||||
}
|
||||
tf_executor.fetch %output : tensor<i32>
|
||||
|
||||
}
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @branch_0(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @branch_1(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @branch_2(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @branch_3(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
func @branch_4(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
}
|
@ -49,6 +49,16 @@ struct TPUBridgeExecutorIslandOutlining
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Move FuncOp referenced by `symbol_ref` from one symbol table to another.
|
||||
void MoveFuncOp(FlatSymbolRefAttr &symbol_ref, SymbolTable &from,
|
||||
SymbolTable &to) {
|
||||
if (to.lookup<FuncOp>(symbol_ref.getValue())) return;
|
||||
FuncOp callee = from.lookup<FuncOp>(symbol_ref.getValue());
|
||||
callee.getOperation()->getBlock()->getOperations().remove(
|
||||
callee.getOperation());
|
||||
to.insert(callee);
|
||||
}
|
||||
|
||||
void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
@ -141,14 +151,17 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
||||
for (FuncOp func : outlined_module.getOps<FuncOp>()) {
|
||||
func.walk([&](Operation *op) {
|
||||
for (NamedAttribute attr : op->getAttrs()) {
|
||||
auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
|
||||
if (!symbol_ref) continue;
|
||||
if (outlined_symbol_table.lookup<FuncOp>(symbol_ref.getValue()))
|
||||
if (auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
|
||||
MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
|
||||
continue;
|
||||
FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
|
||||
callee.getOperation()->getBlock()->getOperations().remove(
|
||||
callee.getOperation());
|
||||
outlined_symbol_table.insert(callee);
|
||||
}
|
||||
if (auto array_attr = attr.second.dyn_cast<ArrayAttr>()) {
|
||||
for (const Attribute &attribute : array_attr) {
|
||||
auto symbol_ref = attribute.dyn_cast<FlatSymbolRefAttr>();
|
||||
if (!symbol_ref) continue;
|
||||
MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user