Refactor XLA's CompileAheadOfTime out of LocalClient into a new CompileOnlyClient class, and likewise from LocalService into a new CompileOnlyService class.
This also renames AheadOfTimeComputationInstance to AotComputationInstance for consistency with AotCompilationResult and AotCompilationOptions in compiler/xla/service/compiler.h. Change: 155252320
This commit is contained in:
parent
1a1bef744a
commit
70c60b1491
@ -73,7 +73,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||||
"//tensorflow/compiler/xla/service:compiler",
|
"//tensorflow/compiler/xla/service:compiler",
|
||||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -274,7 +274,8 @@ Status CreateXlaArgs(const Graph& graph,
|
|||||||
|
|
||||||
// Converts the TensorFlow graph into an XLA computation, by executing the
|
// Converts the TensorFlow graph into an XLA computation, by executing the
|
||||||
// graph symbolically, with each op building up the XLA HLO.
|
// graph symbolically, with each op building up the XLA HLO.
|
||||||
Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
|
Status ConvertGraphToXla(xla::CompileOnlyClient* client,
|
||||||
|
std::unique_ptr<Graph> graph,
|
||||||
xla::Computation* computation, bool* has_context_arg) {
|
xla::Computation* computation, bool* has_context_arg) {
|
||||||
// Create a device and context to convert the graph into an XLA computation.
|
// Create a device and context to convert the graph into an XLA computation.
|
||||||
XlaOpRegistry::RegisterCompilationKernels();
|
XlaOpRegistry::RegisterCompilationKernels();
|
||||||
@ -333,7 +334,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compiles the XLA computation into executable code.
|
// Compiles the XLA computation into executable code.
|
||||||
Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
|
Status CompileXla(xla::CompileOnlyClient* client,
|
||||||
|
const xla::Computation& computation,
|
||||||
const xla::cpu::CpuAotCompilationOptions& aot_opts,
|
const xla::cpu::CpuAotCompilationOptions& aot_opts,
|
||||||
CompileResult* compile_result) {
|
CompileResult* compile_result) {
|
||||||
// Retrieves arg and result layouts from the computation.
|
// Retrieves arg and result layouts from the computation.
|
||||||
@ -350,7 +352,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
|
|||||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||||
arg_layouts.push_back(pshape->mutable_parameters(i));
|
arg_layouts.push_back(pshape->mutable_parameters(i));
|
||||||
}
|
}
|
||||||
xla::LocalClient::AheadOfTimeComputationInstance instance;
|
xla::CompileOnlyClient::AotComputationInstance instance;
|
||||||
instance.computation = &computation;
|
instance.computation = &computation;
|
||||||
instance.argument_layouts = std::move(arg_layouts);
|
instance.argument_layouts = std::move(arg_layouts);
|
||||||
instance.result_layout = &pshape->result();
|
instance.result_layout = &pshape->result();
|
||||||
@ -365,7 +367,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
|
|||||||
std::move(aot_or.ValueOrDie().back()));
|
std::move(aot_or.ValueOrDie().back()));
|
||||||
compile_result->entry_point = aot_opts.entry_point_name();
|
compile_result->entry_point = aot_opts.entry_point_name();
|
||||||
compile_result->pointer_size =
|
compile_result->pointer_size =
|
||||||
xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
|
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -394,8 +396,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
|
|||||||
namespace gpu = perftools::gputools;
|
namespace gpu = perftools::gputools;
|
||||||
gpu::Platform* cpu_platform =
|
gpu::Platform* cpu_platform =
|
||||||
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
|
gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
|
||||||
xla::LocalClient* client =
|
xla::CompileOnlyClient* client =
|
||||||
xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie();
|
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
|
||||||
|
.ValueOrDie();
|
||||||
xla::Computation computation;
|
xla::Computation computation;
|
||||||
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
|
TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
|
||||||
&compile_result->has_context_arg));
|
&compile_result->has_context_arg));
|
||||||
|
@ -99,6 +99,26 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "compile_only_client",
|
||||||
|
srcs = ["compile_only_client.cc"],
|
||||||
|
hdrs = ["compile_only_client.h"],
|
||||||
|
deps = [
|
||||||
|
":client",
|
||||||
|
":computation",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/service:compile_only_service",
|
||||||
|
"//tensorflow/compiler/xla/service:compiler",
|
||||||
|
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
"@llvm//:support",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# This target is used to instantiate the XLA service in-process and create
|
# This target is used to instantiate the XLA service in-process and create
|
||||||
# a client for it.
|
# a client for it.
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -106,12 +126,14 @@ cc_library(
|
|||||||
srcs = ["client_library.cc"],
|
srcs = ["client_library.cc"],
|
||||||
hdrs = ["client_library.h"],
|
hdrs = ["client_library.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":compile_only_client",
|
||||||
":local_client",
|
":local_client",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/service:backend",
|
"//tensorflow/compiler/xla/service:backend",
|
||||||
|
"//tensorflow/compiler/xla/service:compile_only_service",
|
||||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||||
"//tensorflow/compiler/xla/service:local_service",
|
"//tensorflow/compiler/xla/service:local_service",
|
||||||
"//tensorflow/compiler/xla/service:platform_util",
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
|
@ -69,8 +69,8 @@ ClientLibrary::~ClientLibrary() = default;
|
|||||||
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
|
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto it = client_library.instances_.find(platform->id());
|
auto it = client_library.local_instances_.find(platform->id());
|
||||||
if (it != client_library.instances_.end()) {
|
if (it != client_library.local_instances_.end()) {
|
||||||
return it->second->client.get();
|
return it->second->client.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,13 +78,13 @@ ClientLibrary::~ClientLibrary() = default;
|
|||||||
service_options.set_platform(platform);
|
service_options.set_platform(platform);
|
||||||
service_options.set_number_of_replicas(replica_count);
|
service_options.set_number_of_replicas(replica_count);
|
||||||
|
|
||||||
std::unique_ptr<LocalInstance> instance = MakeUnique<LocalInstance>();
|
auto instance = MakeUnique<LocalInstance>();
|
||||||
TF_ASSIGN_OR_RETURN(instance->service,
|
TF_ASSIGN_OR_RETURN(instance->service,
|
||||||
LocalService::NewService(service_options));
|
LocalService::NewService(service_options));
|
||||||
instance->client = MakeUnique<LocalClient>(instance->service.get());
|
instance->client = MakeUnique<LocalClient>(instance->service.get());
|
||||||
LocalClient* cl = instance->client.get();
|
LocalClient* cl = instance->client.get();
|
||||||
|
|
||||||
client_library.instances_.insert(
|
client_library.local_instances_.insert(
|
||||||
std::make_pair(platform->id(), std::move(instance)));
|
std::make_pair(platform->id(), std::move(instance)));
|
||||||
return cl;
|
return cl;
|
||||||
}
|
}
|
||||||
@ -99,9 +99,35 @@ ClientLibrary::~ClientLibrary() = default;
|
|||||||
perftools::gputools::Platform* platform) {
|
perftools::gputools::Platform* platform) {
|
||||||
ClientLibrary& client_library = Singleton();
|
ClientLibrary& client_library = Singleton();
|
||||||
tensorflow::mutex_lock lock(client_library.service_mutex_);
|
tensorflow::mutex_lock lock(client_library.service_mutex_);
|
||||||
auto it = client_library.instances_.find(platform->id());
|
auto it = client_library.local_instances_.find(platform->id());
|
||||||
CHECK(it != client_library.instances_.end());
|
CHECK(it != client_library.local_instances_.end());
|
||||||
return it->second->service.get();
|
return it->second->service.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ StatusOr<CompileOnlyClient*>
|
||||||
|
ClientLibrary::GetOrCreateCompileOnlyClient(
|
||||||
|
perftools::gputools::Platform* platform) {
|
||||||
|
ClientLibrary& client_library = Singleton();
|
||||||
|
tensorflow::mutex_lock lock(client_library.service_mutex_);
|
||||||
|
|
||||||
|
if (platform == nullptr) {
|
||||||
|
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto it = client_library.compile_only_instances_.find(platform->id());
|
||||||
|
if (it != client_library.compile_only_instances_.end()) {
|
||||||
|
return it->second->client.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto instance = MakeUnique<CompileOnlyInstance>();
|
||||||
|
TF_ASSIGN_OR_RETURN(instance->service,
|
||||||
|
CompileOnlyService::NewService(platform));
|
||||||
|
instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
|
||||||
|
CompileOnlyClient* cl = instance->client.get();
|
||||||
|
|
||||||
|
client_library.compile_only_instances_.insert(
|
||||||
|
std::make_pair(platform->id(), std::move(instance)));
|
||||||
|
return cl;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -26,7 +26,9 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/compile_only_service.h"
|
||||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||||
#include "tensorflow/compiler/xla/service/local_service.h"
|
#include "tensorflow/compiler/xla/service/local_service.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
@ -76,6 +78,13 @@ class ClientLibrary {
|
|||||||
// access user computations from client.
|
// access user computations from client.
|
||||||
static LocalService* GetXlaService(perftools::gputools::Platform* platform);
|
static LocalService* GetXlaService(perftools::gputools::Platform* platform);
|
||||||
|
|
||||||
|
// Singleton constructor-or-accessor for compile-only clients. Arguments:
|
||||||
|
//
|
||||||
|
// platform : The platform the underlying XLA service should target. If
|
||||||
|
// null then default platform is used.
|
||||||
|
static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
|
||||||
|
perftools::gputools::Platform* platform = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Returns the singleton instance of ClientLibrary.
|
// Returns the singleton instance of ClientLibrary.
|
||||||
static ClientLibrary& Singleton();
|
static ClientLibrary& Singleton();
|
||||||
@ -90,10 +99,21 @@ class ClientLibrary {
|
|||||||
std::unique_ptr<LocalClient> client;
|
std::unique_ptr<LocalClient> client;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct CompileOnlyInstance {
|
||||||
|
// Service that is wrapped by the singleton client object.
|
||||||
|
std::unique_ptr<CompileOnlyService> service;
|
||||||
|
// Singleton client object.
|
||||||
|
std::unique_ptr<CompileOnlyClient> client;
|
||||||
|
};
|
||||||
|
|
||||||
tensorflow::mutex service_mutex_; // Guards the singleton creation state.
|
tensorflow::mutex service_mutex_; // Guards the singleton creation state.
|
||||||
std::unordered_map<perftools::gputools::Platform::Id,
|
std::unordered_map<perftools::gputools::Platform::Id,
|
||||||
std::unique_ptr<LocalInstance>>
|
std::unique_ptr<LocalInstance>>
|
||||||
instances_ GUARDED_BY(service_mutex_);
|
local_instances_ GUARDED_BY(service_mutex_);
|
||||||
|
|
||||||
|
std::unordered_map<perftools::gputools::Platform::Id,
|
||||||
|
std::unique_ptr<CompileOnlyInstance>>
|
||||||
|
compile_only_instances_ GUARDED_BY(service_mutex_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary);
|
TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary);
|
||||||
};
|
};
|
||||||
|
59
tensorflow/compiler/xla/client/compile_only_client.cc
Normal file
59
tensorflow/compiler/xla/client/compile_only_client.cc
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||||
|
|
||||||
|
#include "external/llvm/include/llvm/ADT/Triple.h"
|
||||||
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
|
||||||
|
namespace se = ::perftools::gputools;
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
|
CompileOnlyClient::CompileAheadOfTime(
|
||||||
|
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||||
|
const AotCompilationOptions& options) {
|
||||||
|
std::vector<CompileOnlyService::AotComputationInstance> service_instances;
|
||||||
|
service_instances.reserve(computations.size());
|
||||||
|
for (const AotComputationInstance& instance : computations) {
|
||||||
|
service_instances.push_back({});
|
||||||
|
CompileOnlyService::AotComputationInstance& service_instance =
|
||||||
|
service_instances.back();
|
||||||
|
TF_RET_CHECK(instance.computation != nullptr);
|
||||||
|
service_instance.computation = instance.computation->handle();
|
||||||
|
service_instance.argument_layouts = instance.argument_layouts;
|
||||||
|
service_instance.result_layout = instance.result_layout;
|
||||||
|
}
|
||||||
|
return compiler_service_->CompileAheadOfTime(service_instances, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 CompileOnlyClient::PointerSizeForTriple(
|
||||||
|
tensorflow::StringPiece target_triple) {
|
||||||
|
llvm::Triple triple(
|
||||||
|
llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
|
||||||
|
if (triple.isArch64Bit()) {
|
||||||
|
return 8;
|
||||||
|
} else if (triple.isArch32Bit()) {
|
||||||
|
return 4;
|
||||||
|
} else {
|
||||||
|
CHECK(triple.isArch16Bit());
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
66
tensorflow/compiler/xla/client/compile_only_client.h
Normal file
66
tensorflow/compiler/xla/client/compile_only_client.h
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/client.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/compile_only_service.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// An XLA Client specialization for doing ahead-of-time compilation. This does
|
||||||
|
// not require (or attempt to instantiate) an execution-capable backend for the
|
||||||
|
// relevant platform.
|
||||||
|
class CompileOnlyClient : public Client {
|
||||||
|
public:
|
||||||
|
explicit CompileOnlyClient(CompileOnlyService* service)
|
||||||
|
: Client(service), compiler_service_(service) {}
|
||||||
|
|
||||||
|
CompileOnlyClient(const CompileOnlyClient&) = delete;
|
||||||
|
void operator=(const CompileOnlyClient&) = delete;
|
||||||
|
|
||||||
|
// A description of a computation to compile using CompileAheadOfTime.
|
||||||
|
struct AotComputationInstance {
|
||||||
|
const Computation* computation;
|
||||||
|
// Inform the compiler of the expected layout for arguments.
|
||||||
|
std::vector<const Shape*> argument_layouts;
|
||||||
|
// Specifies the expected result layout.
|
||||||
|
const Shape* result_layout;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compiles a list of computations for ahead-of-time execution. This is
|
||||||
|
// intended for use in static compilation. The |options| parameter describes
|
||||||
|
// the target for which the compiler should emit code.
|
||||||
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
|
CompileAheadOfTime(
|
||||||
|
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||||
|
const AotCompilationOptions& options);
|
||||||
|
|
||||||
|
// Returns the size of a pointer in bytes for a given triple.
|
||||||
|
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
|
||||||
|
|
||||||
|
private:
|
||||||
|
CompileOnlyService* compiler_service_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
|
@ -261,38 +261,6 @@ tensorflow::Status LocalClient::ResolveArguments(
|
|||||||
argument_ptrs);
|
argument_ptrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
|
||||||
LocalClient::CompileAheadOfTime(
|
|
||||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
|
||||||
computations,
|
|
||||||
const AotCompilationOptions& options) {
|
|
||||||
std::vector<LocalService::AheadOfTimeComputationInstance> service_instances;
|
|
||||||
service_instances.reserve(computations.size());
|
|
||||||
for (const AheadOfTimeComputationInstance& instance : computations) {
|
|
||||||
service_instances.push_back({});
|
|
||||||
LocalService::AheadOfTimeComputationInstance& service_instance =
|
|
||||||
service_instances.back();
|
|
||||||
TF_RET_CHECK(instance.computation != nullptr);
|
|
||||||
service_instance.computation = instance.computation->handle();
|
|
||||||
service_instance.argument_layouts = instance.argument_layouts;
|
|
||||||
service_instance.result_layout = instance.result_layout;
|
|
||||||
}
|
|
||||||
return local_service_->CompileAheadOfTime(service_instances, options);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) {
|
|
||||||
llvm::Triple triple(
|
|
||||||
llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
|
|
||||||
if (triple.isArch64Bit()) {
|
|
||||||
return 8;
|
|
||||||
} else if (triple.isArch32Bit()) {
|
|
||||||
return 4;
|
|
||||||
} else {
|
|
||||||
CHECK(triple.isArch16Bit());
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
se::Platform* LocalClient::platform() const {
|
se::Platform* LocalClient::platform() const {
|
||||||
return local_service_->backend().platform();
|
return local_service_->backend().platform();
|
||||||
}
|
}
|
||||||
|
@ -148,7 +148,7 @@ class LocalExecutable {
|
|||||||
const ExecutableBuildOptions& build_options_;
|
const ExecutableBuildOptions& build_options_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// An XLA service client object for use when the client and service run in
|
// An XLA Client specialization for use when the client and service run in
|
||||||
// the same process.
|
// the same process.
|
||||||
class LocalClient : public Client {
|
class LocalClient : public Client {
|
||||||
public:
|
public:
|
||||||
@ -182,30 +182,6 @@ class LocalClient : public Client {
|
|||||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||||
const ExecutableBuildOptions& options);
|
const ExecutableBuildOptions& options);
|
||||||
|
|
||||||
// A description of a computation to compile using CompileAheadOfTime.
|
|
||||||
struct AheadOfTimeComputationInstance {
|
|
||||||
const Computation* computation;
|
|
||||||
// Inform the compiler of the expected layout for arguments.
|
|
||||||
std::vector<const Shape*> argument_layouts;
|
|
||||||
// Specifies the expected result layout.
|
|
||||||
const Shape* result_layout;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Compiles a list of computations for ahead-of-time execution. This is
|
|
||||||
// intended for use in static compilation. The |options| parameter describes
|
|
||||||
// the target for which the compiler should emit code.
|
|
||||||
//
|
|
||||||
// TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its
|
|
||||||
// own library.
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
|
||||||
CompileAheadOfTime(
|
|
||||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
|
||||||
computations,
|
|
||||||
const AotCompilationOptions& options);
|
|
||||||
|
|
||||||
// Returns the size of a pointer in bytes for a given triple.
|
|
||||||
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
|
|
||||||
|
|
||||||
// Returns the platform that the underlying service targets.
|
// Returns the platform that the underlying service targets.
|
||||||
perftools::gputools::Platform* platform() const;
|
perftools::gputools::Platform* platform() const;
|
||||||
|
|
||||||
|
@ -408,6 +408,27 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "compile_only_service",
|
||||||
|
srcs = ["compile_only_service.cc"],
|
||||||
|
hdrs = ["compile_only_service.h"],
|
||||||
|
deps = [
|
||||||
|
":backend",
|
||||||
|
":compiler",
|
||||||
|
":computation_layout",
|
||||||
|
":computation_tracker",
|
||||||
|
":platform_util",
|
||||||
|
":service",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cpu_plugin",
|
name = "cpu_plugin",
|
||||||
deps = [
|
deps = [
|
||||||
|
131
tensorflow/compiler/xla/service/compile_only_service.cc
Normal file
131
tensorflow/compiler/xla/service/compile_only_service.cc
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/compile_only_service.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/backend.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_tracker.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
|
|
||||||
|
namespace se = ::perftools::gputools;
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
/* static */ StatusOr<std::unique_ptr<CompileOnlyService>>
|
||||||
|
CompileOnlyService::NewService(perftools::gputools::Platform* platform) {
|
||||||
|
ServiceOptions default_options;
|
||||||
|
default_options.set_platform(platform);
|
||||||
|
return NewService(default_options);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* static */ StatusOr<std::unique_ptr<CompileOnlyService>>
|
||||||
|
CompileOnlyService::NewService(const ServiceOptions& options) {
|
||||||
|
perftools::gputools::Platform* platform = options.platform();
|
||||||
|
if (platform == nullptr) {
|
||||||
|
TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||||
|
CreateComputeConstantBackend());
|
||||||
|
std::unique_ptr<CompileOnlyService> service(
|
||||||
|
new CompileOnlyService(compiler, std::move(compute_constant_backend)));
|
||||||
|
return std::move(service);
|
||||||
|
}
|
||||||
|
|
||||||
|
CompileOnlyService::CompileOnlyService(
|
||||||
|
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend)
|
||||||
|
: Service(/*backend=*/nullptr, std::move(compute_constant_backend)),
|
||||||
|
compiler_(compiler) {
|
||||||
|
runs_in_client_process_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
|
CompileOnlyService::CompileAheadOfTime(
|
||||||
|
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||||
|
const AotCompilationOptions& options) {
|
||||||
|
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
||||||
|
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
|
||||||
|
for (const AotComputationInstance& instance : computations) {
|
||||||
|
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
|
||||||
|
computation_tracker_.Resolve(instance.computation));
|
||||||
|
VersionedComputationHandle versioned_handle =
|
||||||
|
user_computation->GetVersionedHandle();
|
||||||
|
|
||||||
|
// Dump computation proto state if flag is set.
|
||||||
|
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
|
||||||
|
const string& directory_path = flags->xla_dump_computations_to;
|
||||||
|
if (!directory_path.empty()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<SessionModule> session_module,
|
||||||
|
computation_tracker_.SnapshotComputation(versioned_handle.handle));
|
||||||
|
string filename = tensorflow::strings::StrCat(
|
||||||
|
"computation_", versioned_handle.handle.handle(), "__",
|
||||||
|
session_module->entry().name(), "__version_",
|
||||||
|
versioned_handle.version);
|
||||||
|
TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
|
||||||
|
*session_module));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
|
||||||
|
computation_tracker_.BuildHloModule(
|
||||||
|
versioned_handle,
|
||||||
|
/*include_unreachable_instructions=*/true));
|
||||||
|
hlo_modules.push_back(std::move(hlo_module));
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::shared_ptr<const ProgramShape> program_shape,
|
||||||
|
user_computation->ComputeProgramShape(versioned_handle.version));
|
||||||
|
|
||||||
|
module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
|
||||||
|
HloModuleConfig* module_config = module_configs.back().get();
|
||||||
|
auto* computation_layout =
|
||||||
|
module_config->mutable_entry_computation_layout();
|
||||||
|
if (flags->xla_hlo_profile) {
|
||||||
|
module_config->enable_hlo_profiling(true);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
|
||||||
|
const Shape& argument_layout = *instance.argument_layouts[i];
|
||||||
|
if (ShapeUtil::IsTuple(argument_layout)) {
|
||||||
|
return Unimplemented("tuple arguments not supported yet");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||||
|
argument_layout));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||||
|
*instance.result_layout));
|
||||||
|
}
|
||||||
|
|
||||||
|
return compiler_->CompileAheadOfTime(std::move(hlo_modules),
|
||||||
|
std::move(module_configs),
|
||||||
|
MakeHloDumper(), options);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
128
tensorflow/compiler/xla/service/compile_only_service.h
Normal file
128
tensorflow/compiler/xla/service/compile_only_service.h
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/backend.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/service.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// An XLA Service specialization for ahead-of-time compilation. This only
|
||||||
|
// instantiates a Compiler object for the relevant platform; it does not
|
||||||
|
// instantiate or require an execution backend.
|
||||||
|
class CompileOnlyService : public Service {
|
||||||
|
public:
|
||||||
|
// Factory for creating a CompileOnlyService. The parameter platform is the
|
||||||
|
// platform that the service should target. If platform is null then the
|
||||||
|
// default platform is used.
|
||||||
|
static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
|
||||||
|
perftools::gputools::Platform* platform);
|
||||||
|
static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
|
||||||
|
const ServiceOptions& options);
|
||||||
|
|
||||||
|
// A description of a computation to compile using CompileAheadOfTime.
|
||||||
|
struct AotComputationInstance {
|
||||||
|
ComputationHandle computation;
|
||||||
|
std::vector<const Shape*> argument_layouts;
|
||||||
|
const Shape* result_layout = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compiles a list of computations for ahead-of-time execution. This is
|
||||||
|
// intended for use in static compilation. See
|
||||||
|
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
|
||||||
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
|
CompileAheadOfTime(
|
||||||
|
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||||
|
const AotCompilationOptions& Options);
|
||||||
|
|
||||||
|
// Override Service methods that require an execute backend.
|
||||||
|
tensorflow::Status Execute(const ExecuteRequest* arg,
|
||||||
|
ExecuteResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support execution.");
|
||||||
|
}
|
||||||
|
tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
|
||||||
|
ExecuteParallelResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support execution.");
|
||||||
|
}
|
||||||
|
tensorflow::Status GetDeviceHandles(
|
||||||
|
const GetDeviceHandlesRequest* arg,
|
||||||
|
GetDeviceHandlesResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support devices.");
|
||||||
|
}
|
||||||
|
tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
|
||||||
|
ExecuteAsyncResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support execution.");
|
||||||
|
}
|
||||||
|
tensorflow::Status WaitForExecution(
|
||||||
|
const WaitForExecutionRequest* arg,
|
||||||
|
WaitForExecutionResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support execution.");
|
||||||
|
}
|
||||||
|
tensorflow::Status TransferToClient(
|
||||||
|
const TransferToClientRequest* arg,
|
||||||
|
TransferToClientResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||||
|
}
|
||||||
|
tensorflow::Status TransferToClientInProcess(
|
||||||
|
const TransferToClientInProcessRequest* arg,
|
||||||
|
TransferToClientInProcessResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||||
|
}
|
||||||
|
tensorflow::Status TransferToServer(
|
||||||
|
const TransferToServerRequest* arg,
|
||||||
|
TransferToServerResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||||
|
}
|
||||||
|
tensorflow::Status TransferToInfeed(
|
||||||
|
const TransferToInfeedRequest* arg,
|
||||||
|
TransferToInfeedResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||||
|
}
|
||||||
|
tensorflow::Status TransferFromOutfeed(
|
||||||
|
const TransferFromOutfeedRequest* arg,
|
||||||
|
TransferFromOutfeedResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||||
|
}
|
||||||
|
tensorflow::Status TransferToServerInProcess(
|
||||||
|
const TransferToServerInProcessRequest* arg,
|
||||||
|
TransferToServerInProcessResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support data transfers.");
|
||||||
|
}
|
||||||
|
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
|
||||||
|
ResetDeviceResponse* result) override {
|
||||||
|
return Unimplemented("CompileOnlyService does not support devices.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
explicit CompileOnlyService(
|
||||||
|
Compiler* compiler, std::unique_ptr<Backend> compute_constant_backend);
|
||||||
|
CompileOnlyService(const CompileOnlyService&) = delete;
|
||||||
|
void operator=(const CompileOnlyService&) = delete;
|
||||||
|
|
||||||
|
// The compiler for the target platform. This is included in place of
|
||||||
|
// the Service::execute_backend_'s compiler, since execute_backend_ is a
|
||||||
|
// nullptr in CompileOnlyService.
|
||||||
|
Compiler* compiler_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_
|
@ -128,70 +128,6 @@ StatusOr<GlobalDataHandle> LocalService::AllocateBufferOnDevice(
|
|||||||
allocation_size));
|
allocation_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
|
||||||
LocalService::CompileAheadOfTime(
|
|
||||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
|
||||||
computations,
|
|
||||||
const AotCompilationOptions& options) {
|
|
||||||
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
|
||||||
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
|
|
||||||
for (const AheadOfTimeComputationInstance& instance : computations) {
|
|
||||||
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
|
|
||||||
computation_tracker_.Resolve(instance.computation));
|
|
||||||
VersionedComputationHandle versioned_handle =
|
|
||||||
user_computation->GetVersionedHandle();
|
|
||||||
|
|
||||||
// Dump computation proto state if flag is set.
|
|
||||||
legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
|
|
||||||
const string& directory_path = flags->xla_dump_computations_to;
|
|
||||||
if (!directory_path.empty()) {
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
std::unique_ptr<SessionModule> session_module,
|
|
||||||
computation_tracker_.SnapshotComputation(versioned_handle.handle));
|
|
||||||
string filename = tensorflow::strings::StrCat(
|
|
||||||
"computation_", versioned_handle.handle.handle(), "__",
|
|
||||||
session_module->entry().name(), "__version_",
|
|
||||||
versioned_handle.version);
|
|
||||||
TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
|
|
||||||
*session_module));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
|
|
||||||
computation_tracker_.BuildHloModule(
|
|
||||||
versioned_handle,
|
|
||||||
/*include_unreachable_instructions=*/true));
|
|
||||||
hlo_modules.push_back(std::move(hlo_module));
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
std::shared_ptr<const ProgramShape> program_shape,
|
|
||||||
user_computation->ComputeProgramShape(versioned_handle.version));
|
|
||||||
|
|
||||||
module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
|
|
||||||
HloModuleConfig* module_config = module_configs.back().get();
|
|
||||||
auto* computation_layout =
|
|
||||||
module_config->mutable_entry_computation_layout();
|
|
||||||
if (flags->xla_hlo_profile) {
|
|
||||||
module_config->enable_hlo_profiling(true);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
|
|
||||||
const Shape& argument_layout = *instance.argument_layouts[i];
|
|
||||||
if (ShapeUtil::IsTuple(argument_layout)) {
|
|
||||||
return Unimplemented("tuple arguments not supported yet");
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
|
||||||
argument_layout));
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
|
||||||
*instance.result_layout));
|
|
||||||
}
|
|
||||||
|
|
||||||
return execute_backend_->compiler()->CompileAheadOfTime(
|
|
||||||
std::move(hlo_modules), std::move(module_configs), MakeHloDumper(),
|
|
||||||
options);
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||||
const ComputationHandle& computation,
|
const ComputationHandle& computation,
|
||||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||||
|
@ -59,22 +59,6 @@ class LocalService : public Service {
|
|||||||
const Shape& shape, int device_ordinal,
|
const Shape& shape, int device_ordinal,
|
||||||
bool allocate_space_for_deep_copy);
|
bool allocate_space_for_deep_copy);
|
||||||
|
|
||||||
// A description of a computation to compile using CompileAheadOfTime.
|
|
||||||
struct AheadOfTimeComputationInstance {
|
|
||||||
ComputationHandle computation;
|
|
||||||
std::vector<const Shape*> argument_layouts;
|
|
||||||
const Shape* result_layout = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Compiles a list of computations for ahead-of-time execution. This is
|
|
||||||
// intended for use in static compilation. See
|
|
||||||
// |LocalClient::CompileAheadOfTime| for additional details.
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
|
||||||
CompileAheadOfTime(
|
|
||||||
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
|
|
||||||
computations,
|
|
||||||
const AotCompilationOptions& Options);
|
|
||||||
|
|
||||||
// Builds an Executable with the given argument layouts and options. If
|
// Builds an Executable with the given argument layouts and options. If
|
||||||
// result_layout is non-null, then the executable is compiled to produce a
|
// result_layout is non-null, then the executable is compiled to produce a
|
||||||
// result of the given layout.
|
// result of the given layout.
|
||||||
|
@ -180,20 +180,24 @@ Service::Service(std::unique_ptr<Backend> execute_backend,
|
|||||||
std::unique_ptr<Backend> compute_constant_backend)
|
std::unique_ptr<Backend> compute_constant_backend)
|
||||||
: execute_backend_(std::move(execute_backend)),
|
: execute_backend_(std::move(execute_backend)),
|
||||||
compute_constant_backend_(std::move(compute_constant_backend)) {
|
compute_constant_backend_(std::move(compute_constant_backend)) {
|
||||||
LOG(INFO) << Printf(
|
if (execute_backend_) {
|
||||||
"XLA service %p executing computations on platform %s. Devices:", this,
|
LOG(INFO) << Printf(
|
||||||
execute_backend_->platform()->Name().c_str());
|
"XLA service %p executing computations on platform %s. Devices:", this,
|
||||||
for (int i = 0; i < execute_backend_->device_count(); ++i) {
|
execute_backend_->platform()->Name().c_str());
|
||||||
if (execute_backend_->device_ordinal_supported(i)) {
|
for (int i = 0; i < execute_backend_->device_count(); ++i) {
|
||||||
se::StreamExecutor* executor =
|
if (execute_backend_->device_ordinal_supported(i)) {
|
||||||
execute_backend_->stream_executor(i).ValueOrDie();
|
se::StreamExecutor* executor =
|
||||||
const auto& description = executor->GetDeviceDescription();
|
execute_backend_->stream_executor(i).ValueOrDie();
|
||||||
LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
|
const auto& description = executor->GetDeviceDescription();
|
||||||
description.name().c_str(),
|
LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i,
|
||||||
description.platform_version().c_str());
|
description.name().c_str(),
|
||||||
} else {
|
description.platform_version().c_str());
|
||||||
LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
|
} else {
|
||||||
|
LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
VLOG(1) << "XLA compile-only service constructed";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ xla::Computation Doubler(xla::Client* client) {
|
|||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||||
|
|
||||||
auto client = xla::ClientLibrary::LocalClientOrDie();
|
auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie();
|
||||||
|
|
||||||
xla::ComputationBuilder builder(client, "aot_test_helper");
|
xla::ComputationBuilder builder(client, "aot_test_helper");
|
||||||
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
|
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
|
||||||
@ -74,7 +74,7 @@ int main(int argc, char** argv) {
|
|||||||
llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
|
llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
|
||||||
|
|
||||||
xla::Computation computation = builder.Build().ConsumeValueOrDie();
|
xla::Computation computation = builder.Build().ConsumeValueOrDie();
|
||||||
xla::LocalClient::AheadOfTimeComputationInstance instance{
|
xla::CompileOnlyClient::AotComputationInstance instance{
|
||||||
&computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
|
&computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
|
||||||
|
|
||||||
xla::cpu::CpuAotCompilationOptions options(
|
xla::cpu::CpuAotCompilationOptions options(
|
||||||
|
1
tensorflow/opensource_only/eigen.threadpool
Normal file
1
tensorflow/opensource_only/eigen.threadpool
Normal file
@ -0,0 +1 @@
|
|||||||
|
#include "unsupported/Eigen/CXX11/ThreadPool"
|
Loading…
Reference in New Issue
Block a user