Update check for compilation result to consult the correct attribute

The _tpu_replicate attribute for the ops being clustered needs to be matched with the _tpu_compilation_status attribute of the compilation result.

PiperOrigin-RevId: 360812812
Change-Id: I0bcf5c8617d728b3227616214929e56ee203014b
This commit is contained in:
Jacques Pienaar 2021-03-03 19:26:27 -08:00 committed by TensorFlower Gardener
parent d2d7787e19
commit 32b5152f35
3 changed files with 52 additions and 17 deletions

View File

@ -15674,7 +15674,7 @@ Computes the gradient function for function f via backpropagation.
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
}
def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> {
def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", []> {
let summary = "Returns the result of a TPU compilation.";
let description = [{

View File

@ -1364,14 +1364,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: %[[COMPILE_RESULT_0:.*]] = "tf.Identity"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[COMPILE_RESULT_1:.*]] = "tf.Identity"(%[[COMPILE_RESULT_0]])
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"
%1 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
%compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
%compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
%compile_result = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster0"} : () -> tensor<!tf.string>
%compile_result2 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster0"} : () -> tensor<!tf.string>
// CHECK-NOT: "tf.TPUCompilationResult"

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@ -556,16 +557,18 @@ tf_device::LaunchOp AssignDevicesToReplicatedExecute(
// Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
// status of `compile_op` to check whether compilation is successful.
void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
Operation* result_id,
llvm::StringRef compilation_device,
OpBuilder* builder) {
auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
compile_op->getLoc(), compile_op->getResult(0));
compile_op->getLoc(), result_id->getResult(0));
WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
}
LogicalResult Rewrite(
tf_device::ClusterFuncOp cluster_func,
llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
ArrayRef<TF::TPUCompilationResultOp> compilation_result,
OpBuilder* builder) {
// Collect `num_replicas` and `num_cores_per_replica` attributes.
int num_replicas = 1;
@ -645,16 +648,22 @@ LogicalResult Rewrite(
});
}
// After rewrite, find if there is a TPUCompilationResultOp in the block with
// the same _tpu_replicate attribute and replace it with the result of the
// compile op. This op is used as a placeholder to hook during graph creation
// the other ops that are intended to consume the compile result.
Block* block = cluster_func.getOperation()->getBlock();
for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
// After rewrite, if there is a TPUCompilationResultOp from the same cluster,
// replace it with the result of the compile op. The TPUCompilationResultOp is
// used as a placeholder to hook during graph creation the other ops that are
// intended to consume the compile result.
Operation* result_id = compile_op;
for (auto res : compilation_result) {
// Build identity op with the same location/name as the original compilation
// result op.
result_id = builder->create<TF::IdentityOp>(
res.getLoc(), compile_op->getResult(0).getType(),
result_id->getResult(0));
res.output().replaceAllUsesWith(compile_op->getResult(0));
}
BuildTPUCompileSucceededAssertOp(
compile_op, tpu_device_assignment.compilation_device, builder);
compile_op, result_id, tpu_device_assignment.compilation_device, builder);
AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
builder);
@ -733,26 +742,50 @@ void TPURewritePass::runOnOperation() {
if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
return signalPassFailure();
// Collect compilation results.
llvm::DenseMap<Attribute, SmallVector<TF::TPUCompilationResultOp, 1>>
compilation_results;
auto result_init = getOperation().walk([&](TF::TPUCompilationResultOp op) {
auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_compilation_status");
if (!cluster_id) {
op->emitOpError("missing '_tpu_compilation_status'");
return WalkResult::interrupt();
}
compilation_results[cluster_id].push_back(op);
return WalkResult::advance();
});
if (result_init.wasInterrupted()) return signalPassFailure();
llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
OpBuilder builder(&getContext());
auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
// Skip non-tpu device cluster_func.
auto replicate_attr = op->getAttrOfType<StringAttr>("_tpu_replicate");
if (!replicate_attr) return WalkResult::advance();
auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_replicate");
if (!cluster_id) return WalkResult::advance();
if (failed(Rewrite(op, devices.device_names(), &builder)))
if (failed(Rewrite(op, devices.device_names(),
compilation_results[cluster_id], &builder)))
return WalkResult::interrupt();
to_be_erased.push_back(op);
return WalkResult::advance();
});
if (result.wasInterrupted()) return signalPassFailure();
EraseClusterFuncs(to_be_erased);
// Eliminate TPUCompilationResultOp now that the rewrite is complete.
getOperation().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
for (auto& it : compilation_results) {
for (auto op : it.second) {
if (!op.use_empty()) {
mlir::InFlightDiagnostic err = op.emitError("uses remain post rewrite");
for (auto user : op->getUsers())
err.attachNote(user->getLoc()) << "remaining user";
return signalPassFailure();
}
op.erase();
}
}
// TODO(b/139377366): Remove functions that are no longer needed.
}