[XLA] Create default ExecutionOptions in Client::Execute* APIs when applicable.

Clients that call the XLA Client API directly (not through TF) and do not pass
execution_options explicitly will be bitten by bad default values of flags
without this change.

PiperOrigin-RevId: 161097315
This commit is contained in:
Eli Bendersky 2017-07-06 10:30:16 -07:00 committed by TensorFlower Gardener
parent 583497f167
commit ad0ba9e6ea
2 changed files with 9 additions and 2 deletions

View File

@ -62,6 +62,7 @@ cc_library(
deps = [
":computation",
":global_data",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:service_interface",
"//tensorflow/compiler/xla:status_macros",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include <utility>
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@ -198,7 +199,10 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
ExecutionProfile* execution_profile) {
ExecuteRequest request;
*request.mutable_computation() = computation.handle();
if (execution_options != nullptr) {
if (execution_options == nullptr) {
*request.mutable_execution_options() = CreateDefaultExecutionOptions();
} else {
*request.mutable_execution_options() = *execution_options;
}
for (GlobalData* argument : arguments) {
@ -299,7 +303,9 @@ StatusOr<ExecutionHandle> Client::ExecuteAsync(
for (GlobalData* argument : arguments) {
*request.add_arguments() = argument->handle();
}
if (execution_options != nullptr) {
if (execution_options == nullptr) {
*request.mutable_execution_options() = CreateDefaultExecutionOptions();
} else {
*request.mutable_execution_options() = *execution_options;
}