Change LocalClient::Compile to support returning multiple executables (one per partition).

PiperOrigin-RevId: 292094485
Change-Id: Idaa4d14246478e5ec9b45d1d17d5610f35d35611
This commit is contained in:
Chris Jones 2020-01-29 01:00:10 -08:00 committed by TensorFlower Gardener
parent 6e4972c241
commit 9e944aa4fc
25 changed files with 225 additions and 136 deletions

View File

@ -163,12 +163,11 @@ Status XlaCompilationCache::BuildExecutable(
build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*executable = std::move(compile_result.ValueOrDie());
TF_ASSIGN_OR_RETURN(
auto executables,
client_->Compile(*result.computation, argument_layouts, build_options));
TF_RET_CHECK(executables.size() == 1);
*executable = std::move(executables[0]);
return Status::OK();
}

View File

@ -117,8 +117,10 @@ XlaJitCompiledCpuFunction::Compile(
// Compile the executable. The static_cast to the CpuExecutable subclass is
// necessary since the raw function and buffer assignments are only available
// there.
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
TF_ASSIGN_OR_RETURN(auto executables,
client->Compile(computation, arg_shapes, build_options));
TF_RET_CHECK(executables.size() == 1);
std::unique_ptr<xla::LocalExecutable> executable = std::move(executables[0]);
const xla::cpu::CpuExecutable* cpu_executable =
static_cast<xla::cpu::CpuExecutable*>(executable->executable());
XlaCompiledCpuFunction::RawFunction raw_function =

View File

@ -337,7 +337,7 @@ Backend* LocalClient::mutable_backend() {
return local_service_->mutable_backend();
}
StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& options) {
@ -347,12 +347,20 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
VLOG(3) << "Set device ordinal to default value of: "
<< updated_options.device_ordinal();
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
local_service_->CompileExecutable(
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
local_service_->CompileExecutables(
computation, argument_layouts, updated_options));
return absl::make_unique<LocalExecutable>(std::move(executable),
local_service_->mutable_backend(),
updated_options);
std::vector<std::unique_ptr<LocalExecutable>> local_executables;
local_executables.reserve(executables.size());
for (auto& executable : executables) {
local_executables.push_back(absl::make_unique<LocalExecutable>(
std::move(executable), local_service_->mutable_backend(),
updated_options));
}
return std::move(local_executables);
}
StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/client.h"
@ -110,12 +111,13 @@ class LocalClient : public Client {
LocalClient(const LocalClient&) = delete;
void operator=(const LocalClient&) = delete;
// Build and return a LocalExecutable object. The executable is compiled using
// the given XlaComputation, argument layouts and options.
// Build and return LocalExecutable objects (one per partition, as specified
// by the build options). The executable is compiled using the given
// XlaComputation, argument layouts and options.
//
// The given ExecutableBuildOptions overrides any values from XLA_FLAGS
// environment variable.
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> Compile(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& options);

View File

@ -676,16 +676,27 @@ static std::shared_ptr<Device> LookupDevice(const PyLocalClient& client,
}
PyLocalExecutable::PyLocalExecutable(
std::shared_ptr<LocalExecutable> executable,
std::vector<std::unique_ptr<LocalExecutable>> executables,
DeviceAssignment device_assignment, std::shared_ptr<PyLocalClient> client)
: client_(std::move(client)),
executable_(std::move(executable)),
device_assignment_(
std::make_shared<DeviceAssignment>(device_assignment)) {
executables_.reserve(executables.size());
for (auto& executable : executables) {
executables_.emplace_back(std::move(executable));
}
// This must go after `executables_` is initialized.
VLOG(1) << "PyLocalExecutable " << name() << " device_assignment:\n"
<< device_assignment_->ToString();
const int num_replicas = device_assignment_->replica_count();
const int num_partitions = device_assignment_->computation_count();
CHECK_EQ(num_partitions, executables_.size())
<< "Number of executables " << executables_.size()
<< " did not match number of partitions " << num_partitions;
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment_)(replica, partition);
@ -704,7 +715,7 @@ PyLocalExecutable::PyLocalExecutable(
}
const std::string& PyLocalExecutable::name() const {
Executable* executable = executable_->executable();
Executable* executable = executables_[0]->executable();
if (executable->has_module()) {
return executable->module().name();
} else {
@ -779,7 +790,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
device_state->compute_semaphore().ScopedAcquire(1));
StatusOr<ScopedShapedBuffer> result_buffer_or_status =
executable_->RunAsync(argument_buffer_ptrs, options);
executables_[partition]->RunAsync(argument_buffer_ptrs, options);
VLOG(1) << "Replica " << replica << " partition " << partition
<< " completed; ok=" << result_buffer_or_status.ok();
@ -809,7 +820,8 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
device_state->ThenRelease(
device_state->compute_stream(),
std::make_tuple(executable_, compute_reservation, device_assignment_));
std::make_tuple(executables_[partition], compute_reservation,
device_assignment_));
return absl::make_unique<PyLocalBuffer>(
result_buffer.on_host_shape(), result_buffer.on_device_shape(),
std::move(out_buffer), client_, device);
@ -1081,13 +1093,14 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
TF_RETURN_IF_ERROR(assign_layouts(&result_layout));
options.set_result_layout(result_layout);
TF_ASSIGN_OR_RETURN(std::unique_ptr<LocalExecutable> local_executable,
client->client()->Compile(
computation, argument_layout_pointers, options));
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<LocalExecutable>> local_executables,
client->client()->Compile(computation, argument_layout_pointers,
options));
return absl::make_unique<PyLocalExecutable>(
std::shared_ptr<LocalExecutable>(std::move(local_executable)),
std::move(*device_assignment), std::move(client));
return absl::make_unique<PyLocalExecutable>(std::move(local_executables),
std::move(*device_assignment),
std::move(client));
}
} // namespace xla

