[TF:XLA] Use RunAsync() in TF/XLA launch ops where applicable.

PiperOrigin-RevId: 261734859
This commit is contained in:
Peter Hawkins 2019-08-05 11:52:54 -07:00 committed by TensorFlower Gardener
parent a3672fefae
commit c4bdfed6e8

View File

@ -366,7 +366,12 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
Env* env = Env::Default(); Env* env = Env::Default();
auto start_time = env->NowMicros(); auto start_time = env->NowMicros();
auto run_result = executable->Run(launch_context.arguments(), run_options); xla::StatusOr<xla::ScopedShapedBuffer> run_result;
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
run_result = executable->Run(launch_context.arguments(), run_options);
} else {
run_result = executable->RunAsync(launch_context.arguments(), run_options);
}
OP_REQUIRES(ctx, run_result.ok(), run_result.status()); OP_REQUIRES(ctx, run_result.ok(), run_result.status());
auto elapsed = env->NowMicros() - start_time; auto elapsed = env->NowMicros() - start_time;
@ -550,8 +555,14 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
Env* env = Env::Default(); Env* env = Env::Default();
auto start_time = env->NowMicros(); auto start_time = env->NowMicros();
auto run_result = xla::StatusOr<xla::ScopedShapedBuffer> run_result;
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
run_result =
closure.executable()->Run(launch_context.arguments(), run_options); closure.executable()->Run(launch_context.arguments(), run_options);
} else {
run_result =
closure.executable()->RunAsync(launch_context.arguments(), run_options);
}
OP_REQUIRES(ctx, run_result.ok(), run_result.status()); OP_REQUIRES(ctx, run_result.ok(), run_result.status());
auto elapsed = env->NowMicros() - start_time; auto elapsed = env->NowMicros() - start_time;