Update check for compilation result to consult the correct attribute
The _tpu_replicate attribute for the ops being clustered needs to be matched with the _tpu_compilation_status attribute of the compilation result. PiperOrigin-RevId: 360812812 Change-Id: I0bcf5c8617d728b3227616214929e56ee203014b
This commit is contained in:
parent
d2d7787e19
commit
32b5152f35
@ -15674,7 +15674,7 @@ Computes the gradient function for function f via backpropagation.
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> {
|
||||
def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", []> {
|
||||
let summary = "Returns the result of a TPU compilation.";
|
||||
|
||||
let description = [{
|
||||
|
@ -1364,14 +1364,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"
|
||||
// CHECK: %[[COMPILE_RESULT_0:.*]] = "tf.Identity"(%[[COMPILE_OUTPUT]]#0)
|
||||
// CHECK: %[[COMPILE_RESULT_1:.*]] = "tf.Identity"(%[[COMPILE_RESULT_0]])
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"
|
||||
%1 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
%compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
|
||||
%compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
|
||||
%compile_result = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster0"} : () -> tensor<!tf.string>
|
||||
%compile_result2 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster0"} : () -> tensor<!tf.string>
|
||||
|
||||
// CHECK-NOT: "tf.TPUCompilationResult"
|
||||
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
@ -556,16 +557,18 @@ tf_device::LaunchOp AssignDevicesToReplicatedExecute(
|
||||
// Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
|
||||
// status of `compile_op` to check whether compilation is successful.
|
||||
void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
|
||||
Operation* result_id,
|
||||
llvm::StringRef compilation_device,
|
||||
OpBuilder* builder) {
|
||||
auto assert_op = builder->create<TF::TPUCompileSucceededAssertOp>(
|
||||
compile_op->getLoc(), compile_op->getResult(0));
|
||||
compile_op->getLoc(), result_id->getResult(0));
|
||||
WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
|
||||
}
|
||||
|
||||
LogicalResult Rewrite(
|
||||
tf_device::ClusterFuncOp cluster_func,
|
||||
llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
|
||||
ArrayRef<TF::TPUCompilationResultOp> compilation_result,
|
||||
OpBuilder* builder) {
|
||||
// Collect `num_replicas` and `num_cores_per_replica` attributes.
|
||||
int num_replicas = 1;
|
||||
@ -645,16 +648,22 @@ LogicalResult Rewrite(
|
||||
});
|
||||
}
|
||||
|
||||
// After rewrite, find if there is a TPUCompilationResultOp in the block with
|
||||
// the same _tpu_replicate attribute and replace it with the result of the
|
||||
// compile op. This op is used as a placeholder to hook during graph creation
|
||||
// the other ops that are intended to consume the compile result.
|
||||
Block* block = cluster_func.getOperation()->getBlock();
|
||||
for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
|
||||
compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
|
||||
// After rewrite, if there is a TPUCompilationResultOp from the same cluster,
|
||||
// replace it with the result of the compile op. The TPUCompilationResultOp is
|
||||
// used as a placeholder to hook during graph creation the other ops that are
|
||||
// intended to consume the compile result.
|
||||
Operation* result_id = compile_op;
|
||||
for (auto res : compilation_result) {
|
||||
// Build identity op with the same location/name as the original compilation
|
||||
// result op.
|
||||
result_id = builder->create<TF::IdentityOp>(
|
||||
res.getLoc(), compile_op->getResult(0).getType(),
|
||||
result_id->getResult(0));
|
||||
res.output().replaceAllUsesWith(compile_op->getResult(0));
|
||||
}
|
||||
|
||||
BuildTPUCompileSucceededAssertOp(
|
||||
compile_op, tpu_device_assignment.compilation_device, builder);
|
||||
compile_op, result_id, tpu_device_assignment.compilation_device, builder);
|
||||
|
||||
AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices,
|
||||
builder);
|
||||
@ -733,26 +742,50 @@ void TPURewritePass::runOnOperation() {
|
||||
if (failed(tensorflow::GetDevicesFromOp(getOperation(), &devices)))
|
||||
return signalPassFailure();
|
||||
|
||||
// Collect compilation results.
|
||||
llvm::DenseMap<Attribute, SmallVector<TF::TPUCompilationResultOp, 1>>
|
||||
compilation_results;
|
||||
auto result_init = getOperation().walk([&](TF::TPUCompilationResultOp op) {
|
||||
auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_compilation_status");
|
||||
if (!cluster_id) {
|
||||
op->emitOpError("missing '_tpu_compilation_status'");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
compilation_results[cluster_id].push_back(op);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (result_init.wasInterrupted()) return signalPassFailure();
|
||||
|
||||
llvm::SmallVector<tf_device::ClusterFuncOp> to_be_erased;
|
||||
OpBuilder builder(&getContext());
|
||||
auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) {
|
||||
// Skip non-tpu device cluster_func.
|
||||
auto replicate_attr = op->getAttrOfType<StringAttr>("_tpu_replicate");
|
||||
if (!replicate_attr) return WalkResult::advance();
|
||||
auto cluster_id = op->getAttrOfType<StringAttr>("_tpu_replicate");
|
||||
if (!cluster_id) return WalkResult::advance();
|
||||
|
||||
if (failed(Rewrite(op, devices.device_names(), &builder)))
|
||||
if (failed(Rewrite(op, devices.device_names(),
|
||||
compilation_results[cluster_id], &builder)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
to_be_erased.push_back(op);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (result.wasInterrupted()) return signalPassFailure();
|
||||
|
||||
EraseClusterFuncs(to_be_erased);
|
||||
|
||||
// Eliminate TPUCompilationResultOp now that the rewrite is complete.
|
||||
getOperation().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
|
||||
for (auto& it : compilation_results) {
|
||||
for (auto op : it.second) {
|
||||
if (!op.use_empty()) {
|
||||
mlir::InFlightDiagnostic err = op.emitError("uses remain post rewrite");
|
||||
for (auto user : op->getUsers())
|
||||
err.attachNote(user->getLoc()) << "remaining user";
|
||||
return signalPassFailure();
|
||||
}
|
||||
op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/139377366): Remove functions that are no longer needed.
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user