From 32b5152f3543b891e3027912483eee7e2b0d319d Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 3 Mar 2021 19:26:27 -0800 Subject: [PATCH] Update check for compilation result to consult the correct attribute The _tpu_replicate attribute for the ops being clustered needs to be matched with the _tpu_compilation_status attribute of the compilation result. PiperOrigin-RevId: 360812812 Change-Id: I0bcf5c8617d728b3227616214929e56ee203014b --- .../mlir/tensorflow/ir/tf_generated_ops.td | 2 +- .../mlir/tensorflow/tests/tpu_rewrite.mlir | 6 +- .../tensorflow/transforms/tpu_rewrite_pass.cc | 61 ++++++++++++++----- 3 files changed, 52 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index a7b28a83a86..4acde4df91e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -15674,7 +15674,7 @@ Computes the gradient function for function f via backpropagation. TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; } -def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> { +def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", []> { let summary = "Returns the result of a TPU compilation."; let description = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 23583d9d877..85d51681671 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1364,14 +1364,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: %[[COMPILE_RESULT_0:.*]] = "tf.Identity"(%[[COMPILE_OUTPUT]]#0) + // CHECK: %[[COMPILE_RESULT_1:.*]] = "tf.Identity"(%[[COMPILE_RESULT_0]]) // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUCompileSucceededAssert" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute" %1 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor - %compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor - %compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor + %compile_result = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster0"} : () -> tensor + %compile_result2 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster0"} : () -> tensor // CHECK-NOT: "tf.TPUCompilationResult" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 3027031e8be..81db953c8f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -556,16 +557,18 @@ tf_device::LaunchOp AssignDevicesToReplicatedExecute( // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation // status of `compile_op` to check whether compilation is successful. void BuildTPUCompileSucceededAssertOp(Operation* compile_op, + Operation* result_id, llvm::StringRef compilation_device, OpBuilder* builder) { auto assert_op = builder->create( - compile_op->getLoc(), compile_op->getResult(0)); + compile_op->getLoc(), result_id->getResult(0)); WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device); } LogicalResult Rewrite( tf_device::ClusterFuncOp cluster_func, llvm::ArrayRef devices, + ArrayRef compilation_result, OpBuilder* builder) { // Collect `num_replicas` and `num_cores_per_replica` attributes. int num_replicas = 1; @@ -645,16 +648,22 @@ LogicalResult Rewrite( }); } - // After rewrite, find if there is a TPUCompilationResultOp in the block with - // the same _tpu_replicate attribute and replace it with the result of the - // compile op. This op is used as a placeholder to hook during graph creation - // the other ops that are intended to consume the compile result. - Block* block = cluster_func.getOperation()->getBlock(); - for (auto compile_result_op : block->getOps()) - compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0)); + // After rewrite, if there is a TPUCompilationResultOp from the same cluster, + // replace it with the result of the compile op. The TPUCompilationResultOp is + // used as a placeholder to hook during graph creation the other ops that are + // intended to consume the compile result. + Operation* result_id = compile_op; + for (auto res : compilation_result) { + // Build identity op with the same location/name as the original compilation + // result op. + result_id = builder->create( + res.getLoc(), compile_op->getResult(0).getType(), + result_id->getResult(0)); + res.output().replaceAllUsesWith(compile_op->getResult(0)); + } BuildTPUCompileSucceededAssertOp( - compile_op, tpu_device_assignment.compilation_device, builder); + compile_op, result_id, tpu_device_assignment.compilation_device, builder); AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices, builder); @@ -733,26 +742,50 @@ void TPURewritePass::runOnOperation() { if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices))) return signalPassFailure(); + // Collect compilation results. + llvm::DenseMap> + compilation_results; + auto result_init = getOperation().walk([&](TF::TPUCompilationResultOp op) { + auto cluster_id = op->getAttrOfType("_tpu_compilation_status"); + if (!cluster_id) { + op->emitOpError("missing '_tpu_compilation_status'"); + return WalkResult::interrupt(); + } + compilation_results[cluster_id].push_back(op); + return WalkResult::advance(); + }); + if (result_init.wasInterrupted()) return signalPassFailure(); + llvm::SmallVector to_be_erased; OpBuilder builder(&getContext()); auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) { // Skip non-tpu device cluster_func. - auto replicate_attr = op->getAttrOfType("_tpu_replicate"); - if (!replicate_attr) return WalkResult::advance(); + auto cluster_id = op->getAttrOfType("_tpu_replicate"); + if (!cluster_id) return WalkResult::advance(); - if (failed(Rewrite(op, devices.device_names(), &builder))) + if (failed(Rewrite(op, devices.device_names(), + compilation_results[cluster_id], &builder))) return WalkResult::interrupt(); to_be_erased.push_back(op); return WalkResult::advance(); }); - if (result.wasInterrupted()) return signalPassFailure(); EraseClusterFuncs(to_be_erased); // Eliminate TPUCompilationResultOp now that the rewrite is complete. - getOperation().walk([&](TF::TPUCompilationResultOp op) { op.erase(); }); + for (auto& it : compilation_results) { + for (auto op : it.second) { + if (!op.use_empty()) { + mlir::InFlightDiagnostic err = op.emitError("uses remain post rewrite"); + for (auto user : op->getUsers()) + err.attachNote(user->getLoc()) << "remaining user"; + return signalPassFailure(); + } + op.erase(); + } + } // TODO(b/139377366): Remove functions that are no longer needed. }