View File

@ -283,7 +283,8 @@ class PyLocalBuffer {
};
// Represents a compiled computation that can be executed given handles to
// device-allocated literals. Wraps an XLA LocalExecutable.
// device-allocated literals. Wraps one or more XLA LocalExecutables (one per
// partition, as specified by the build options).
class PyLocalExecutable {
public:
// Compiles a computation to an executable.
@ -304,20 +305,24 @@ class PyLocalExecutable {
std::shared_ptr<PyLocalClient> client,
absl::optional<DeviceAssignment> device_assignment);
PyLocalExecutable(std::shared_ptr<LocalExecutable> executable,
PyLocalExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
DeviceAssignment device_assignment,
std::shared_ptr<PyLocalClient> client);
int num_replicas() const {
return executable_->build_options().num_replicas();
return executables_[0]->build_options().num_replicas();
}
int num_partitions() const {
return executable_->build_options().num_partitions();
return executables_[0]->build_options().num_partitions();
}
int64 SizeOfGeneratedCodeInBytes() const {
return executable_->executable()->SizeOfGeneratedCodeInBytes();
int64 size = 0;
for (auto& executable : executables_) {
size += executable->executable()->SizeOfGeneratedCodeInBytes();
}
return size;
}
const DeviceAssignment& device_assignment() const {
@ -346,7 +351,7 @@ class PyLocalExecutable {
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteOnLocalDevices(
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles);
void Delete() { executable_ = nullptr; }
void Delete() { executables_.clear(); }
const string& name() const;
@ -359,7 +364,8 @@ class PyLocalExecutable {
// asynchronous execution, the process being executed can outlive the
// executable itself.
std::shared_ptr<PyLocalClient> const client_;
std::shared_ptr<LocalExecutable> executable_;
// One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_;
std::shared_ptr<DeviceAssignment> device_assignment_;
// The replica and partition indices of device_assignment_ to be run by this

View File

@ -119,7 +119,8 @@ ExecutionOptions CreateExecutionOptions(
} // namespace
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
StatusOr<std::vector<std::unique_ptr<Executable>>>
LocalService::CompileExecutables(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options) {
@ -178,9 +179,29 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
se::StreamExecutor * executor,
execute_backend_->stream_executor(build_options.device_ordinal()));
return BuildExecutable(proto, std::move(module_config),
execute_backend_.get(), executor,
build_options.device_allocator());
// TODO(cjfj): Investigate why there are a couple of test failures when the
// single partition computations are built using `BuildExecutables`, fix it,
// and remove this special case (provided the performance if similar).
if (build_options.num_partitions() == 1) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
executor, build_options.device_allocator()));
std::vector<std::unique_ptr<Executable>> executables;
executables.push_back(std::move(executable));
return executables;
} else {
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
module_configs.push_back(std::move(module_config));
// BuildExecutables uses the executors length to determine the number of
// cores per module, but otherwise only uses the first executor.
std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
executor);
return BuildExecutables({&proto}, std::move(module_configs),
execute_backend_.get(), {executors},
build_options.device_allocator());
}
}
StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
@ -41,12 +42,12 @@ class LocalService : public Service {
static StatusOr<std::unique_ptr<LocalService>> NewService(
const ServiceOptions& options);
// Builds an Executable with the given XlaComputation, argument layouts and
// Builds Executables with the given XlaComputation, argument layouts and
// options. If result_layout is non-null, then the executable is compiled to
// produce a result of the given layout. If device_allocator is non-null,
// then the compiler may use it to allocate temp space on the device. The
// compiler is responsible for freeing any memory it allocates this way.
StatusOr<std::unique_ptr<Executable>> CompileExecutable(
StatusOr<std::vector<std::unique_ptr<Executable>>> CompileExecutables(
const XlaComputation& computation,
const absl::Span<const Shape* const> argument_layouts,
const ExecutableBuildOptions& build_options);

View File

@ -452,7 +452,10 @@ xla_test(
name = "while_test",
srcs = ["while_test.cc"],
deps = [
":client_library_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -464,9 +467,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@ -480,7 +480,9 @@ xla_test(
"interpreter",
],
deps = [
":client_library_test_base",
":test_macros_header",
":test_utils",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@ -489,8 +491,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
@ -950,7 +950,12 @@ xla_test(
"optonly",
],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
@ -959,11 +964,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -984,7 +984,12 @@ xla_test(
"optonly",
],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
@ -993,11 +998,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1048,7 +1048,12 @@ xla_test(
"optonly",
],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:reference_util",
@ -1057,11 +1062,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -1383,7 +1383,10 @@ xla_test(
timeout = "moderate",
srcs = ["dynamic_ops_test.cc"],
deps = [
":client_library_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:test_helpers",
@ -1395,9 +1398,6 @@ xla_test(
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@ -2205,7 +2205,11 @@ xla_test(
name = "cpu_gpu_fusion_test",
srcs = ["cpu_gpu_fusion_test.cc"],
deps = [
":client_library_test_base",
":hlo_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -2214,10 +2218,6 @@ xla_test(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@ -2295,7 +2295,11 @@ xla_test(
shard_count = 30,
tags = ["optonly"],
deps = [
":literal_test_util",
":local_client_test_base",
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@ -2304,16 +2308,12 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:local_client_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@ -2526,13 +2526,13 @@ tf_cc_test(
srcs = ["multiple_devices_on_host_test.cc"],
args = ["--xla_force_host_platform_device_count=4"],
deps = [
":xla_internal_test_main", # fixdeps: keep
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/synchronization",

View File

@ -882,13 +882,14 @@ void BM_ParallelFusion(int num_iters) {
.ConsumeValueOrDie();
// Build executable.
std::unique_ptr<LocalExecutable> executable =
auto executables =
client
->Compile(computation,
{&buffer0.on_host_shape(), &buffer1.on_host_shape(),
&buffer2.on_host_shape()},
ExecutableBuildOptions())
.ConsumeValueOrDie();
auto executable = std::move(executables[0]);
se::Stream stream(executors[device_ordinal]);
stream.Init();

View File

@ -1672,11 +1672,10 @@ void DOT_ReorderContracting(int num_iters) {
client->LiteralToShapedBuffer(input_literal, device_ordinal)
.ConsumeValueOrDie();
std::unique_ptr<LocalExecutable> executable =
client
->Compile(computation, {&buffer0.on_host_shape()},
ExecutableBuildOptions())
.ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto executables, client->Compile(computation, {&buffer0.on_host_shape()},
ExecutableBuildOptions()));
auto executable = std::move(executables[0]);
se::Stream stream(executors[device_ordinal]);
stream.Init();

View File

@ -779,9 +779,10 @@ void BM_DynamicSlice(int num_iters) {
DynamicSlice(input, start_indices, {1, 1, 1, 1});
auto computation = builder.Build().ConsumeValueOrDie();
std::unique_ptr<LocalExecutable> executable =
client->Compile(computation, host_shapes, ExecutableBuildOptions())
.ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
client->Compile(computation, host_shapes, ExecutableBuildOptions()));
auto executable = std::move(executables[0]);
// Run some warm-up executions.
ExecutableRunOptions options;

View File

@ -242,7 +242,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
[&](const Literal* input_literal) { return &input_literal->shape(); });
TF_ASSIGN_OR_RETURN(
auto executable,
auto executables,
client_->Compile(computation, input_shapes, build_opts));
std::vector<ScopedShapedBuffer> input_buffers;
@ -264,7 +264,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
run_opts.set_intra_op_thread_pool(
client_->backend().eigen_intra_op_thread_pool_device());
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
executable->Run(input_buffer_pointers, run_opts));
executables[0]->Run(input_buffer_pointers, run_opts));
TF_ASSIGN_OR_RETURN(Literal result_literal,
client_->ShapedBufferToLiteral(result));

View File

@ -46,12 +46,13 @@ TEST_F(HloMetadataTest, MetadataPropagation) {
Shape argument_layout = ShapeUtil::MakeShape(F32, {});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> executable,
auto executables,
local_client_->Compile(builder.Build().ValueOrDie(),
{&argument_layout, &argument_layout},
ExecutableBuildOptions()));
auto instruction = executable->executable()
auto instruction = executables[0]
->executable()
->module()
.entry_computation()
->root_instruction();
@ -67,15 +68,14 @@ TEST_F(HloMetadataTest, MetadataClearing) {
BuildAddComputation(&builder);
Shape argument_layout = ShapeUtil::MakeShape(F32, {});
auto executable_status = local_client_->Compile(
builder.Build().ValueOrDie(), {&argument_layout, &argument_layout},
ExecutableBuildOptions());
ASSERT_IS_OK(executable_status);
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
local_client_->Compile(builder.Build().ValueOrDie(),
{&argument_layout, &argument_layout},
ExecutableBuildOptions()));
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
auto instruction = executable->executable()
auto instruction = executables[0]
->executable()
->module()
.entry_computation()
->root_instruction();

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@ -759,17 +760,17 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
Shape argument_layout =
ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
auto executable_status =
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
ExecutableBuildOptions());
ASSERT_IS_OK(executable_status);
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
ExecutableBuildOptions()));
EXPECT_EQ(1, executables.size());
auto x_array =
LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
executable->Run({&x_array}, DefaultExecutableRunOptions())
executables[0]
->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
ASSERT_IS_OK(local_client_->mutable_backend()
->BorrowStream(0)
@ -780,6 +781,31 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
{2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
}
XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) {
if (local_client_->device_count() < 2) {
GTEST_SKIP_("requires two devices");
}
XlaBuilder builder(TestName());
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
auto z = ConstantR1<float>(&builder, {5.0f, 6.0f, 7.0f});
auto r = Add(x, y);
builder.SetSharding(sharding_builder::AssignDevice(1));
Add(r, z);
builder.ClearSharding();
Shape argument_layout =
ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
ExecutableBuildOptions build_options;
build_options.set_num_partitions(2);
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
build_options));
EXPECT_EQ(2, executables.size());
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
// Test copying Literals to the device as ShapedBuffers, then copying them
// back again to Literals.
@ -928,11 +954,10 @@ void BM_LocalClientOverhead(int num_iters) {
const int kWarmups = 2;
auto executable_status = client->Compile(
computation, {&buffer.on_host_shape()}, ExecutableBuildOptions());
ASSERT_IS_OK(executable_status);
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto executables, client->Compile(computation, {&buffer.on_host_shape()},
ExecutableBuildOptions()));
std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
ExecutableRunOptions run_options;
run_options.set_allocator(&allocator).set_stream(stream.get());

View File

@ -194,9 +194,10 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
argument_layouts[i] = &arguments[i]->on_host_shape();
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<LocalExecutable> executable,
auto executables,
local_client_->Compile(computation, argument_layouts, build_options));
TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
TF_RET_CHECK(executables.size() == 1);
TF_ASSIGN_OR_RETURN(auto ret, executables[0]->Run(arguments, run_options));
auto device_ordinal =
build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();

View File

@ -65,8 +65,9 @@ void TestWithDeviceCount(const int device_count) {
TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> executable,
auto executables,
client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{}));
std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
std::vector<tensorflow::Thread*> threads;
absl::Mutex results_mutex;
std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>> results;

View File

@ -47,12 +47,12 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) {
computation_status = builder.Build();
TF_ASSERT_OK(computation_status.status());
auto executable_status = local_client_->Compile(
computation_status.ValueOrDie(), {&pair_float, &single_float},
ExecutableBuildOptions());
TF_ASSERT_OK(executable_status.status());
HloModule& module = const_cast<HloModule&>(
executable_status.ValueOrDie()->executable()->module());
TF_ASSERT_OK_AND_ASSIGN(
auto executables, local_client_->Compile(computation_status.ValueOrDie(),
{&pair_float, &single_float},
ExecutableBuildOptions()));
HloModule& module =
const_cast<HloModule&>(executables[0]->executable()->module());
TF_ASSERT_OK(MakeFakeArguments(&module).status());
}

View File

@ -1314,9 +1314,10 @@ void BM_WhileLoop(int num_iters) {
While(condition, body, init);
auto computation = builder.Build().ConsumeValueOrDie();
std::unique_ptr<LocalExecutable> executable =
client->Compile(computation, {}, ExecutableBuildOptions())
.ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto executables,
client->Compile(computation, {}, ExecutableBuildOptions()));
auto executable = std::move(executables[0]);
// Run some warm-up executions.
ExecutableRunOptions options;

View File

@ -158,11 +158,11 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
ExecutableBuildOptions build_options;
build_options.mutable_debug_options()->set_xla_hlo_profile(true);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> local_executable,
auto local_executables,
client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape},
build_options));
Executable* executable = local_executable->executable();
Executable* executable = local_executables[0]->executable();
HloExecutionProfile hlo_execution_profile(
&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());

