Refactor XLA's CompileAheadOfTime out of LocalClient into a new CompileOnlyClient class, and likewise from LocalService into a new CompileOnlyService class.

This also renames AheadOfTimeComputationInstance to AotComputationInstance for consistency with AotCompilationResult and AotCompilationOptions in compiler/xla/service/compiler.h.
Change: 155252320
This commit is contained in:
A. Unique TensorFlower 2017-05-05 15:10:26 -08:00 committed by TensorFlower Gardener
parent 1a1bef744a
commit 70c60b1491
17 changed files with 512 additions and 167 deletions

View File

@ -73,7 +73,7 @@ cc_library(
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
@ -274,7 +274,8 @@ Status CreateXlaArgs(const Graph& graph,
// Converts the TensorFlow graph into an XLA computation, by executing the // Converts the TensorFlow graph into an XLA computation, by executing the
// graph symbolically, with each op building up the XLA HLO. // graph symbolically, with each op building up the XLA HLO.
Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph, Status ConvertGraphToXla(xla::CompileOnlyClient* client,
std::unique_ptr<Graph> graph,
xla::Computation* computation, bool* has_context_arg) { xla::Computation* computation, bool* has_context_arg) {
// Create a device and context to convert the graph into an XLA computation. // Create a device and context to convert the graph into an XLA computation.
XlaOpRegistry::RegisterCompilationKernels(); XlaOpRegistry::RegisterCompilationKernels();
@ -333,7 +334,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
} }
// Compiles the XLA computation into executable code. // Compiles the XLA computation into executable code.
Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, Status CompileXla(xla::CompileOnlyClient* client,
const xla::Computation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts, const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) { CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation. // Retrieves arg and result layouts from the computation.
@ -350,7 +352,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
for (int i = 0; i < pshape->parameters_size(); ++i) { for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i)); arg_layouts.push_back(pshape->mutable_parameters(i));
} }
xla::LocalClient::AheadOfTimeComputationInstance instance; xla::CompileOnlyClient::AotComputationInstance instance;
instance.computation = &computation; instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts); instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result(); instance.result_layout = &pshape->result();
@ -365,7 +367,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
std::move(aot_or.ValueOrDie().back())); std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name(); compile_result->entry_point = aot_opts.entry_point_name();
compile_result->pointer_size = compile_result->pointer_size =
xla::LocalClient::PointerSizeForTriple(aot_opts.triple()); xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK(); return Status::OK();
} }
@ -394,8 +396,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
namespace gpu = perftools::gputools; namespace gpu = perftools::gputools;
gpu::Platform* cpu_platform = gpu::Platform* cpu_platform =
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
xla::LocalClient* client = xla::CompileOnlyClient* client =
xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie(); xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
xla::Computation computation; xla::Computation computation;
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation, TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
&compile_result->has_context_arg)); &compile_result->has_context_arg));

View File

