Replace _TPUCompileMlir placeholder ops with correct compile op in parallel_execute regions.

When adding parallel_execute regions for outside compilation, _TPUCompileMlir placeholder ops are generated since the _TPUCompileMlir op is not created until this pass.  This change replaces those placeholder ops with the newly created _TPUCompileMlir op.

PiperOrigin-RevId: 312521930
Change-Id: I2136c9569ea853875397a83dc40eebb7db004a4d
This commit is contained in:
Ken Franko 2020-05-20 11:34:25 -07:00 committed by TensorFlower Gardener
parent 0992a65a5d
commit 8d51ed5895
2 changed files with 24 additions and 1 deletions

View File

@ -1234,16 +1234,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK: "tf._TPUCompileMlir"
// CHECK: "tf.TPUCompileSucceededAssert"
// CHECK: "tf_device.parallel_execute"
// CHECK-NOT:"tf._TPUCompileMlir"
// CHECK: "tf.D"(%[[COMPILE_OUTPUT]]#1
// CHECK: "tf.TPUExecute"
// CHECK-NOT:"tf._TPUCompileMlir"
// CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1
%3 = "tf_device.parallel_execute"() ( {
"tf.D"() : () -> ()
%status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf.D"(%program) : (tensor<!tf.string>) -> ()
tf_device.return
}, {
%4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}, {
%status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf.E"(%program) : (tensor<!tf.string>) -> ()
tf_device.return
}) : () -> (tensor<?xi32>)
tf_device.return %3 : tensor<?xi32>
}

View File

@ -701,6 +701,19 @@ LogicalResult Rewrite(
std::move(tpu_device_assignment.xla_device_assignment), builder);
if (!compile_op) return failure();
// This replaces _TPUCompileMlir placeholder ops that are required
// by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass.
// TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp
// and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
// structured lowering.
if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
cluster_func.getParentOp())) {
parallel_op.walk([&](TF::_TPUCompileMlirOp parallel_compile_op) {
parallel_compile_op.replaceAllUsesWith(compile_op);
parallel_compile_op.erase();
});
}
// After rewrite, find if there is a TPUCompilationResultOp in the block with
// the same _tpu_replicate attribute and replace it with the result of the
// compile op. This op is used as a placeholder to hook during graph creation