parent
130a84e59c
commit
3f266b1c8d
@ -196,15 +196,16 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
|
||||
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
const absl::Span<const ShapedBuffer* const> arguments) {
|
||||
executable_->hlo_snapshot()->set_execution_platform(
|
||||
backend_->platform()->Name());
|
||||
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
|
||||
HloSnapshot snapshot;
|
||||
*snapshot.mutable_hlo() = *executable_->hlo_proto();
|
||||
snapshot.set_execution_platform(backend_->platform()->Name());
|
||||
TF_RETURN_IF_ERROR(RecordArguments(arguments, &snapshot));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer result,
|
||||
executable_->ExecuteOnStream(run_options, arguments,
|
||||
/*hlo_execution_profile=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot()));
|
||||
DumpHloSnapshotIfEnabled(executable_->module(), *executable_->hlo_snapshot());
|
||||
TF_RETURN_IF_ERROR(RecordResult(&result, &snapshot));
|
||||
DumpHloSnapshotIfEnabled(executable_->module(), snapshot);
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
|
@ -224,11 +224,11 @@ class Executable {
|
||||
virtual int64 SizeInBytes();
|
||||
|
||||
// Dumping helpers.
|
||||
void set_hlo_snapshot(std::unique_ptr<xla::HloSnapshot> hlo_snapshot) {
|
||||
hlo_snapshot_ = std::move(hlo_snapshot);
|
||||
void set_hlo_proto(std::unique_ptr<xla::HloProto> hlo_proto) {
|
||||
hlo_proto_ = std::move(hlo_proto);
|
||||
}
|
||||
bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; }
|
||||
HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); }
|
||||
bool dumping_snapshot() const { return hlo_proto_ != nullptr; }
|
||||
HloProto const* hlo_proto() const { return hlo_proto_.get(); }
|
||||
|
||||
protected:
|
||||
mutable tensorflow::mutex mutex_;
|
||||
@ -241,8 +241,8 @@ class Executable {
|
||||
// around.
|
||||
const std::shared_ptr<HloModule> hlo_module_;
|
||||
|
||||
// HloSnapshot this was compiled from. Null if not dumping executions.
|
||||
std::unique_ptr<HloSnapshot> hlo_snapshot_;
|
||||
// The serialized HLO proto. Non-null only if dumping snapshots is enabled.
|
||||
std::unique_ptr<HloProto const> hlo_proto_;
|
||||
|
||||
// Execution count, used to generate a unique filename for each dumped
|
||||
// execution.
|
||||
|
@ -351,11 +351,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
VLOG(1) << StrFormat("BuildExecutable on service %p", this);
|
||||
|
||||
// Dump computation proto state if flag is set.
|
||||
std::vector<std::unique_ptr<HloSnapshot>> hlo_snapshots;
|
||||
std::vector<std::unique_ptr<HloProto>> hlo_protos;
|
||||
for (int64 i = 0; i < module_protos.size(); ++i) {
|
||||
auto hlo_snapshot = absl::make_unique<HloSnapshot>();
|
||||
*hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i];
|
||||
hlo_snapshots.push_back(std::move(hlo_snapshot));
|
||||
auto hlo_proto = absl::make_unique<HloProto>();
|
||||
*hlo_proto->mutable_hlo_module() = *module_protos[i];
|
||||
hlo_protos.push_back(std::move(hlo_proto));
|
||||
}
|
||||
|
||||
VLOG(1) << "Computations:";
|
||||
@ -383,7 +383,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
const auto& debug_opts = module_configs[i]->debug_options();
|
||||
if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) &&
|
||||
debug_opts.xla_dump_hlo_snapshots()) {
|
||||
executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i]));
|
||||
executables[i]->set_hlo_proto(std::move(hlo_protos[i]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -692,14 +692,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
|
||||
executable_ptrs.push_back(executable.get());
|
||||
}
|
||||
|
||||
std::vector<HloSnapshot> snapshots;
|
||||
snapshots.resize(executable_ptrs.size());
|
||||
for (int i = 0; i < executable_ptrs.size(); i++) {
|
||||
if (executable_ptrs[i]->dumping_snapshot()) {
|
||||
*snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto();
|
||||
TF_ASSIGN_OR_RETURN(auto stream,
|
||||
execute_backend_->BorrowStream(
|
||||
all_executors[i][0]->device_ordinal()));
|
||||
TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
|
||||
execute_backend_->transfer_manager(),
|
||||
executable_ptrs[i]->hlo_snapshot()));
|
||||
&snapshots[i]));
|
||||
}
|
||||
}
|
||||
|
||||
@ -746,9 +749,8 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
|
||||
execute_backend_->BorrowStream(all_executors[i][0]));
|
||||
TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
|
||||
execute_backend_->transfer_manager(),
|
||||
executable->hlo_snapshot()));
|
||||
DumpHloSnapshotIfEnabled(executable->module(),
|
||||
*executable->hlo_snapshot());
|
||||
&snapshots[i]));
|
||||
DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -803,9 +805,9 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
const auto& debug_opts = module_config->debug_options();
|
||||
if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) &&
|
||||
debug_opts.xla_dump_hlo_snapshots()) {
|
||||
auto hlo_snapshot = absl::make_unique<HloSnapshot>();
|
||||
*hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto;
|
||||
executable->set_hlo_snapshot(std::move(hlo_snapshot));
|
||||
auto hlo_proto = absl::make_unique<HloProto>();
|
||||
*hlo_proto->mutable_hlo_module() = module_proto;
|
||||
executable->set_hlo_proto(std::move(hlo_proto));
|
||||
}
|
||||
|
||||
return std::move(executable);
|
||||
@ -891,12 +893,13 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
|
||||
TF_ASSIGN_OR_RETURN(auto stream,
|
||||
execute_backend_->BorrowStream(
|
||||
execute_backend_->default_stream_executor()));
|
||||
HloSnapshot snapshot;
|
||||
if (executable->dumping_snapshot()) {
|
||||
executable->hlo_snapshot()->set_execution_platform(
|
||||
execute_backend_->platform()->Name());
|
||||
TF_RETURN_IF_ERROR(RecordArguments(
|
||||
replicated_arguments.front(), stream.get(),
|
||||
execute_backend_->transfer_manager(), executable->hlo_snapshot()));
|
||||
*snapshot.mutable_hlo() = *executable->hlo_proto();
|
||||
snapshot.set_execution_platform(execute_backend_->platform()->Name());
|
||||
TF_RETURN_IF_ERROR(
|
||||
RecordArguments(replicated_arguments.front(), stream.get(),
|
||||
execute_backend_->transfer_manager(), &snapshot));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -913,8 +916,8 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
|
||||
allocation_tracker_.ResolveForReplica(result->output(), 0));
|
||||
TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
|
||||
execute_backend_->transfer_manager(),
|
||||
executable->hlo_snapshot()));
|
||||
DumpHloSnapshotIfEnabled(executable->module(), *executable->hlo_snapshot());
|
||||
&snapshot));
|
||||
DumpHloSnapshotIfEnabled(executable->module(), snapshot);
|
||||
}
|
||||
|
||||
VLOG(1) << "successfully completed 'execute' request";
|
||||
|
Loading…
Reference in New Issue
Block a user