Refactor XLA's CompileAheadOfTime out of LocalClient into a new CompileOnlyClient class, and likewise from LocalService into a new CompileOnlyService class.
This also renames AheadOfTimeComputationInstance to AotComputationInstance for consistency with AotCompilationResult and AotCompilationOptions in compiler/xla/service/compiler.h. Change: 155252320
This commit is contained in:
parent
1a1bef744a
commit
70c60b1491
@ -73,7 +73,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -274,7 +274,8 @@ Status CreateXlaArgs(const Graph& graph,
|
||||
|
||||
// Converts the TensorFlow graph into an XLA computation, by executing the
|
||||
// graph symbolically, with each op building up the XLA HLO.
|
||||
Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
|
||||
Status ConvertGraphToXla(xla::CompileOnlyClient* client,
|
||||
std::unique_ptr<Graph> graph,
|
||||
xla::Computation* computation, bool* has_context_arg) {
|
||||
// Create a device and context to convert the graph into an XLA computation.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
@ -333,7 +334,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
|
||||
}
|
||||
|
||||
// Compiles the XLA computation into executable code.
|
||||
Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
|
||||
Status CompileXla(xla::CompileOnlyClient* client,
|
||||
const xla::Computation& computation,
|
||||
const xla::cpu::CpuAotCompilationOptions& aot_opts,
|
||||
CompileResult* compile_result) {
|
||||
// Retrieves arg and result layouts from the computation.
|
||||
@ -350,7 +352,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
|
||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||
arg_layouts.push_back(pshape->mutable_parameters(i));
|
||||
}
|
||||
xla::LocalClient::AheadOfTimeComputationInstance instance;
|
||||
xla::CompileOnlyClient::AotComputationInstance instance;
|
||||
instance.computation = &computation;
|
||||
instance.argument_layouts = std::move(arg_layouts);
|
||||
instance.result_layout = &pshape->result();
|
||||
@ -365,7 +367,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
|
||||
std::move(aot_or.ValueOrDie().back()));
|
||||
compile_result->entry_point = aot_opts.entry_point_name();
|
||||
compile_result->pointer_size =
|
||||
xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
|
||||
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -394,8 +396,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
|
||||
namespace gpu = perftools::gputools;
|
||||
gpu::Platform* cpu_platform =
|
||||
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
|
||||
xla::LocalClient* client =
|
||||
xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie();
|
||||
xla::CompileOnlyClient* client =
|
||||
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
|
||||
.ValueOrDie();
|
||||
xla::Computation computation;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
|
||||
&compile_result->has_context_arg));
|
||||
|
@ -99,6 +99,26 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compile_only_client",
|
||||
srcs = ["compile_only_client.cc"],
|
||||
hdrs = ["compile_only_client.h"],
|
||||
deps = [
|
||||
":client",
|
||||
":computation",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:compile_only_service",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@llvm//:support",
|
||||
],
|
||||
)
|
||||
|
||||
# This target is used to instantiate the XLA service in-process and create
|
||||
# a client for it.
|
||||
cc_library(
|
||||
@ -106,12 +126,14 @@ cc_library(
|
||||
srcs = ["client_library.cc"],
|
||||
hdrs = ["client_library.h"],
|
||||
deps = [
|
||||
":compile_only_client",
|
||||
":local_client",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:compile_only_service",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:local_service",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
|
@ -69,8 +69,8 @@ ClientLibrary::~ClientLibrary() = default;
|
||||
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
|
||||
}
|
||||
|
||||
auto it = client_library.instances_.find(platform->id());
|
||||
if (it != client_library.instances_.end()) {
|
||||
auto it = client_library.local_instances_.find(platform->id());
|
||||
if (it != client_library.local_instances_.end()) {
|
||||
return it->second->client.get();
|
||||
}
|
||||
|
||||
@ -78,13 +78,13 @@ ClientLibrary::~ClientLibrary() = default;
|
||||
service_options.set_platform(platform);
|
||||
service_options.set_number_of_replicas(replica_count);
|
||||
|
||||
std::unique_ptr<LocalInstance> instance = MakeUnique<LocalInstance>();
|
||||
auto instance = MakeUnique<LocalInstance>();
|
||||
TF_ASSIGN_OR_RETURN(instance->service,
|
||||
LocalService::NewService(service_options));
|
||||
instance->client = MakeUnique<LocalClient>(instance->service.get());
|
||||
LocalClient* cl = instance->client.get();
|
||||
|
||||
client_library.instances_.insert(
|
||||
client_library.local_instances_.insert(
|
||||
std::make_pair(platform->id(), std::move(instance)));
|
||||
return cl;
|
||||
}
|
||||
@ -99,9 +99,35 @@ ClientLibrary::~ClientLibrary() = default;
|
||||
perftools::gputools::Platform* platform) {
|
||||
ClientLibrary& client_library = Singleton();
|
||||
tensorflow::mutex_lock lock(client_library.service_mutex_);
|
||||
auto it = client_library.instances_.find(platform->id());
|
||||
CHECK(it != client_library.instances_.end());
|
||||
auto it = client_library.local_instances_.find(platform->id());
|
||||
CHECK(it != client_library.local_instances_.end());
|
||||
return it->second->service.get();
|
||||
}
|
||||
|
||||
/* static */ StatusOr<CompileOnlyClient*>
|
||||
ClientLibrary::GetOrCreateCompileOnlyClient(
|
||||
perftools::gputools::Platform* platform) {
|
||||
ClientLibrary& client_library = Singleton();
|
||||
tensorflow::mutex_lock lock(client_library.service_mutex_);
|
||||
|
||||
if (platform == nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
|
||||
}
|
||||
|
||||
auto it = client_library.compile_only_instances_.find(platform->id());
|
||||
if (it != client_library.compile_only_instances_.end()) {
|
||||
return it->second->client.get();
|
||||
}
|
||||
|
||||
auto instance = MakeUnique<CompileOnlyInstance>();
|
||||
TF_ASSIGN_OR_RETURN(instance->service,
|
||||
CompileOnlyService::NewService(platform));
|
||||
instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
|
||||
CompileOnlyClient* cl = instance->client.get();
|
||||
|
||||
client_library.compile_only_instances_.insert(
|
||||
std::make_pair(platform->id(), std::move(instance)));
|
||||
return cl;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -26,7 +26,9 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/compile_only_service.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/local_service.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -76,6 +78,13 @@ class ClientLibrary {
|
||||
// access user computations from client.
|
||||
static LocalService* GetXlaService(perftools::gputools::Platform* platform);
|
||||
|
||||
// Singleton constructor-or-accessor for compile-only clients. Arguments:
|
||||
//
|
||||
// platform : The platform the underlying XLA service should target. If
|
||||
// null then default platform is used.
|
||||
static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
|
||||
perftools::gputools::Platform* platform = nullptr);
|
||||
|
||||
private:
|
||||
// Returns the singleton instance of ClientLibrary.
|
||||
static ClientLibrary& Singleton();
|
||||
@ -90,10 +99,21 @@ class ClientLibrary {
|
||||
std::unique_ptr<LocalClient> client;
|
||||
};
|
||||
|
||||
struct CompileOnlyInstance {
|
||||
// Service that is wrapped by the singleton client object.
|
||||
std::unique_ptr<CompileOnlyService> service;
|
||||
// Singleton client object.
|
||||
std::unique_ptr<CompileOnlyClient> client;
|
||||
};
|
||||
|
||||
tensorflow::mutex service_mutex_; // Guards the singleton creation state.
|
||||
std::unordered_map<perftools::gputools::Platform::Id,
|
||||
std::unique_ptr<LocalInstance>>
|
||||
instances_ GUARDED_BY(service_mutex_);
|
||||
local_instances_ GUARDED_BY(service_mutex_);
|
||||
|
||||
std::unordered_map<perftools::gputools::Platform::Id,
|
||||
std::unique_ptr<CompileOnlyInstance>>
|
||||
compile_only_instances_ GUARDED_BY(service_mutex_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary);
|
||||
};
|
||||
|
59
tensorflow/compiler/xla/client/compile_only_client.cc
Normal file
59
tensorflow/compiler/xla/client/compile_only_client.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
|
||||
#include "external/llvm/include/llvm/ADT/Triple.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyClient::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||
const AotCompilationOptions& options) {
|
||||
std::vector<CompileOnlyService::AotComputationInstance> service_instances;
|
||||
service_instances.reserve(computations.size());
|
||||
for (const AotComputationInstance& instance : computations) {
|
||||
service_instances.push_back({});
|
||||
CompileOnlyService::AotComputationInstance& service_instance =
|
||||
service_instances.back();
|
||||
TF_RET_CHECK(instance.computation != nullptr);
|
||||
service_instance.computation = instance.computation->handle();
|
||||
service_instance.argument_layouts = instance.argument_layouts;
|
||||
service_instance.result_layout = instance.result_layout;
|
||||
}
|
||||
return compiler_service_->CompileAheadOfTime(service_instances, options);
|
||||
}
|
||||
|
||||
int64 CompileOnlyClient::PointerSizeForTriple(
|
||||
tensorflow::StringPiece target_triple) {
|
||||
llvm::Triple triple(
|
||||
llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
|
||||
if (triple.isArch64Bit()) {
|
||||
return 8;
|
||||
} else if (triple.isArch32Bit()) {
|
||||
return 4;
|
||||
} else {
|
||||
CHECK(triple.isArch16Bit());
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
66
tensorflow/compiler/xla/client/compile_only_client.h
Normal file
66
tensorflow/compiler/xla/client/compile_only_client.h
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/service/compile_only_service.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// An XLA Client specialization for doing ahead-of-time compilation. This does
|
||||
// not require (or attempt to instantiate) an execution-capable backend for the
|
||||
// relevant platform.
|
||||
class CompileOnlyClient : public Client {
|
||||
public:
|
||||
explicit CompileOnlyClient(CompileOnlyService* service)
|
||||
: Client(service), compiler_service_(service) {}
|
||||
|
||||
CompileOnlyClient(const CompileOnlyClient&) = delete;
|
||||
void operator=(const CompileOnlyClient&) = delete;
|
||||
|
||||
// A description of a computation to compile using CompileAheadOfTime.
|
||||
struct AotComputationInstance {
|
||||
const Computation* computation;
|
||||
// Inform the compiler of the expected layout for arguments.
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
// Specifies the expected result layout.
|
||||
const Shape* result_layout;
|
||||
};
|
||||
|
||||
// Compiles a list of computations for ahead-of-time execution. This is
|
||||
// intended for use in static compilation. The |options| parameter describes
|
||||
// the target for which the compiler should emit code.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||
const AotCompilationOptions& options);
|
||||
|
||||
// Returns the size of a pointer in bytes for a given triple.
|
||||
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
|
||||
|
||||
private:
|
||||
CompileOnlyService* compiler_service_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
|
@ -261,38 +261,6 @@ tensorflow::Status LocalClient::ResolveArguments(
|
||||
argument_ptrs);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
LocalClient::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
||||
computations,
|
||||
const AotCompilationOptions& options) {
|
||||
std::vector<LocalService::AheadOfTimeComputationInstance> service_instances;
|
||||
service_instances.reserve(computations.size());
|
||||
for (const AheadOfTimeComputationInstance& instance : computations) {
|
||||
service_instances.push_back({});
|
||||
LocalService::AheadOfTimeComputationInstance& service_instance =
|
||||
service_instances.back();
|
||||
TF_RET_CHECK(instance.computation != nullptr);
|
||||
service_instance.computation = instance.computation->handle();
|
||||
service_instance.argument_layouts = instance.argument_layouts;
|
||||
service_instance.result_layout = instance.result_layout;
|
||||
}
|
||||
return local_service_->CompileAheadOfTime(service_instances, options);
|
||||
}
|
||||
|
||||
int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) {
|
||||
llvm::Triple triple(
|
||||
llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
|
||||
if (triple.isArch64Bit()) {
|
||||
return 8;
|
||||
} else if (triple.isArch32Bit()) {
|
||||
return 4;
|
||||
} else {
|
||||
CHECK(triple.isArch16Bit());
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
se::Platform* LocalClient::platform() const {
|
||||
return local_service_->backend().platform();
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ class LocalExecutable {
|
||||
const ExecutableBuildOptions& build_options_;
|
||||
};
|
||||
|
||||
// An XLA service client object for use when the client and service run in
|
||||
// An XLA Client specialization for use when the client and service run in
|
||||
// the same process.
|
||||
class LocalClient : public Client {
|
||||
public:
|
||||
@ -182,30 +182,6 @@ class LocalClient : public Client {
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const ExecutableBuildOptions& options);
|
||||
|
||||
// A description of a computation to compile using CompileAheadOfTime.
|
||||
struct AheadOfTimeComputationInstance {
|
||||
const Computation* computation;
|
||||
// Inform the compiler of the expected layout for arguments.
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
// Specifies the expected result layout.
|
||||
const Shape* result_layout;
|
||||
};
|
||||
|
||||
// Compiles a list of computations for ahead-of-time execution. This is
|
||||
// intended for use in static compilation. The |options| parameter describes
|
||||
// the target for which the compiler should emit code.
|
||||
//
|
||||
// TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its
|
||||
// own library.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
||||
computations,
|
||||
const AotCompilationOptions& options);
|
||||
|
||||
// Returns the size of a pointer in bytes for a given triple.
|
||||
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
|
||||
|
||||
// Returns the platform that the underlying service targets.
|
||||
perftools::gputools::Platform* platform() const;
|
||||
|
||||
|
@ -408,6 +408,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compile_only_service",
|
||||
srcs = ["compile_only_service.cc"],
|
||||
hdrs = ["compile_only_service.h"],
|
||||
deps = [
|
||||
":backend",
|
||||
":compiler",
|
||||
":computation_layout",
|
||||
":computation_tracker",
|
||||
":platform_util",
|
||||
":service",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_plugin",
|
||||
deps = [
|
||||
|
131
tensorflow/compiler/xla/service/compile_only_service.cc
Normal file
131
tensorflow/compiler/xla/service/compile_only_service.cc
Normal file
@ -0,0 +1,131 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/compile_only_service.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_tracker.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<CompileOnlyService>>
|
||||
CompileOnlyService::NewService(perftools::gputools::Platform* platform) {
|
||||
ServiceOptions default_options;
|
||||
default_options.set_platform(platform);
|
||||
return NewService(default_options);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<CompileOnlyService>>
|
||||
CompileOnlyService::NewService(const ServiceOptions& options) {
|
||||
perftools::gputools::Platform* platform = options.platform();
|
||||
if (platform == nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||
CreateComputeConstantBackend());
|
||||
std::unique_ptr<CompileOnlyService> service(
|
||||
new CompileOnlyService(compiler, std::move(compute_constant_backend)));
|
||||
return std::move(service);
|
||||
}
|
||||
|
||||
CompileOnlyService::CompileOnlyService(
|
||||
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend)
|
||||
: Service(/*backend=*/nullptr, std::move(compute_constant_backend)),
|
||||
compiler_(compiler) {
|
||||
runs_in_client_process_ = true;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyService::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||
const AotCompilationOptions& options) {
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
|
||||
for (const AotComputationInstance& instance : computations) {
|
||||
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
|
||||
computation_tracker_.Resolve(instance.computation));
|
||||
VersionedComputationHandle versioned_handle =
|
||||
user_computation->GetVersionedHandle();
|
||||
|
||||
// Dump computation proto state if flag is set.
|
||||
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
|
||||
const string& directory_path = flags->xla_dump_computations_to;
|
||||
if (!directory_path.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<SessionModule> session_module,
|
||||
computation_tracker_.SnapshotComputation(versioned_handle.handle));
|
||||
string filename = tensorflow::strings::StrCat(
|
||||
"computation_", versioned_handle.handle.handle(), "__",
|
||||
session_module->entry().name(), "__version_",
|
||||
versioned_handle.version);
|
||||
TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
|
||||
*session_module));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
|
||||
computation_tracker_.BuildHloModule(
|
||||
versioned_handle,
|
||||
/*include_unreachable_instructions=*/true));
|
||||
hlo_modules.push_back(std::move(hlo_module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<const ProgramShape> program_shape,
|
||||
user_computation->ComputeProgramShape(versioned_handle.version));
|
||||
|
||||
module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
|
||||
HloModuleConfig* module_config = module_configs.back().get();
|
||||
auto* computation_layout =
|
||||
module_config->mutable_entry_computation_layout();
|
||||
if (flags->xla_hlo_profile) {
|
||||
module_config->enable_hlo_profiling(true);
|
||||
}
|
||||
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
|
||||
const Shape& argument_layout = *instance.argument_layouts[i];
|
||||
if (ShapeUtil::IsTuple(argument_layout)) {
|
||||
return Unimplemented("tuple arguments not supported yet");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
argument_layout));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
*instance.result_layout));
|
||||
}
|
||||
|
||||
return compiler_->CompileAheadOfTime(std::move(hlo_modules),
|
||||
std::move(module_configs),
|
||||
MakeHloDumper(), options);
|
||||
}
|
||||
|
||||
} // namespace xla
|
128
tensorflow/compiler/xla/service/compile_only_service.h
Normal file
128
tensorflow/compiler/xla/service/compile_only_service.h
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/service.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// An XLA Service specialization for ahead-of-time compilation. This only
|
||||
// instantiates a Compiler object for the relevant platform; it does not
|
||||
// instantiate or require an execution backend.
|
||||
class CompileOnlyService : public Service {
|
||||
public:
|
||||
// Factory for creating a CompileOnlyService. The parameter platform is the
|
||||
// platform that the service should target. If platform is null then the
|
||||
// default platform is used.
|
||||
static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
|
||||
perftools::gputools::Platform* platform);
|
||||
static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
|
||||
const ServiceOptions& options);
|
||||
|
||||
// A description of a computation to compile using CompileAheadOfTime.
|
||||
struct AotComputationInstance {
|
||||
ComputationHandle computation;
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
const Shape* result_layout = nullptr;
|
||||
};
|
||||
|
||||
// Compiles a list of computations for ahead-of-time execution. This is
|
||||
// intended for use in static compilation. See
|
||||
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||
const AotCompilationOptions& Options);
|
||||
|
||||
// Override Service methods that require an execute backend.
|
||||
tensorflow::Status Execute(const ExecuteRequest* arg,
|
||||
ExecuteResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support execution.");
|
||||
}
|
||||
tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||
ExecuteParallelResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support execution.");
|
||||
}
|
||||
tensorflow::Status GetDeviceHandles(
|
||||
const GetDeviceHandlesRequest* arg,
|
||||
GetDeviceHandlesResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support devices.");
|
||||
}
|
||||
tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
|
||||
ExecuteAsyncResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support execution.");
|
||||
}
|
||||
tensorflow::Status WaitForExecution(
|
||||
const WaitForExecutionRequest* arg,
|
||||
WaitForExecutionResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support execution.");
|
||||
}
|
||||
tensorflow::Status TransferToClient(
|
||||
const TransferToClientRequest* arg,
|
||||
TransferToClientResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||
}
|
||||
tensorflow::Status TransferToClientInProcess(
|
||||
const TransferToClientInProcessRequest* arg,
|
||||
TransferToClientInProcessResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||
}
|
||||
tensorflow::Status TransferToServer(
|
||||
const TransferToServerRequest* arg,
|
||||
TransferToServerResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||
}
|
||||
tensorflow::Status TransferToInfeed(
|
||||
const TransferToInfeedRequest* arg,
|
||||
TransferToInfeedResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||
}
|
||||
tensorflow::Status TransferFromOutfeed(
|
||||
const TransferFromOutfeedRequest* arg,
|
||||
TransferFromOutfeedResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||
}
|
||||
tensorflow::Status TransferToServerInProcess(
|
||||
const TransferToServerInProcessRequest* arg,
|
||||
TransferToServerInProcessResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||
}
|
||||
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
|
||||
ResetDeviceResponse* result) override {
|
||||
return Unimplemented("CompileOnlyService does not support devices.");
|
||||
}
|
||||
|
||||
private:
|
||||
explicit CompileOnlyService(
|
||||
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend);
|
||||
CompileOnlyService(const CompileOnlyService&) = delete;
|
||||
void operator=(const CompileOnlyService&) = delete;
|
||||
|
||||
// The compiler for the target platform. This is included in place of
|
||||
// the Service::execute_backend_'s compiler, since execute_backend_ is a
|
||||
// nullptr in CompileOnlyService.
|
||||
Compiler* compiler_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
|
@ -128,70 +128,6 @@ StatusOr<GlobalDataHandle> LocalService::AllocateBufferOnDevice(
|
||||
allocation_size));
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
LocalService::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
||||
computations,
|
||||
const AotCompilationOptions& options) {
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
|
||||
for (const AheadOfTimeComputationInstance& instance : computations) {
|
||||
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
|
||||
computation_tracker_.Resolve(instance.computation));
|
||||
VersionedComputationHandle versioned_handle =
|
||||
user_computation->GetVersionedHandle();
|
||||
|
||||
// Dump computation proto state if flag is set.
|
||||
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
|
||||
const string& directory_path = flags->xla_dump_computations_to;
|
||||
if (!directory_path.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<SessionModule> session_module,
|
||||
computation_tracker_.SnapshotComputation(versioned_handle.handle));
|
||||
string filename = tensorflow::strings::StrCat(
|
||||
"computation_", versioned_handle.handle.handle(), "__",
|
||||
session_module->entry().name(), "__version_",
|
||||
versioned_handle.version);
|
||||
TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
|
||||
*session_module));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
|
||||
computation_tracker_.BuildHloModule(
|
||||
versioned_handle,
|
||||
/*include_unreachable_instructions=*/true));
|
||||
hlo_modules.push_back(std::move(hlo_module));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<const ProgramShape> program_shape,
|
||||
user_computation->ComputeProgramShape(versioned_handle.version));
|
||||
|
||||
module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
|
||||
HloModuleConfig* module_config = module_configs.back().get();
|
||||
auto* computation_layout =
|
||||
module_config->mutable_entry_computation_layout();
|
||||
if (flags->xla_hlo_profile) {
|
||||
module_config->enable_hlo_profiling(true);
|
||||
}
|
||||
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
|
||||
const Shape& argument_layout = *instance.argument_layouts[i];
|
||||
if (ShapeUtil::IsTuple(argument_layout)) {
|
||||
return Unimplemented("tuple arguments not supported yet");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
argument_layout));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
*instance.result_layout));
|
||||
}
|
||||
|
||||
return execute_backend_->compiler()->CompileAheadOfTime(
|
||||
std::move(hlo_modules), std::move(module_configs), MakeHloDumper(),
|
||||
options);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||
const ComputationHandle& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
|
@ -59,22 +59,6 @@ class LocalService : public Service {
|
||||
const Shape& shape, int device_ordinal,
|
||||
bool allocate_space_for_deep_copy);
|
||||
|
||||
// A description of a computation to compile using CompileAheadOfTime.
|
||||
struct AheadOfTimeComputationInstance {
|
||||
ComputationHandle computation;
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
const Shape* result_layout = nullptr;
|
||||
};
|
||||
|
||||
// Compiles a list of computations for ahead-of-time execution. This is
|
||||
// intended for use in static compilation. See
|
||||
// |LocalClient::CompileAheadOfTime| for additional details.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
||||
computations,
|
||||
const AotCompilationOptions& Options);
|
||||
|
||||
// Builds an Executable with the given argument layouts and options. If
|
||||
// result_layout is non-null, then the executable is compiled to produce a
|
||||
// result of the given layout.
|
||||
|
@ -180,20 +180,24 @@ Service::Service(std::unique_ptr<Backend> execute_backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend)
|
||||
: execute_backend_(std::move(execute_backend)),
|
||||
compute_constant_backend_(std::move(compute_constant_backend)) {
|
||||
LOG(INFO) << Printf(
|
||||
"XLA service %p executing computations on platform %s. Devices:", this,
|
||||
execute_backend_->platform()->Name().c_str());
|
||||
for (int i = 0; i < execute_backend_->device_count(); ++i) {
|
||||
if (execute_backend_->device_ordinal_supported(i)) {
|
||||
se::StreamExecutor* executor =
|
||||
execute_backend_->stream_executor(i).ValueOrDie();
|
||||
const auto& description = executor->GetDeviceDescription();
|
||||
LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
|
||||
description.name().c_str(),
|
||||
description.platform_version().c_str());
|
||||
} else {
|
||||
LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
|
||||
if (execute_backend_) {
|
||||
LOG(INFO) << Printf(
|
||||
"XLA service %p executing computations on platform %s. Devices:", this,
|
||||
execute_backend_->platform()->Name().c_str());
|
||||
for (int i = 0; i < execute_backend_->device_count(); ++i) {
|
||||
if (execute_backend_->device_ordinal_supported(i)) {
|
||||
se::StreamExecutor* executor =
|
||||
execute_backend_->stream_executor(i).ValueOrDie();
|
||||
const auto& description = executor->GetDeviceDescription();
|
||||
LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
|
||||
description.name().c_str(),
|
||||
description.platform_version().c_str());
|
||||
} else {
|
||||
LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
VLOG(1) << "XLA compile-only service constructed";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,7 @@ xla::Computation Doubler(xla::Client* client) {
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
auto client = xla::ClientLibrary::LocalClientOrDie();
|
||||
auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie();
|
||||
|
||||
xla::ComputationBuilder builder(client, "aot_test_helper");
|
||||
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
|
||||
@ -74,7 +74,7 @@ int main(int argc, char** argv) {
|
||||
llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
|
||||
|
||||
xla::Computation computation = builder.Build().ConsumeValueOrDie();
|
||||
xla::LocalClient::AheadOfTimeComputationInstance instance{
|
||||
xla::CompileOnlyClient::AotComputationInstance instance{
|
||||
&computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
|
||||
|
||||
xla::cpu::CpuAotCompilationOptions options(
|
||||
|
1
tensorflow/opensource_only/eigen.threadpool
Normal file
1
tensorflow/opensource_only/eigen.threadpool
Normal file
@ -0,0 +1 @@
|
||||
#include "unsupported/Eigen/CXX11/ThreadPool"
|
Loading…
Reference in New Issue
Block a user