Change LocalClient::Compile
to support returning multiple executables (one per partition).
PiperOrigin-RevId: 291343358 Change-Id: I0550040ddbb67e78e9e4078185e0af6b11b96e35
This commit is contained in:
parent
c42a05f658
commit
a3edaa7235
@ -163,11 +163,12 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
build_options.set_device_allocator(options.device_allocator);
|
||||
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executables,
|
||||
client_->Compile(*result.computation, argument_layouts, build_options));
|
||||
TF_RET_CHECK(executables.size() == 1);
|
||||
*executable = std::move(executables[0]);
|
||||
auto compile_result =
|
||||
client_->Compile(*result.computation, argument_layouts, build_options);
|
||||
if (!compile_result.ok()) {
|
||||
return compile_result.status();
|
||||
}
|
||||
*executable = std::move(compile_result.ValueOrDie());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -117,10 +117,8 @@ XlaJitCompiledCpuFunction::Compile(
|
||||
// Compile the executable. The static_cast to the CpuExecutable subclass is
|
||||
// necessary since the raw function and buffer assignments are only available
|
||||
// there.
|
||||
TF_ASSIGN_OR_RETURN(auto executables,
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
|
||||
client->Compile(computation, arg_shapes, build_options));
|
||||
TF_RET_CHECK(executables.size() == 1);
|
||||
std::unique_ptr<xla::LocalExecutable> executable = std::move(executables[0]);
|
||||
const xla::cpu::CpuExecutable* cpu_executable =
|
||||
static_cast<xla::cpu::CpuExecutable*>(executable->executable());
|
||||
XlaCompiledCpuFunction::RawFunction raw_function =
|
||||
|
@ -337,7 +337,7 @@ Backend* LocalClient::mutable_backend() {
|
||||
return local_service_->mutable_backend();
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
|
||||
StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
|
||||
const XlaComputation& computation,
|
||||
const absl::Span<const Shape* const> argument_layouts,
|
||||
const ExecutableBuildOptions& options) {
|
||||
@ -347,20 +347,12 @@ StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> LocalClient::Compile(
|
||||
VLOG(3) << "Set device ordinal to default value of: "
|
||||
<< updated_options.device_ordinal();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
|
||||
local_service_->CompileExecutables(
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
local_service_->CompileExecutable(
|
||||
computation, argument_layouts, updated_options));
|
||||
|
||||
std::vector<std::unique_ptr<LocalExecutable>> local_executables;
|
||||
local_executables.reserve(executables.size());
|
||||
|
||||
for (auto& executable : executables) {
|
||||
local_executables.push_back(absl::make_unique<LocalExecutable>(
|
||||
std::move(executable), local_service_->mutable_backend(),
|
||||
updated_options));
|
||||
}
|
||||
|
||||
return std::move(local_executables);
|
||||
return absl::make_unique<LocalExecutable>(std::move(executable),
|
||||
local_service_->mutable_backend(),
|
||||
updated_options);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
@ -111,13 +110,12 @@ class LocalClient : public Client {
|
||||
LocalClient(const LocalClient&) = delete;
|
||||
void operator=(const LocalClient&) = delete;
|
||||
|
||||
// Build and return LocalExecutable objects (one per partition, as specified
|
||||
// by the build options). The executable is compiled using the given
|
||||
// XlaComputation, argument layouts and options.
|
||||
// Build and return a LocalExecutable object. The executable is compiled using
|
||||
// the given XlaComputation, argument layouts and options.
|
||||
//
|
||||
// The given ExecutableBuildOptions overrides any values from XLA_FLAGS
|
||||
// environment variable.
|
||||
StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> Compile(
|
||||
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
|
||||
const XlaComputation& computation,
|
||||
const absl::Span<const Shape* const> argument_layouts,
|
||||
const ExecutableBuildOptions& options);
|
||||
|
@ -676,27 +676,16 @@ static std::shared_ptr<Device> LookupDevice(const PyLocalClient& client,
|
||||
}
|
||||
|
||||
PyLocalExecutable::PyLocalExecutable(
|
||||
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
std::shared_ptr<LocalExecutable> executable,
|
||||
DeviceAssignment device_assignment, std::shared_ptr<PyLocalClient> client)
|
||||
: client_(std::move(client)),
|
||||
executable_(std::move(executable)),
|
||||
device_assignment_(
|
||||
std::make_shared<DeviceAssignment>(device_assignment)) {
|
||||
executables_.reserve(executables.size());
|
||||
for (auto& executable : executables) {
|
||||
executables_.emplace_back(std::move(executable));
|
||||
}
|
||||
|
||||
// This must go after `executables_` is initialized.
|
||||
VLOG(1) << "PyLocalExecutable " << name() << " device_assignment:\n"
|
||||
<< device_assignment_->ToString();
|
||||
|
||||
const int num_replicas = device_assignment_->replica_count();
|
||||
const int num_partitions = device_assignment_->computation_count();
|
||||
|
||||
CHECK_EQ(num_partitions, executables_.size())
|
||||
<< "Number of executables " << executables_.size()
|
||||
<< " did not match number of partitions " << num_partitions;
|
||||
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = (*device_assignment_)(replica, partition);
|
||||
@ -715,7 +704,7 @@ PyLocalExecutable::PyLocalExecutable(
|
||||
}
|
||||
|
||||
const std::string& PyLocalExecutable::name() const {
|
||||
Executable* executable = executables_[0]->executable();
|
||||
Executable* executable = executable_->executable();
|
||||
if (executable->has_module()) {
|
||||
return executable->module().name();
|
||||
} else {
|
||||
@ -788,7 +777,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
||||
options.set_run_id(run_id);
|
||||
|
||||
StatusOr<ScopedShapedBuffer> result_buffer_or_status =
|
||||
executables_[partition]->RunAsync(argument_buffer_ptrs, options);
|
||||
executable_->RunAsync(argument_buffer_ptrs, options);
|
||||
|
||||
VLOG(1) << "Replica " << replica << " partition " << partition
|
||||
<< " completed; ok=" << result_buffer_or_status.ok();
|
||||
@ -818,8 +807,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
||||
|
||||
device_state->ThenRelease(
|
||||
device_state->compute_stream(),
|
||||
std::make_tuple(executables_[partition], compute_reservation,
|
||||
device_assignment_));
|
||||
std::make_tuple(executable_, compute_reservation, device_assignment_));
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
result_buffer.on_host_shape(), result_buffer.on_device_shape(),
|
||||
std::move(out_buffer), client_, device);
|
||||
@ -1044,14 +1032,13 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
TF_RETURN_IF_ERROR(assign_layouts(&result_layout));
|
||||
options.set_result_layout(result_layout);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::unique_ptr<LocalExecutable>> local_executables,
|
||||
client->client()->Compile(computation, argument_layout_pointers,
|
||||
options));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<LocalExecutable> local_executable,
|
||||
client->client()->Compile(
|
||||
computation, argument_layout_pointers, options));
|
||||
|
||||
return absl::make_unique<PyLocalExecutable>(std::move(local_executables),
|
||||
std::move(*device_assignment),
|
||||
std::move(client));
|
||||
return absl::make_unique<PyLocalExecutable>(
|
||||
std::shared_ptr<LocalExecutable>(std::move(local_executable)),
|
||||
std::move(*device_assignment), std::move(client));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -283,8 +283,7 @@ class PyLocalBuffer {
|
||||
};
|
||||
|
||||
// Represents a compiled computation that can be executed given handles to
|
||||
// device-allocated literals. Wraps one or more XLA LocalExecutables (one per
|
||||
// partition, as specified by the build options).
|
||||
// device-allocated literals. Wraps an XLA LocalExecutable.
|
||||
class PyLocalExecutable {
|
||||
public:
|
||||
// Compiles a computation to an executable.
|
||||
@ -295,24 +294,20 @@ class PyLocalExecutable {
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
absl::optional<DeviceAssignment> device_assignment);
|
||||
|
||||
PyLocalExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
PyLocalExecutable(std::shared_ptr<LocalExecutable> executable,
|
||||
DeviceAssignment device_assignment,
|
||||
std::shared_ptr<PyLocalClient> client);
|
||||
|
||||
int num_replicas() const {
|
||||
return executables_[0]->build_options().num_replicas();
|
||||
return executable_->build_options().num_replicas();
|
||||
}
|
||||
|
||||
int num_partitions() const {
|
||||
return executables_[0]->build_options().num_partitions();
|
||||
return executable_->build_options().num_partitions();
|
||||
}
|
||||
|
||||
int64 SizeOfGeneratedCodeInBytes() const {
|
||||
int64 size = 0;
|
||||
for (auto& executable : executables_) {
|
||||
size += executable->executable()->SizeOfGeneratedCodeInBytes();
|
||||
}
|
||||
return size;
|
||||
return executable_->executable()->SizeOfGeneratedCodeInBytes();
|
||||
}
|
||||
|
||||
const DeviceAssignment& device_assignment() const {
|
||||
@ -341,7 +336,7 @@ class PyLocalExecutable {
|
||||
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles);
|
||||
|
||||
void Delete() { executables_.clear(); }
|
||||
void Delete() { executable_ = nullptr; }
|
||||
|
||||
const string& name() const;
|
||||
|
||||
@ -354,8 +349,7 @@ class PyLocalExecutable {
|
||||
// asynchronous execution, the process being executed can outlive the
|
||||
// executable itself.
|
||||
std::shared_ptr<PyLocalClient> const client_;
|
||||
// One executable per partition.
|
||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||
std::shared_ptr<LocalExecutable> executable_;
|
||||
std::shared_ptr<DeviceAssignment> device_assignment_;
|
||||
|
||||
// The replica and partition indices of device_assignment_ to be run by this
|
||||
|
@ -119,8 +119,7 @@ ExecutionOptions CreateExecutionOptions(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>>
|
||||
LocalService::CompileExecutables(
|
||||
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||
const XlaComputation& computation,
|
||||
const absl::Span<const Shape* const> argument_layouts,
|
||||
const ExecutableBuildOptions& build_options) {
|
||||
@ -179,29 +178,9 @@ LocalService::CompileExecutables(
|
||||
se::StreamExecutor * executor,
|
||||
execute_backend_->stream_executor(build_options.device_ordinal()));
|
||||
|
||||
// TODO(cjfj): Investigate why there are a couple of test failures when the
|
||||
// single partition computations are built using `BuildExecutables`, fix it,
|
||||
// and remove this special case (provided the performance if similar).
|
||||
if (build_options.num_partitions() == 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
BuildExecutable(proto, std::move(module_config), execute_backend_.get(),
|
||||
executor, build_options.device_allocator()));
|
||||
std::vector<std::unique_ptr<Executable>> executables;
|
||||
executables.push_back(std::move(executable));
|
||||
return executables;
|
||||
} else {
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
|
||||
module_configs.push_back(std::move(module_config));
|
||||
// BuildExecutables uses the executors length to determine the number of
|
||||
// cores per module, but otherwise only uses the first executor.
|
||||
std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
|
||||
executor);
|
||||
|
||||
return BuildExecutables({&proto}, std::move(module_configs),
|
||||
execute_backend_.get(), {executors},
|
||||
build_options.device_allocator());
|
||||
}
|
||||
return BuildExecutable(proto, std::move(module_config),
|
||||
execute_backend_.get(), executor,
|
||||
build_options.device_allocator());
|
||||
}
|
||||
|
||||
StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
@ -42,12 +41,12 @@ class LocalService : public Service {
|
||||
static StatusOr<std::unique_ptr<LocalService>> NewService(
|
||||
const ServiceOptions& options);
|
||||
|
||||
// Builds Executables with the given XlaComputation, argument layouts and
|
||||
// Builds an Executable with the given XlaComputation, argument layouts and
|
||||
// options. If result_layout is non-null, then the executable is compiled to
|
||||
// produce a result of the given layout. If device_allocator is non-null,
|
||||
// then the compiler may use it to allocate temp space on the device. The
|
||||
// compiler is responsible for freeing any memory it allocates this way.
|
||||
StatusOr<std::vector<std::unique_ptr<Executable>>> CompileExecutables(
|
||||
StatusOr<std::unique_ptr<Executable>> CompileExecutable(
|
||||
const XlaComputation& computation,
|
||||
const absl::Span<const Shape* const> argument_layouts,
|
||||
const ExecutableBuildOptions& build_options);
|
||||
|
@ -452,10 +452,7 @@ xla_test(
|
||||
name = "while_test",
|
||||
srcs = ["while_test.cc"],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -467,6 +464,9 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
@ -480,9 +480,7 @@ xla_test(
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":test_macros_header",
|
||||
":test_utils",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -491,6 +489,8 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/service:stream_pool",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
@ -950,12 +950,7 @@ xla_test(
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":test_utils",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
@ -964,6 +959,11 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -984,12 +984,7 @@ xla_test(
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":test_utils",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
@ -998,6 +993,11 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -1048,12 +1048,7 @@ xla_test(
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":test_utils",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
@ -1062,6 +1057,11 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -1383,10 +1383,7 @@ xla_test(
|
||||
timeout = "moderate",
|
||||
srcs = ["dynamic_ops_test.cc"],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -1398,6 +1395,9 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:test",
|
||||
@ -2203,11 +2203,7 @@ xla_test(
|
||||
name = "cpu_gpu_fusion_test",
|
||||
srcs = ["cpu_gpu_fusion_test.cc"],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":hlo_test_base",
|
||||
":literal_test_util",
|
||||
":test_macros_header",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -2216,6 +2212,10 @@ xla_test(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
@ -2293,11 +2293,7 @@ xla_test(
|
||||
shard_count = 30,
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":literal_test_util",
|
||||
":local_client_test_base",
|
||||
":test_macros_header",
|
||||
":test_utils",
|
||||
":xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -2306,12 +2302,16 @@ xla_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:sharding_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:local_service",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:local_client_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:test",
|
||||
@ -2524,13 +2524,13 @@ tf_cc_test(
|
||||
srcs = ["multiple_devices_on_host_test.cc"],
|
||||
args = ["--xla_force_host_platform_device_count=4"],
|
||||
deps = [
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
|
@ -882,14 +882,13 @@ void BM_ParallelFusion(int num_iters) {
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
// Build executable.
|
||||
auto executables =
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
client
|
||||
->Compile(computation,
|
||||
{&buffer0.on_host_shape(), &buffer1.on_host_shape(),
|
||||
&buffer2.on_host_shape()},
|
||||
ExecutableBuildOptions())
|
||||
.ConsumeValueOrDie();
|
||||
auto executable = std::move(executables[0]);
|
||||
|
||||
se::Stream stream(executors[device_ordinal]);
|
||||
stream.Init();
|
||||
|
@ -1672,10 +1672,11 @@ void DOT_ReorderContracting(int num_iters) {
|
||||
client->LiteralToShapedBuffer(input_literal, device_ordinal)
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables, client->Compile(computation, {&buffer0.on_host_shape()},
|
||||
ExecutableBuildOptions()));
|
||||
auto executable = std::move(executables[0]);
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
client
|
||||
->Compile(computation, {&buffer0.on_host_shape()},
|
||||
ExecutableBuildOptions())
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
se::Stream stream(executors[device_ordinal]);
|
||||
stream.Init();
|
||||
|
@ -779,10 +779,9 @@ void BM_DynamicSlice(int num_iters) {
|
||||
DynamicSlice(input, start_indices, {1, 1, 1, 1});
|
||||
auto computation = builder.Build().ConsumeValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables,
|
||||
client->Compile(computation, host_shapes, ExecutableBuildOptions()));
|
||||
auto executable = std::move(executables[0]);
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
client->Compile(computation, host_shapes, ExecutableBuildOptions())
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
// Run some warm-up executions.
|
||||
ExecutableRunOptions options;
|
||||
|
@ -242,7 +242,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
[&](const Literal* input_literal) { return &input_literal->shape(); });
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executables,
|
||||
auto executable,
|
||||
client_->Compile(computation, input_shapes, build_opts));
|
||||
|
||||
std::vector<ScopedShapedBuffer> input_buffers;
|
||||
@ -264,7 +264,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase {
|
||||
run_opts.set_intra_op_thread_pool(
|
||||
client_->backend().eigen_intra_op_thread_pool_device());
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||
executables[0]->Run(input_buffer_pointers, run_opts));
|
||||
executable->Run(input_buffer_pointers, run_opts));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Literal result_literal,
|
||||
client_->ShapedBufferToLiteral(result));
|
||||
|
@ -46,13 +46,12 @@ TEST_F(HloMetadataTest, MetadataPropagation) {
|
||||
|
||||
Shape argument_layout = ShapeUtil::MakeShape(F32, {});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables,
|
||||
std::unique_ptr<LocalExecutable> executable,
|
||||
local_client_->Compile(builder.Build().ValueOrDie(),
|
||||
{&argument_layout, &argument_layout},
|
||||
ExecutableBuildOptions()));
|
||||
|
||||
auto instruction = executables[0]
|
||||
->executable()
|
||||
auto instruction = executable->executable()
|
||||
->module()
|
||||
.entry_computation()
|
||||
->root_instruction();
|
||||
@ -68,14 +67,15 @@ TEST_F(HloMetadataTest, MetadataClearing) {
|
||||
BuildAddComputation(&builder);
|
||||
|
||||
Shape argument_layout = ShapeUtil::MakeShape(F32, {});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables,
|
||||
local_client_->Compile(builder.Build().ValueOrDie(),
|
||||
{&argument_layout, &argument_layout},
|
||||
ExecutableBuildOptions()));
|
||||
auto executable_status = local_client_->Compile(
|
||||
builder.Build().ValueOrDie(), {&argument_layout, &argument_layout},
|
||||
ExecutableBuildOptions());
|
||||
ASSERT_IS_OK(executable_status);
|
||||
|
||||
auto instruction = executables[0]
|
||||
->executable()
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
executable_status.ConsumeValueOrDie();
|
||||
|
||||
auto instruction = executable->executable()
|
||||
->module()
|
||||
.entry_computation()
|
||||
->root_instruction();
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -760,17 +759,17 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
|
||||
|
||||
Shape argument_layout =
|
||||
ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables,
|
||||
auto executable_status =
|
||||
local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
|
||||
ExecutableBuildOptions()));
|
||||
EXPECT_EQ(1, executables.size());
|
||||
ExecutableBuildOptions());
|
||||
ASSERT_IS_OK(executable_status);
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
executable_status.ConsumeValueOrDie();
|
||||
|
||||
auto x_array =
|
||||
LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
|
||||
ScopedShapedBuffer result =
|
||||
executables[0]
|
||||
->Run({&x_array}, DefaultExecutableRunOptions())
|
||||
executable->Run({&x_array}, DefaultExecutableRunOptions())
|
||||
.ConsumeValueOrDie();
|
||||
ASSERT_IS_OK(local_client_->mutable_backend()
|
||||
->BorrowStream(0)
|
||||
@ -781,31 +780,6 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
|
||||
{2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) {
|
||||
if (local_client_->device_count() < 2) {
|
||||
GTEST_SKIP_("requires two devices");
|
||||
}
|
||||
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
|
||||
auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
|
||||
auto z = ConstantR1<float>(&builder, {5.0f, 6.0f, 7.0f});
|
||||
auto r = Add(x, y);
|
||||
builder.SetSharding(sharding_builder::AssignDevice(1));
|
||||
Add(r, z);
|
||||
builder.ClearSharding();
|
||||
|
||||
Shape argument_layout =
|
||||
ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
|
||||
ExecutableBuildOptions build_options;
|
||||
build_options.set_num_partitions(2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables,
|
||||
local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
|
||||
build_options));
|
||||
EXPECT_EQ(2, executables.size());
|
||||
}
|
||||
|
||||
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
|
||||
// Test copying Literals to the device as ShapedBuffers, then copying them
|
||||
// back again to Literals.
|
||||
@ -954,10 +928,11 @@ void BM_LocalClientOverhead(int num_iters) {
|
||||
|
||||
const int kWarmups = 2;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables, client->Compile(computation, {&buffer.on_host_shape()},
|
||||
ExecutableBuildOptions()));
|
||||
std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
|
||||
auto executable_status = client->Compile(
|
||||
computation, {&buffer.on_host_shape()}, ExecutableBuildOptions());
|
||||
ASSERT_IS_OK(executable_status);
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
executable_status.ConsumeValueOrDie();
|
||||
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_allocator(&allocator).set_stream(stream.get());
|
||||
|
@ -194,10 +194,9 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
|
||||
argument_layouts[i] = &arguments[i]->on_host_shape();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executables,
|
||||
std::unique_ptr<LocalExecutable> executable,
|
||||
local_client_->Compile(computation, argument_layouts, build_options));
|
||||
TF_RET_CHECK(executables.size() == 1);
|
||||
TF_ASSIGN_OR_RETURN(auto ret, executables[0]->Run(arguments, run_options));
|
||||
TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
|
||||
|
||||
auto device_ordinal =
|
||||
build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
|
||||
|
@ -65,9 +65,8 @@ void TestWithDeviceCount(const int device_count) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation());
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables,
|
||||
std::unique_ptr<LocalExecutable> executable,
|
||||
client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{}));
|
||||
std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
|
||||
std::vector<tensorflow::Thread*> threads;
|
||||
absl::Mutex results_mutex;
|
||||
std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>> results;
|
||||
|
@ -47,12 +47,12 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) {
|
||||
computation_status = builder.Build();
|
||||
TF_ASSERT_OK(computation_status.status());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables, local_client_->Compile(computation_status.ValueOrDie(),
|
||||
{&pair_float, &single_float},
|
||||
ExecutableBuildOptions()));
|
||||
HloModule& module =
|
||||
const_cast<HloModule&>(executables[0]->executable()->module());
|
||||
auto executable_status = local_client_->Compile(
|
||||
computation_status.ValueOrDie(), {&pair_float, &single_float},
|
||||
ExecutableBuildOptions());
|
||||
TF_ASSERT_OK(executable_status.status());
|
||||
HloModule& module = const_cast<HloModule&>(
|
||||
executable_status.ValueOrDie()->executable()->module());
|
||||
TF_ASSERT_OK(MakeFakeArguments(&module).status());
|
||||
}
|
||||
|
||||
|
@ -1314,10 +1314,9 @@ void BM_WhileLoop(int num_iters) {
|
||||
While(condition, body, init);
|
||||
auto computation = builder.Build().ConsumeValueOrDie();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto executables,
|
||||
client->Compile(computation, {}, ExecutableBuildOptions()));
|
||||
auto executable = std::move(executables[0]);
|
||||
std::unique_ptr<LocalExecutable> executable =
|
||||
client->Compile(computation, {}, ExecutableBuildOptions())
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
// Run some warm-up executions.
|
||||
ExecutableRunOptions options;
|
||||
|
@ -158,11 +158,11 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
|
||||
ExecutableBuildOptions build_options;
|
||||
build_options.mutable_debug_options()->set_xla_hlo_profile(true);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto local_executables,
|
||||
std::unique_ptr<LocalExecutable> local_executable,
|
||||
client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape},
|
||||
build_options));
|
||||
|
||||
Executable* executable = local_executables[0]->executable();
|
||||
Executable* executable = local_executable->executable();
|
||||
HloExecutionProfile hlo_execution_profile(
|
||||
&executable->hlo_profile_printer_data(),
|
||||
&executable->hlo_profile_index_map());
|
||||
|
@ -85,11 +85,10 @@ void RealMain(absl::Span<char* const> args) {
|
||||
ExecutableBuildOptions build_options;
|
||||
build_options.set_device_ordinal(0);
|
||||
build_options.set_result_layout(program_shape->result());
|
||||
auto executables =
|
||||
local_service->CompileExecutables(computation, layouts, build_options)
|
||||
.ConsumeValueOrDie();
|
||||
CHECK_EQ(executables.size(), 1);
|
||||
const HloModule& module = executables[0]->module();
|
||||
StatusOr<std::unique_ptr<Executable>> executable =
|
||||
local_service->CompileExecutable(computation, layouts, build_options);
|
||||
|
||||
const HloModule& module = executable.ValueOrDie()->module();
|
||||
|
||||
OperationDumper dumper(arg);
|
||||
for (auto* computation : module.computations()) {
|
||||
|
@ -62,11 +62,10 @@ void RealMain(absl::Span<char* const> args, bool compile) {
|
||||
ExecutableBuildOptions build_options;
|
||||
build_options.set_device_ordinal(0);
|
||||
build_options.set_result_layout(program_shape->result());
|
||||
auto executables =
|
||||
local_service->CompileExecutables(computation, layouts, build_options)
|
||||
.ConsumeValueOrDie();
|
||||
CHECK_EQ(executables.size(), 1);
|
||||
const HloModule& module = executables[0]->module();
|
||||
StatusOr<std::unique_ptr<Executable>> executable =
|
||||
local_service->CompileExecutable(computation, layouts, build_options);
|
||||
|
||||
const HloModule& module = executable.ValueOrDie()->module();
|
||||
|
||||
fprintf(stdout, "HLO compiled for %s backend:\n%s\n",
|
||||
local_service->backend().platform()->Name().c_str(),
|
||||
|
@ -125,11 +125,7 @@ StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
|
||||
}
|
||||
ExecutableBuildOptions exec_build_options;
|
||||
*exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executables,
|
||||
client->Compile(computation, argument_layout_ptrs, exec_build_options));
|
||||
TF_RET_CHECK(executables.size() == 1);
|
||||
return std::move(executables[0]);
|
||||
return client->Compile(computation, argument_layout_ptrs, exec_build_options);
|
||||
}
|
||||
|
||||
absl::optional<Shape> GetXfeedShape(bool is_infeed,
|
||||
|
@ -126,11 +126,12 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
VLOG(1) << "Building executable";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executables,
|
||||
client->Compile(computation, argument_layout_ptrs, build_options));
|
||||
TF_RET_CHECK(executables.size() == 1);
|
||||
*program = std::move(executables[0]);
|
||||
auto compile_result =
|
||||
client->Compile(computation, argument_layout_ptrs, build_options);
|
||||
if (!compile_result.ok()) {
|
||||
return compile_result.status();
|
||||
}
|
||||
*program = std::move(compile_result.ValueOrDie());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -286,8 +286,8 @@ xla::ProgramShape XlaCompiledProgramShape(
|
||||
parameters_shapes.push_back(&input_program_shape.parameters(i));
|
||||
}
|
||||
auto local_executable =
|
||||
std::move(client->Compile(computation, parameters_shapes, exec_options)
|
||||
.ValueOrDie()[0]);
|
||||
client->Compile(computation, parameters_shapes, exec_options)
|
||||
.ValueOrDie();
|
||||
return local_executable->executable()
|
||||
->module()
|
||||
.entry_computation()
|
||||
|
Loading…
Reference in New Issue
Block a user