Change LocalClient::Compile to support returning multiple executables (one per partition).

PiperOrigin-RevId: 291343358
Change-Id: I0550040ddbb67e78e9e4078185e0af6b11b96e35
This commit is contained in:
A. Unique TensorFlower 2020-01-24 03:32:24 -08:00 committed by TensorFlower Gardener
parent c42a05f658
commit a3edaa7235
25 changed files with 136 additions and 222 deletions

View File

@ -163,11 +163,12 @@ Status XlaCompilationCache::BuildExecutable(
build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
TF_ASSIGN_OR_RETURN(
auto executables,
client_->Compile(*result.computation, argument_layouts, build_options));
TF_RET_CHECK(executables.size() == 1);
*executable = std::move(executables[0]);
auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*executable = std::move(compile_result.ValueOrDie());
return Status::OK();
}

View File

@ -117,10 +117,8 @@ XlaJitCompiledCpuFunction::Compile(
// Compile the executable. The static_cast to the CpuExecutable subclass is
// necessary since the raw function and buffer assignments are only available
// there.
TF_ASSIGN_OR_RETURN(auto executables,
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
client->Compile(computation, arg_shapes, build_options));
TF_RET_CHECK(executables.size() == 1);
std::unique_ptr<xla::LocalExecutable> executable = std::move(executables[0]);
const xla::cpu::CpuExecutable* cpu_executable =
static_cast<xla::cpu::CpuExecutable*>(executable->executable());
XlaCompiledCpuFunction::RawFunction raw_function =

View File

@ -337,7 +337,7 @@ Backend* LocalClient::mutable_backend() {
return local_service_->mutable_backend();
}
StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& options) {
@ -347,20 +347,12 @@ StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
VLOG(3) << "Set device ordinal to default value of: "
<< updated_options.device_ordinal();
}
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
local_service_->CompileExecutables(
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
local_service_->CompileExecutable(
computation, argument_layouts, updated_options));
std::vector<std::unique_ptr<LocalExecutable>> local_executables;
local_executables.reserve(executables.size());
for (auto& executable : executables) {
local_executables.push_back(absl::make_unique<LocalExecutable>(
std::move(executable), local_service_->mutable_backend(),
updated_options));
}
return std::move(local_executables);
return absl::make_unique<LocalExecutable>(std::move(executable),
local_service_->mutable_backend(),
updated_options);
}
StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(

View File

@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
@ -111,13 +110,12 @@ class LocalClient : public Client {
LocalClient(const LocalClient&) = delete;
void operator=(const LocalClient&) = delete;
// Build and return LocalExecutable objects (one per partition, as specified
// by the build options). The executable is compiled using the given
// XlaComputation, argument layouts and options.
// Build and return a LocalExecutable object. The executable is compiled using
// the given XlaComputation, argument layouts and options.
//
// The given ExecutableBuildOptions overrides any values from XLA_FLAGS
// environment variable.
StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> Compile(
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& options);

View File

@ -676,27 +676,16 @@ static std::shared_ptr<Device> LookupDevice(const PyLocalClient& client,
}
PyLocalExecutable::PyLocalExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
std::shared_ptr<LocalExecutable> executable,
DeviceAssignment device_assignment, std::shared_ptr<PyLocalClient> client)
: client_(std::move(client)),
executable_(std::move(executable)),
device_assignment_(
std::make_shared<DeviceAssignment>(device_assignment)) {
executables_.reserve(executables.size());
for (auto& executable : executables) {
executables_.emplace_back(std::move(executable));
}
// This must go after `executables_` is initialized.
VLOG(1) << "PyLocalExecutable " << name() << " device_assignment:\n"
<< device_assignment_->ToString();
const int num_replicas = device_assignment_->replica_count();
const int num_partitions = device_assignment_->computation_count();
CHECK_EQ(num_partitions, executables_.size())
<< "Number of executables " << executables_.size()
<< " did not match number of partitions " << num_partitions;
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment_)(replica, partition);
@ -715,7 +704,7 @@ PyLocalExecutable::PyLocalExecutable(
}
const std::string& PyLocalExecutable::name() const {
Executable* executable = executables_[0]->executable();
Executable* executable = executable_->executable();
if (executable->has_module()) {
return executable->module().name();
} else {
@ -788,7 +777,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
options.set_run_id(run_id);
StatusOr<ScopedShapedBuffer> result_buffer_or_status =
executables_[partition]->RunAsync(argument_buffer_ptrs, options);
executable_->RunAsync(argument_buffer_ptrs, options);
VLOG(1) << "Replica " << replica << " partition " << partition
<< " completed; ok=" << result_buffer_or_status.ok();
@ -818,8 +807,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
device_state->ThenRelease(
device_state->compute_stream(),
std::make_tuple(executables_[partition], compute_reservation,
device_assignment_));
std::make_tuple(executable_, compute_reservation, device_assignment_));
return absl::make_unique<PyLocalBuffer>(
result_buffer.on_host_shape(), result_buffer.on_device_shape(),
std::move(out_buffer), client_, device);
@ -1044,14 +1032,13 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
TF_RETURN_IF_ERROR(assign_layouts(&result_layout));
options.set_result_layout(result_layout);
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<LocalExecutable>> local_executables,
client->client()->Compile(computation, argument_layout_pointers,
options));
TF_ASSIGN_OR_RETURN(std::unique_ptr<LocalExecutable> local_executable,
client->client()->Compile(
computation, argument_layout_pointers, options));
return absl::make_unique<PyLocalExecutable>(std::move(local_executables),
std::move(*device_assignment),
std::move(client));
return absl::make_unique<PyLocalExecutable>(
std::shared_ptr<LocalExecutable>(std::move(local_executable)),
std::move(*device_assignment), std::move(client));
}
} // namespace xla