@ -99,6 +99,26 @@ cc_library(
], ],
) )
cc_library(
name = "compile_only_client",
srcs = ["compile_only_client.cc"],
hdrs = ["compile_only_client.h"],
deps = [
":client",
":computation",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:compile_only_service",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@llvm//:support",
],
)
# This target is used to instantiate the XLA service in-process and create # This target is used to instantiate the XLA service in-process and create
# a client for it. # a client for it.
cc_library( cc_library(
@ -106,12 +126,14 @@ cc_library(
srcs = ["client_library.cc"], srcs = ["client_library.cc"],
hdrs = ["client_library.h"], hdrs = ["client_library.h"],
deps = [ deps = [
":compile_only_client",
":local_client", ":local_client",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compile_only_service",
"//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:platform_util",

View File

@ -69,8 +69,8 @@ ClientLibrary::~ClientLibrary() = default;
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
} }
auto it = client_library.instances_.find(platform->id()); auto it = client_library.local_instances_.find(platform->id());
if (it != client_library.instances_.end()) { if (it != client_library.local_instances_.end()) {
return it->second->client.get(); return it->second->client.get();
} }
@ -78,13 +78,13 @@ ClientLibrary::~ClientLibrary() = default;
service_options.set_platform(platform); service_options.set_platform(platform);
service_options.set_number_of_replicas(replica_count); service_options.set_number_of_replicas(replica_count);
std::unique_ptr<LocalInstance> instance = MakeUnique<LocalInstance>(); auto instance = MakeUnique<LocalInstance>();
TF_ASSIGN_OR_RETURN(instance->service, TF_ASSIGN_OR_RETURN(instance->service,
LocalService::NewService(service_options)); LocalService::NewService(service_options));
instance->client = MakeUnique<LocalClient>(instance->service.get()); instance->client = MakeUnique<LocalClient>(instance->service.get());
LocalClient* cl = instance->client.get(); LocalClient* cl = instance->client.get();
client_library.instances_.insert( client_library.local_instances_.insert(
std::make_pair(platform->id(), std::move(instance))); std::make_pair(platform->id(), std::move(instance)));
return cl; return cl;
} }
@ -99,9 +99,35 @@ ClientLibrary::~ClientLibrary() = default;
perftools::gputools::Platform* platform) { perftools::gputools::Platform* platform) {
ClientLibrary& client_library = Singleton(); ClientLibrary& client_library = Singleton();
tensorflow::mutex_lock lock(client_library.service_mutex_); tensorflow::mutex_lock lock(client_library.service_mutex_);
auto it = client_library.instances_.find(platform->id()); auto it = client_library.local_instances_.find(platform->id());
CHECK(it != client_library.instances_.end()); CHECK(it != client_library.local_instances_.end());
return it->second->service.get(); return it->second->service.get();
} }
/* static */ StatusOr<CompileOnlyClient*>
ClientLibrary::GetOrCreateCompileOnlyClient(
perftools::gputools::Platform* platform) {
ClientLibrary& client_library = Singleton();
tensorflow::mutex_lock lock(client_library.service_mutex_);
if (platform == nullptr) {
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
}
auto it = client_library.compile_only_instances_.find(platform->id());
if (it != client_library.compile_only_instances_.end()) {
return it->second->client.get();
}
auto instance = MakeUnique<CompileOnlyInstance>();
TF_ASSIGN_OR_RETURN(instance->service,
CompileOnlyService::NewService(platform));
instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
CompileOnlyClient* cl = instance->client.get();
client_library.compile_only_instances_.insert(
std::make_pair(platform->id(), std::move(instance)));
return cl;
}
} // namespace xla } // namespace xla

View File

@ -26,7 +26,9 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/compile_only_service.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
@ -76,6 +78,13 @@ class ClientLibrary {
// access user computations from client. // access user computations from client.
static LocalService* GetXlaService(perftools::gputools::Platform* platform); static LocalService* GetXlaService(perftools::gputools::Platform* platform);
// Singleton constructor-or-accessor for compile-only clients. Arguments:
//
// platform : The platform the underlying XLA service should target. If
// null then default platform is used.
static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
perftools::gputools::Platform* platform = nullptr);
private: private:
// Returns the singleton instance of ClientLibrary. // Returns the singleton instance of ClientLibrary.
static ClientLibrary& Singleton(); static ClientLibrary& Singleton();
@ -90,10 +99,21 @@ class ClientLibrary {
std::unique_ptr<LocalClient> client; std::unique_ptr<LocalClient> client;
}; };
struct CompileOnlyInstance {
// Service that is wrapped by the singleton client object.
std::unique_ptr<CompileOnlyService> service;
// Singleton client object.
std::unique_ptr<CompileOnlyClient> client;
};
tensorflow::mutex service_mutex_; // Guards the singleton creation state. tensorflow::mutex service_mutex_; // Guards the singleton creation state.
std::unordered_map<perftools::gputools::Platform::Id, std::unordered_map<perftools::gputools::Platform::Id,
std::unique_ptr<LocalInstance>> std::unique_ptr<LocalInstance>>
instances_ GUARDED_BY(service_mutex_); local_instances_ GUARDED_BY(service_mutex_);
std::unordered_map<perftools::gputools::Platform::Id,
std::unique_ptr<CompileOnlyInstance>>
compile_only_instances_ GUARDED_BY(service_mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary);
}; };

