diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 289004575d3..a998b91c89d 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -62,6 +62,7 @@ cc_library( deps = [ ":computation", ":global_data", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:service_interface", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 555e3df427a..1799bbd3480 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -198,7 +199,10 @@ StatusOr> Client::Execute( ExecutionProfile* execution_profile) { ExecuteRequest request; *request.mutable_computation() = computation.handle(); - if (execution_options != nullptr) { + + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { *request.mutable_execution_options() = *execution_options; } for (GlobalData* argument : arguments) { @@ -299,7 +303,9 @@ StatusOr Client::ExecuteAsync( for (GlobalData* argument : arguments) { *request.add_arguments() = argument->handle(); } - if (execution_options != nullptr) { + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { *request.mutable_execution_options() = *execution_options; }