View File

@ -85,10 +85,11 @@ void RealMain(absl::Span<char* const> args) {
ExecutableBuildOptions build_options;
build_options.set_device_ordinal(0);
build_options.set_result_layout(program_shape->result());
StatusOr<std::unique_ptr<Executable>> executable =
local_service->CompileExecutable(computation, layouts, build_options);
const HloModule& module = executable.ValueOrDie()->module();
auto executables =
local_service->CompileExecutables(computation, layouts, build_options)
.ConsumeValueOrDie();
CHECK_EQ(executables.size(), 1);
const HloModule& module = executables[0]->module();
OperationDumper dumper(arg);
for (auto* computation : module.computations()) {

View File

@ -62,10 +62,11 @@ void RealMain(absl::Span<char* const> args, bool compile) {
ExecutableBuildOptions build_options;
build_options.set_device_ordinal(0);
build_options.set_result_layout(program_shape->result());
StatusOr<std::unique_ptr<Executable>> executable =
local_service->CompileExecutable(computation, layouts, build_options);
const HloModule& module = executable.ValueOrDie()->module();
auto executables =
local_service->CompileExecutables(computation, layouts, build_options)
.ConsumeValueOrDie();
CHECK_EQ(executables.size(), 1);
const HloModule& module = executables[0]->module();
fprintf(stdout, "HLO compiled for %s backend:\n%s\n",
local_service->backend().platform()->Name().c_str(),

View File

@ -125,7 +125,11 @@ StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
}
ExecutableBuildOptions exec_build_options;
*exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags();
return client->Compile(computation, argument_layout_ptrs, exec_build_options);
TF_ASSIGN_OR_RETURN(
auto executables,
client->Compile(computation, argument_layout_ptrs, exec_build_options));
TF_RET_CHECK(executables.size() == 1);
return std::move(executables[0]);
}
absl::optional<Shape> GetXfeedShape(bool is_infeed,

View File

@ -128,12 +128,11 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
}
VLOG(1) << "Building executable";
auto compile_result =
client->Compile(computation, argument_layout_ptrs, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*program = std::move(compile_result.ValueOrDie());
TF_ASSIGN_OR_RETURN(
auto executables,
client->Compile(computation, argument_layout_ptrs, build_options));
TF_RET_CHECK(executables.size() == 1);
*program = std::move(executables[0]);
return Status::OK();
}

View File

@ -285,9 +285,12 @@ xla::ProgramShape XlaCompiledProgramShape(
for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) {
parameters_shapes.push_back(&input_program_shape.parameters(i));
}
auto local_executable =
std::vector<std::unique_ptr<xla::LocalExecutable>> local_executables =
client->Compile(computation, parameters_shapes, exec_options)
.ValueOrDie();
.ConsumeValueOrDie();
EXPECT_EQ(local_executables.size(), 1);
std::unique_ptr<xla::LocalExecutable> local_executable =
std::move(local_executables[0]);
return local_executable->executable()
->module()
.entry_computation()