View File

@ -0,0 +1,59 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "external/llvm/include/llvm/ADT/Triple.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace se = ::perftools::gputools;
namespace xla {
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyClient::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
const AotCompilationOptions& options) {
std::vector<CompileOnlyService::AotComputationInstance> service_instances;
service_instances.reserve(computations.size());
for (const AotComputationInstance& instance : computations) {
service_instances.push_back({});
CompileOnlyService::AotComputationInstance& service_instance =
service_instances.back();
TF_RET_CHECK(instance.computation != nullptr);
service_instance.computation = instance.computation->handle();
service_instance.argument_layouts = instance.argument_layouts;
service_instance.result_layout = instance.result_layout;
}
return compiler_service_->CompileAheadOfTime(service_instances, options);
}
int64 CompileOnlyClient::PointerSizeForTriple(
tensorflow::StringPiece target_triple) {
llvm::Triple triple(
llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
if (triple.isArch64Bit()) {
return 8;
} else if (triple.isArch32Bit()) {
return 4;
} else {
CHECK(triple.isArch16Bit());
return 2;
}
}
} // namespace xla

View File

@ -0,0 +1,66 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/service/compile_only_service.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
// An XLA Client specialization for doing ahead-of-time compilation. This does
// not require (or attempt to instantiate) an execution-capable backend for the
// relevant platform.
class CompileOnlyClient : public Client {
public:
explicit CompileOnlyClient(CompileOnlyService* service)
: Client(service), compiler_service_(service) {}
CompileOnlyClient(const CompileOnlyClient&) = delete;
void operator=(const CompileOnlyClient&) = delete;
// A description of a computation to compile using CompileAheadOfTime.
struct AotComputationInstance {
const Computation* computation;
// Inform the compiler of the expected layout for arguments.
std::vector<const Shape*> argument_layouts;
// Specifies the expected result layout.
const Shape* result_layout;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. The |options| parameter describes
// the target for which the compiler should emit code.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
const AotCompilationOptions& options);
// Returns the size of a pointer in bytes for a given triple.
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
private:
CompileOnlyService* compiler_service_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_

View File

@ -261,38 +261,6 @@ tensorflow::Status LocalClient::ResolveArguments(
argument_ptrs); argument_ptrs);
} }
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
LocalClient::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
computations,
const AotCompilationOptions& options) {
std::vector<LocalService::AheadOfTimeComputationInstance> service_instances;
service_instances.reserve(computations.size());
for (const AheadOfTimeComputationInstance& instance : computations) {
service_instances.push_back({});
LocalService::AheadOfTimeComputationInstance& service_instance =
service_instances.back();
TF_RET_CHECK(instance.computation != nullptr);
service_instance.computation = instance.computation->handle();
service_instance.argument_layouts = instance.argument_layouts;
service_instance.result_layout = instance.result_layout;
}
return local_service_->CompileAheadOfTime(service_instances, options);
}
int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) {
llvm::Triple triple(
llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
if (triple.isArch64Bit()) {
return 8;
} else if (triple.isArch32Bit()) {
return 4;
} else {
CHECK(triple.isArch16Bit());
return 2;
}
}
se::Platform* LocalClient::platform() const { se::Platform* LocalClient::platform() const {
return local_service_->backend().platform(); return local_service_->backend().platform();
} }

View File

@ -148,7 +148,7 @@ class LocalExecutable {
const ExecutableBuildOptions& build_options_; const ExecutableBuildOptions& build_options_;
}; };
// An XLA service client object for use when the client and service run in // An XLA Client specialization for use when the client and service run in
// the same process. // the same process.
class LocalClient : public Client { class LocalClient : public Client {
public: public:
@ -182,30 +182,6 @@ class LocalClient : public Client {
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const ExecutableBuildOptions& options); const ExecutableBuildOptions& options);
// A description of a computation to compile using CompileAheadOfTime.
struct AheadOfTimeComputationInstance {
const Computation* computation;
// Inform the compiler of the expected layout for arguments.
std::vector<const Shape*> argument_layouts;
// Specifies the expected result layout.
const Shape* result_layout;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. The |options| parameter describes
// the target for which the compiler should emit code.
//
// TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its
// own library.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
computations,
const AotCompilationOptions& options);
// Returns the size of a pointer in bytes for a given triple.
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
// Returns the platform that the underlying service targets. // Returns the platform that the underlying service targets.
perftools::gputools::Platform* platform() const; perftools::gputools::Platform* platform() const;

View File

@ -408,6 +408,27 @@ cc_library(
], ],
) )
cc_library(
name = "compile_only_service",
srcs = ["compile_only_service.cc"],
hdrs = ["compile_only_service.h"],
deps = [
":backend",
":compiler",
":computation_layout",
":computation_tracker",
":platform_util",
":service",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
)
cc_library( cc_library(
name = "cpu_plugin", name = "cpu_plugin",
deps = [ deps = [

View File

@ -0,0 +1,131 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/compile_only_service.h"
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace se = ::perftools::gputools;
namespace xla {
/* static */ StatusOr<std::unique_ptr<CompileOnlyService>>
CompileOnlyService::NewService(perftools::gputools::Platform* platform) {
ServiceOptions default_options;
default_options.set_platform(platform);
return NewService(default_options);
}
/* static */ StatusOr<std::unique_ptr<CompileOnlyService>>
CompileOnlyService::NewService(const ServiceOptions& options) {
perftools::gputools::Platform* platform = options.platform();
if (platform == nullptr) {
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
}
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
std::unique_ptr<CompileOnlyService> service(
new CompileOnlyService(compiler, std::move(compute_constant_backend)));
return std::move(service);
}
CompileOnlyService::CompileOnlyService(
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend)
: Service(/*backend=*/nullptr, std::move(compute_constant_backend)),
compiler_(compiler) {
runs_in_client_process_ = true;
}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
const AotCompilationOptions& options) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
for (const AotComputationInstance& instance : computations) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(instance.computation));
VersionedComputationHandle versioned_handle =
user_computation->GetVersionedHandle();
// Dump computation proto state if flag is set.
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
const string& directory_path = flags->xla_dump_computations_to;
if (!directory_path.empty()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<SessionModule> session_module,
computation_tracker_.SnapshotComputation(versioned_handle.handle));
string filename = tensorflow::strings::StrCat(
"computation_", versioned_handle.handle.handle(), "__",
session_module->entry().name(), "__version_",
versioned_handle.version);
TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
*session_module));
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
versioned_handle,
/*include_unreachable_instructions=*/true));
hlo_modules.push_back(std::move(hlo_module));
TF_ASSIGN_OR_RETURN(
std::shared_ptr<const ProgramShape> program_shape,
user_computation->ComputeProgramShape(versioned_handle.version));
module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
HloModuleConfig* module_config = module_configs.back().get();
auto* computation_layout =
module_config->mutable_entry_computation_layout();
if (flags->xla_hlo_profile) {
module_config->enable_hlo_profiling(true);
}
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
const Shape& argument_layout = *instance.argument_layouts[i];
if (ShapeUtil::IsTuple(argument_layout)) {
return Unimplemented("tuple arguments not supported yet");
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
argument_layout));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
*instance.result_layout));
}
return compiler_->CompileAheadOfTime(std::move(hlo_modules),
std::move(module_configs),
MakeHloDumper(), options);
}
} // namespace xla