View File

@ -283,8 +283,7 @@ class PyLocalBuffer {
};
// Represents a compiled computation that can be executed given handles to
// device-allocated literals. Wraps one or more XLA LocalExecutables (one per
// partition, as specified by the build options).
// device-allocated literals. Wraps an XLA LocalExecutable.
class PyLocalExecutable {
public:
// Compiles a computation to an executable.
@ -295,24 +294,20 @@ class PyLocalExecutable {
std::shared_ptr<PyLocalClient> client,
absl::optional<DeviceAssignment> device_assignment);
PyLocalExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
PyLocalExecutable(std::shared_ptr<LocalExecutable> executable,
DeviceAssignment device_assignment,
std::shared_ptr<PyLocalClient> client);
int num_replicas() const {
return executables_[0]->build_options().num_replicas();
return executable_->build_options().num_replicas();
}
int num_partitions() const {
return executables_[0]->build_options().num_partitions();
return executable_->build_options().num_partitions();
}
int64 SizeOfGeneratedCodeInBytes() const {
int64 size = 0;
for (auto& executable : executables_) {
size += executable->executable()->SizeOfGeneratedCodeInBytes();
}
return size;
return executable_->executable()->SizeOfGeneratedCodeInBytes();
}
const DeviceAssignment& device_assignment() const {
@ -341,7 +336,7 @@ class PyLocalExecutable {
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteOnLocalDevices(
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles);
void Delete() { executables_.clear(); }
void Delete() { executable_ = nullptr; }
const string& name() const;
@ -354,8 +349,7 @@ class PyLocalExecutable {
// asynchronous execution, the process being executed can outlive the
// executable itself.
std::shared_ptr<PyLocalClient> const client_;
// One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_;
std::shared_ptr<LocalExecutable> executable_;
std::shared_ptr<DeviceAssignment> device_assignment_;
// The replica and partition indices of device_assignment_ to be run by this

View File

@ -119,8 +119,7 @@ ExecutionOptions CreateExecutionOptions(
} // namespace
StatusOr<std::vector<std::unique_ptr<Executable>>>
LocalService::CompileExecutables(
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options) {
@ -179,29 +178,9 @@ LocalService::CompileExecutables(
se::StreamExecutor * executor,
execute_backend_->stream_executor(build_options.device_ordinal()));
// TODO(cjfj): Investigate why there are a couple of test failures when the
// single partition computations are built using `BuildExecutables`, fix it,
// and remove this special case (provided the performance if similar).
if (build_options.num_partitions() == 1) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
executor, build_options.device_allocator()));
std::vector<std::unique_ptr<Executable>> executables;
executables.push_back(std::move(executable));
return executables;
} else {
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
module_configs.push_back(std::move(module_config));
// BuildExecutables uses the executors length to determine the number of
// cores per module, but otherwise only uses the first executor.
std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
executor);
return BuildExecutables({&proto}, std::move(module_configs),
execute_backend_.get(), {executors},
build_options.device_allocator());
}
return BuildExecutable(proto, std::move(module_config),
execute_backend_.get(), executor,
build_options.device_allocator());
}
StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {

View File

@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
@ -42,12 +41,12 @@ class LocalService : public Service {
static StatusOr<std::unique_ptr<LocalService>> NewService(
const ServiceOptions& options);
// Builds Executables with the given XlaComputation, argument layouts and
// Builds an Executable with the given XlaComputation, argument layouts and
// options. If result_layout is non-null, then the executable is compiled to
// produce a result of the given layout. If device_allocator is non-null,
// then the compiler may use it to allocate temp space on the device. The
// compiler is responsible for freeing any memory it allocates this way.
StatusOr<std::vector<std::unique_ptr<Executable>>> CompileExecutables(
StatusOr<std::unique_ptr<Executable>> CompileExecutable(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options);

View File

@ -452,10 +452,7 @@ xla_test(
name = "while_test",
srcs = ["while_test.cc"],
deps = [
":client_library_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -467,6 +464,9 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@ -480,9 +480,7 @@ xla_test(
"interpreter",
],
deps = [
":client_library_test_base",
":test_macros_header",
":test_utils",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@ -491,6 +489,8 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
@ -950,12 +950,7 @@ xla_test(
"optonly",
],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
@ -964,6 +959,11 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -984,12 +984,7 @@ xla_test(
"optonly",
],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
@ -998,6 +993,11 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1048,12 +1048,7 @@ xla_test(
"optonly",
],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
@ -1062,6 +1057,11 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1383,10 +1383,7 @@ xla_test(
timeout = "moderate",
srcs = ["dynamic_ops_test.cc"],
deps = [
":client_library_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:test_helpers",
@ -1398,6 +1395,9 @@ xla_test(
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@ -2203,11 +2203,7 @@ xla_test(
name = "cpu_gpu_fusion_test",
srcs = ["cpu_gpu_fusion_test.cc"],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -2216,6 +2212,10 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -2293,11 +2293,7 @@ xla_test(
shard_count = 30,
tags = ["optonly"],
deps = [
":literal_test_util",
":local_client_test_base",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@ -2306,12 +2302,16 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:local_client_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@ -2524,13 +2524,13 @@ tf_cc_test(
srcs = ["multiple_devices_on_host_test.cc"],
args = ["--xla_force_host_platform_device_count=4"],
deps = [
":xla_internal_test_main", # fixdeps: keep
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/synchronization",

View File

@ -882,14 +882,13 @@ void BM_ParallelFusion(int num_iters) {
.ConsumeValueOrDie();
// Build executable.
auto executables =
std::unique_ptr<LocalExecutable> executable =
client
->Compile(computation,
{&buffer0.on_host_shape(), &buffer1.on_host_shape(),
&buffer2.on_host_shape()},
ExecutableBuildOptions())
.ConsumeValueOrDie();
auto executable = std::move(executables[0]);
se::Stream stream(executors[device_ordinal]);
stream.Init();

View File

@ -1672,10 +1672,11 @@ void DOT_ReorderContracting(int num_iters) {
client->LiteralToShapedBuffer(input_literal, device_ordinal)
.ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto executables, client->Compile(computation, {&buffer0.on_host_shape()},
ExecutableBuildOptions()));
auto executable = std::move(executables[0]);
std::unique_ptr<LocalExecutable> executable =
client
->Compile(computation, {&buffer0.on_host_shape()},
ExecutableBuildOptions())
.ConsumeValueOrDie();
se::Stream stream(executors[device_ordinal]);
stream.Init();

View File

@ -779,10 +779,9 @@ void BM_DynamicSlice(int num_iters) {
DynamicSlice(input, start_indices, {1, 1, 1, 1});
auto computation = builder.Build().ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
client->Compile(computation, host_shapes, ExecutableBuildOptions()));
auto executable = std::move(executables[0]);
std::unique_ptr<LocalExecutable> executable =
client->Compile(computation, host_shapes, ExecutableBuildOptions())
.ConsumeValueOrDie();
// Run some warm-up executions.
ExecutableRunOptions options;

View File

@ -242,7 +242,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
[&](const Literal* input_literal) { return &input_literal->shape(); });
TF_ASSIGN_OR_RETURN(
auto executables,
auto executable,
client_->Compile(computation, input_shapes, build_opts));
std::vector<ScopedShapedBuffer> input_buffers;
@ -264,7 +264,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
run_opts.set_intra_op_thread_pool(
client_->backend().eigen_intra_op_thread_pool_device());
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
executables[0]->Run(input_buffer_pointers, run_opts));
executable->Run(input_buffer_pointers, run_opts));
TF_ASSIGN_OR_RETURN(Literal result_literal,
client_->ShapedBufferToLiteral(result));

View File

@ -46,13 +46,12 @@ TEST_F(HloMetadataTest, MetadataPropagation) {
Shape argument_layout = ShapeUtil::MakeShape(F32, {});
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
std::unique_ptr<LocalExecutable> executable,
local_client_->Compile(builder.Build().ValueOrDie(),
{&argument_layout, &argument_layout},
ExecutableBuildOptions()));
auto instruction = executables[0]
->executable()
auto instruction = executable->executable()
->module()
.entry_computation()
->root_instruction();
@ -68,14 +67,15 @@ TEST_F(HloMetadataTest, MetadataClearing) {
BuildAddComputation(&builder);
Shape argument_layout = ShapeUtil::MakeShape(F32, {});
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
local_client_->Compile(builder.Build().ValueOrDie(),
{&argument_layout, &argument_layout},
ExecutableBuildOptions()));
auto executable_status = local_client_->Compile(
builder.Build().ValueOrDie(), {&argument_layout, &argument_layout},
ExecutableBuildOptions());
ASSERT_IS_OK(executable_status);
auto instruction = executables[0]
->executable()
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
auto instruction = executable->executable()
->module()
.entry_computation()
->root_instruction();

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@ -760,17 +759,17 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
Shape argument_layout =
ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
auto executable_status =
local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
ExecutableBuildOptions()));
EXPECT_EQ(1, executables.size());
ExecutableBuildOptions());
ASSERT_IS_OK(executable_status);
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
auto x_array =
LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
executables[0]
->Run({&x_array}, DefaultExecutableRunOptions())
executable->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
ASSERT_IS_OK(local_client_->mutable_backend()
->BorrowStream(0)
@ -781,31 +780,6 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
{2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) {
if (local_client_->device_count() < 2) {
GTEST_SKIP_("requires two devices");
}
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
auto z = ConstantR1<float>(&builder, {5.0f, 6.0f, 7.0f});
auto r = Add(x, y);
builder.SetSharding(sharding_builder::AssignDevice(1));
Add(r, z);
builder.ClearSharding();
Shape argument_layout =
ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
ExecutableBuildOptions build_options;
build_options.set_num_partitions(2);
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
build_options));
EXPECT_EQ(2, executables.size());
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
// Test copying Literals to the device as ShapedBuffers, then copying them
// back again to Literals.
@ -954,10 +928,11 @@ void BM_LocalClientOverhead(int num_iters) {
const int kWarmups = 2;
TF_ASSERT_OK_AND_ASSIGN(
auto executables, client->Compile(computation, {&buffer.on_host_shape()},
ExecutableBuildOptions()));
std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
auto executable_status = client->Compile(
computation, {&buffer.on_host_shape()}, ExecutableBuildOptions());
ASSERT_IS_OK(executable_status);
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
ExecutableRunOptions run_options;
run_options.set_allocator(&allocator).set_stream(stream.get());

View File

@ -194,10 +194,9 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
argument_layouts[i] = &arguments[i]->on_host_shape();
}
TF_ASSIGN_OR_RETURN(
auto executables,
std::unique_ptr<LocalExecutable> executable,
local_client_->Compile(computation, argument_layouts, build_options));
TF_RET_CHECK(executables.size() == 1);
TF_ASSIGN_OR_RETURN(auto ret, executables[0]->Run(arguments, run_options));
TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
auto device_ordinal =
build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();

View File

@ -65,9 +65,8 @@ void TestWithDeviceCount(const int device_count) {
TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation());
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
std::unique_ptr<LocalExecutable> executable,
client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{}));
std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
std::vector<tensorflow::Thread*> threads;
absl::Mutex results_mutex;
std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>> results;

View File

@ -47,12 +47,12 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) {
computation_status = builder.Build();
TF_ASSERT_OK(computation_status.status());
TF_ASSERT_OK_AND_ASSIGN(
auto executables, local_client_->Compile(computation_status.ValueOrDie(),
{&pair_float, &single_float},
ExecutableBuildOptions()));
HloModule& module =
const_cast<HloModule&>(executables[0]->executable()->module());
auto executable_status = local_client_->Compile(
computation_status.ValueOrDie(), {&pair_float, &single_float},
ExecutableBuildOptions());
TF_ASSERT_OK(executable_status.status());
HloModule& module = const_cast<HloModule&>(
executable_status.ValueOrDie()->executable()->module());
TF_ASSERT_OK(MakeFakeArguments(&module).status());
}

View File

@ -1314,10 +1314,9 @@ void BM_WhileLoop(int num_iters) {
While(condition, body, init);
auto computation = builder.Build().ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
client->Compile(computation, {}, ExecutableBuildOptions()));
auto executable = std::move(executables[0]);
std::unique_ptr<LocalExecutable> executable =
client->Compile(computation, {}, ExecutableBuildOptions())
.ConsumeValueOrDie();
// Run some warm-up executions.
ExecutableRunOptions options;

View File

@ -158,11 +158,11 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
ExecutableBuildOptions build_options;
build_options.mutable_debug_options()->set_xla_hlo_profile(true);
TF_ASSERT_OK_AND_ASSIGN(
auto local_executables,
std::unique_ptr<LocalExecutable> local_executable,
client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape},
build_options));
Executable* executable = local_executables[0]->executable();
Executable* executable = local_executable->executable();
HloExecutionProfile hlo_execution_profile(
&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());

View File

@ -85,11 +85,10 @@ void RealMain(absl::Span<char* const> args) {
ExecutableBuildOptions build_options;
build_options.set_device_ordinal(0);
build_options.set_result_layout(program_shape->result());
auto executables =
local_service->CompileExecutables(computation, layouts, build_options)
.ConsumeValueOrDie();
CHECK_EQ(executables.size(), 1);
const HloModule& module = executables[0]->module();
StatusOr<std::unique_ptr<Executable>> executable =
local_service->CompileExecutable(computation, layouts, build_options);
const HloModule& module = executable.ValueOrDie()->module();
OperationDumper dumper(arg);
for (auto* computation : module.computations()) {

View File

@ -62,11 +62,10 @@ void RealMain(absl::Span<char* const> args, bool compile) {
ExecutableBuildOptions build_options;
build_options.set_device_ordinal(0);
build_options.set_result_layout(program_shape->result());
auto executables =
local_service->CompileExecutables(computation, layouts, build_options)
.ConsumeValueOrDie();
CHECK_EQ(executables.size(), 1);
const HloModule& module = executables[0]->module();
StatusOr<std::unique_ptr<Executable>> executable =
local_service->CompileExecutable(computation, layouts, build_options);
const HloModule& module = executable.ValueOrDie()->module();
fprintf(stdout, "HLO compiled for %s backend:\n%s\n",
local_service->backend().platform()->Name().c_str(),

View File

@ -125,11 +125,7 @@ StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
}
ExecutableBuildOptions exec_build_options;
*exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags();
TF_ASSIGN_OR_RETURN(
auto executables,
client->Compile(computation, argument_layout_ptrs, exec_build_options));
TF_RET_CHECK(executables.size() == 1);
return std::move(executables[0]);
return client->Compile(computation, argument_layout_ptrs, exec_build_options);
}
absl::optional<Shape> GetXfeedShape(bool is_infeed,

View File

@ -126,11 +126,12 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
}
VLOG(1) << "Building executable";
TF_ASSIGN_OR_RETURN(
auto executables,
client->Compile(computation, argument_layout_ptrs, build_options));
TF_RET_CHECK(executables.size() == 1);
*program = std::move(executables[0]);
auto compile_result =
client->Compile(computation, argument_layout_ptrs, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*program = std::move(compile_result.ValueOrDie());
return Status::OK();
}

View File

@ -286,8 +286,8 @@ xla::ProgramShape XlaCompiledProgramShape(
parameters_shapes.push_back(&input_program_shape.parameters(i));
}
auto local_executable =
std::move(client->Compile(computation, parameters_shapes, exec_options)
.ValueOrDie()[0]);
client->Compile(computation, parameters_shapes, exec_options)
.ValueOrDie();
return local_executable->executable()
->module()
.entry_computation()