[XLA] Make HLO snapshot dumping work on the LocalClient::RunAsync path.
PiperOrigin-RevId: 259956061
This commit is contained in:
parent
6710b20acc
commit
18f1467496
@ -189,8 +189,49 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
|
||||
ExecutableRunOptions run_options) {
|
||||
TF_ASSIGN_OR_RETURN(auto options_and_stream,
|
||||
RunHelper(arguments, run_options));
|
||||
return executable_->ExecuteAsyncOnStream(&options_and_stream.first,
|
||||
arguments);
|
||||
se::Stream* stream = run_options.stream();
|
||||
|
||||
std::shared_ptr<HloSnapshot> snapshot;
|
||||
if (executable_->dumping_snapshot()) {
|
||||
snapshot = std::make_shared<HloSnapshot>();
|
||||
snapshot->set_execution_platform(backend_->platform()->Name());
|
||||
*snapshot->mutable_hlo() = *executable_->hlo_proto();
|
||||
for (const ShapedBuffer* arg : arguments) {
|
||||
auto literal = std::make_shared<Literal>(arg->on_host_shape());
|
||||
backend_->transfer_manager()->TransferLiteralFromDevice(
|
||||
stream, *arg, literal.get(), [snapshot, literal](Status status) {
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
|
||||
"failed: "
|
||||
<< status;
|
||||
return;
|
||||
}
|
||||
*snapshot->add_arguments() = literal->ToProto();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer outputs,
|
||||
executable_->ExecuteAsyncOnStream(&options_and_stream.first, arguments));
|
||||
|
||||
// Transfer the outputs and save the snapshot to disk.
|
||||
if (snapshot) {
|
||||
auto literal = std::make_shared<Literal>(outputs.on_host_shape());
|
||||
backend_->transfer_manager()->TransferLiteralFromDevice(
|
||||
stream, outputs, literal.get(), [snapshot, literal](Status status) {
|
||||
if (status.ok()) {
|
||||
*snapshot->mutable_result() = literal->ToProto();
|
||||
} else {
|
||||
LOG(ERROR)
|
||||
<< "TransferLiteralFromDevice for HLO snapshot outputs failed: "
|
||||
<< status;
|
||||
}
|
||||
DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags());
|
||||
});
|
||||
}
|
||||
|
||||
return std::move(outputs);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
|
||||
|
Loading…
Reference in New Issue
Block a user