View File

@ -0,0 +1,128 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
// An XLA Service specialization for ahead-of-time compilation. This only
// instantiates a Compiler object for the relevant platform; it does not
// instantiate or require an execution backend.
class CompileOnlyService : public Service {
public:
// Factory for creating a CompileOnlyService. The parameter platform is the
// platform that the service should target. If platform is null then the
// default platform is used.
static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
perftools::gputools::Platform* platform);
static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
const ServiceOptions& options);
// A description of a computation to compile using CompileAheadOfTime.
struct AotComputationInstance {
ComputationHandle computation;
std::vector<const Shape*> argument_layouts;
const Shape* result_layout = nullptr;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. See
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
const AotCompilationOptions& Options);
// Override Service methods that require an execute backend.
tensorflow::Status Execute(const ExecuteRequest* arg,
ExecuteResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
ExecuteParallelResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
tensorflow::Status GetDeviceHandles(
const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) override {
return Unimplemented("CompileOnlyService does not support devices.");
}
tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
ExecuteAsyncResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
tensorflow::Status WaitForExecution(
const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
tensorflow::Status TransferToClient(
const TransferToClientRequest* arg,
TransferToClientResponse* result) override {
return Unimplemented("CompileOnlyService does not support data transfers.");
}
tensorflow::Status TransferToClientInProcess(
const TransferToClientInProcessRequest* arg,
TransferToClientInProcessResponse* result) override {
return Unimplemented("CompileOnlyService does not support data transfers.");
}
tensorflow::Status TransferToServer(
const TransferToServerRequest* arg,
TransferToServerResponse* result) override {
return Unimplemented("CompileOnlyService does not support data transfers.");
}
tensorflow::Status TransferToInfeed(
const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) override {
return Unimplemented("CompileOnlyService does not support data transfers.");
}
tensorflow::Status TransferFromOutfeed(
const TransferFromOutfeedRequest* arg,
TransferFromOutfeedResponse* result) override {
return Unimplemented("CompileOnlyService does not support data transfers.");
}
tensorflow::Status TransferToServerInProcess(
const TransferToServerInProcessRequest* arg,
TransferToServerInProcessResponse* result) override {
return Unimplemented("CompileOnlyService does not support data transfers.");
}
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override {
return Unimplemented("CompileOnlyService does not support devices.");
}
private:
explicit CompileOnlyService(
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend);
CompileOnlyService(const CompileOnlyService&) = delete;
void operator=(const CompileOnlyService&) = delete;
// The compiler for the target platform. This is included in place of
// the Service::execute_backend_'s compiler, since execute_backend_ is a
// nullptr in CompileOnlyService.
Compiler* compiler_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_

View File

@ -128,70 +128,6 @@ StatusOr<GlobalDataHandle> LocalService::AllocateBufferOnDevice(
allocation_size)); allocation_size));
} }
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
LocalService::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
computations,
const AotCompilationOptions& options) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
for (const AheadOfTimeComputationInstance& instance : computations) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(instance.computation));
VersionedComputationHandle versioned_handle =
user_computation->GetVersionedHandle();
// Dump computation proto state if flag is set.
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
const string& directory_path = flags->xla_dump_computations_to;
if (!directory_path.empty()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<SessionModule> session_module,
computation_tracker_.SnapshotComputation(versioned_handle.handle));
string filename = tensorflow::strings::StrCat(
"computation_", versioned_handle.handle.handle(), "__",
session_module->entry().name(), "__version_",
versioned_handle.version);
TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
*session_module));
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
versioned_handle,
/*include_unreachable_instructions=*/true));
hlo_modules.push_back(std::move(hlo_module));
TF_ASSIGN_OR_RETURN(
std::shared_ptr<const ProgramShape> program_shape,
user_computation->ComputeProgramShape(versioned_handle.version));
module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
HloModuleConfig* module_config = module_configs.back().get();
auto* computation_layout =
module_config->mutable_entry_computation_layout();
if (flags->xla_hlo_profile) {
module_config->enable_hlo_profiling(true);
}
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
const Shape& argument_layout = *instance.argument_layouts[i];
if (ShapeUtil::IsTuple(argument_layout)) {
return Unimplemented("tuple arguments not supported yet");
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
argument_layout));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
*instance.result_layout));
}
return execute_backend_->compiler()->CompileAheadOfTime(
std::move(hlo_modules), std::move(module_configs), MakeHloDumper(),
options);
}
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ComputationHandle& computation, const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,

