Replace _TPUCompileMlir placeholder ops with correct compile op in parallel_execute regions.
When adding parallel_execute regions for outside compilation, _TPUCompileMlir placeholder ops are generated since the _TPUCompileMlir op is not created until this pass. This change replaces those placeholder ops with the newly created _TPUCompileMlir op. PiperOrigin-RevId: 312521930 Change-Id: I2136c9569ea853875397a83dc40eebb7db004a4d
This commit is contained in:
parent
0992a65a5d
commit
8d51ed5895
@ -1234,16 +1234,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: "tf._TPUCompileMlir"
|
||||
// CHECK: "tf.TPUCompileSucceededAssert"
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NOT:"tf._TPUCompileMlir"
|
||||
// CHECK: "tf.D"(%[[COMPILE_OUTPUT]]#1
|
||||
// CHECK: "tf.TPUExecute"
|
||||
// CHECK-NOT:"tf._TPUCompileMlir"
|
||||
// CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1
|
||||
%3 = "tf_device.parallel_execute"() ( {
|
||||
"tf.D"() : () -> ()
|
||||
%status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf.D"(%program) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}, {
|
||||
%4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}, {
|
||||
%status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf.E"(%program) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) : () -> (tensor<?xi32>)
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}
|
||||
|
@ -701,6 +701,19 @@ LogicalResult Rewrite(
|
||||
std::move(tpu_device_assignment.xla_device_assignment), builder);
|
||||
if (!compile_op) return failure();
|
||||
|
||||
// This replaces _TPUCompileMlir placeholder ops that are required
|
||||
// by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass.
|
||||
// TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp
|
||||
// and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
|
||||
// structured lowering.
|
||||
if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
||||
cluster_func.getParentOp())) {
|
||||
parallel_op.walk([&](TF::_TPUCompileMlirOp parallel_compile_op) {
|
||||
parallel_compile_op.replaceAllUsesWith(compile_op);
|
||||
parallel_compile_op.erase();
|
||||
});
|
||||
}
|
||||
|
||||
// After rewrite, find if there is a TPUCompilationResultOp in the block with
|
||||
// the same _tpu_replicate attribute and replace it with the result of the
|
||||
// compile op. This op is used as a placeholder to hook during graph creation
|
||||
|
Loading…
x
Reference in New Issue
Block a user