View File

@ -59,22 +59,6 @@ class LocalService : public Service {
const Shape& shape, int device_ordinal, const Shape& shape, int device_ordinal,
bool allocate_space_for_deep_copy); bool allocate_space_for_deep_copy);
// A description of a computation to compile using CompileAheadOfTime.
struct AheadOfTimeComputationInstance {
ComputationHandle computation;
std::vector<const Shape*> argument_layouts;
const Shape* result_layout = nullptr;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. See
// |LocalClient::CompileAheadOfTime| for additional details.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
computations,
const AotCompilationOptions& Options);
// Builds an Executable with the given argument layouts and options. If // Builds an Executable with the given argument layouts and options. If
// result_layout is non-null, then the executable is compiled to produce a // result_layout is non-null, then the executable is compiled to produce a
// result of the given layout. // result of the given layout.

View File

@ -180,20 +180,24 @@ Service::Service(std::unique_ptr<Backend> execute_backend,
std::unique_ptr<Backend> compute_constant_backend) std::unique_ptr<Backend> compute_constant_backend)
: execute_backend_(std::move(execute_backend)), : execute_backend_(std::move(execute_backend)),
compute_constant_backend_(std::move(compute_constant_backend)) { compute_constant_backend_(std::move(compute_constant_backend)) {
LOG(INFO) << Printf( if (execute_backend_) {
"XLA service %p executing computations on platform %s. Devices:", this, LOG(INFO) << Printf(
execute_backend_->platform()->Name().c_str()); "XLA service %p executing computations on platform %s. Devices:", this,
for (int i = 0; i < execute_backend_->device_count(); ++i) { execute_backend_->platform()->Name().c_str());
if (execute_backend_->device_ordinal_supported(i)) { for (int i = 0; i < execute_backend_->device_count(); ++i) {
se::StreamExecutor* executor = if (execute_backend_->device_ordinal_supported(i)) {
execute_backend_->stream_executor(i).ValueOrDie(); se::StreamExecutor* executor =
const auto& description = executor->GetDeviceDescription(); execute_backend_->stream_executor(i).ValueOrDie();
LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, const auto& description = executor->GetDeviceDescription();
description.name().c_str(), LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
description.platform_version().c_str()); description.name().c_str(),
} else { description.platform_version().c_str());
LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); } else {
LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
}
} }
} else {
VLOG(1) << "XLA compile-only service constructed";
} }
} }

View File

@ -42,7 +42,7 @@ xla::Computation Doubler(xla::Client* client) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::port::InitMain(argv[0], &argc, &argv);
auto client = xla::ClientLibrary::LocalClientOrDie(); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie();
xla::ComputationBuilder builder(client, "aot_test_helper"); xla::ComputationBuilder builder(client, "aot_test_helper");
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
@ -74,7 +74,7 @@ int main(int argc, char** argv) {
llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
xla::Computation computation = builder.Build().ConsumeValueOrDie(); xla::Computation computation = builder.Build().ConsumeValueOrDie();
xla::LocalClient::AheadOfTimeComputationInstance instance{ xla::CompileOnlyClient::AotComputationInstance instance{
&computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
xla::cpu::CpuAotCompilationOptions options( xla::cpu::CpuAotCompilationOptions options(

View File

@ -0,0 +1 @@
#include "unsupported/Eigen/CXX11/ThreadPool"