[XLA] Redesign: support xla::XlaComputation in compile-only client and service.

PiperOrigin-RevId: 193247845
This commit is contained in:
A. Unique TensorFlower 2018-04-17 13:36:46 -07:00 committed by TensorFlower Gardener
parent fabf010116
commit 72df3d60fa
5 changed files with 87 additions and 0 deletions

View File

@ -130,6 +130,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:compile_only_service",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:stream_executor_no_cuda",

View File

@ -39,6 +39,24 @@ CompileOnlyClient::CompileAheadOfTime(
return compiler_service_->CompileAheadOfTime(service_instances, options);
}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyClient::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
const AotCompilationOptions& options) {
std::vector<CompileOnlyService::AotXlaComputationInstance> service_instances;
service_instances.reserve(computations.size());
for (const AotXlaComputationInstance& instance : computations) {
service_instances.emplace_back();
CompileOnlyService::AotXlaComputationInstance& service_instance =
service_instances.back();
TF_RET_CHECK(instance.computation != nullptr);
service_instance.computation = instance.computation->proto();
service_instance.argument_layouts = instance.argument_layouts;
service_instance.result_layout = instance.result_layout;
}
return compiler_service_->CompileAheadOfTime(service_instances, options);
}
int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) {
llvm::Triple llvm_triple(
llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size())));

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/service/compile_only_service.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -54,6 +55,27 @@ class CompileOnlyClient : public Client {
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
const AotCompilationOptions& options);
// A description of an xla computation to compile using CompileAheadOfTime.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
struct AotXlaComputationInstance {
const XlaComputation* computation;
// Inform the compiler of the expected layout for arguments.
std::vector<const Shape*> argument_layouts;
// Specifies the expected result layout.
const Shape* result_layout;
};
// Compiles a list of xla computations for ahead-of-time execution. This is
// intended for use in static compilation. The |options| parameter describes
// the target for which the compiler should emit code.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
// Returns the size of a pointer in bytes for a given triple.
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);

View File

@ -61,6 +61,33 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
Compiler* compiler)
: Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
const AotCompilationOptions& options) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
for (const AotXlaComputationInstance& instance : computations) {
TF_RET_CHECK(instance.computation.has_program_shape());
const DebugOptions& debug_options = options.debug_options();
const auto& program_shape = instance.computation.program_shape();
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, instance.argument_layouts,
&execution_options));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
HloModule::CreateFromProto(instance.computation, *module_config));
TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
hlo_modules.push_back(std::move(hlo_module));
}
return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,

View File

@ -53,6 +53,25 @@ class CompileOnlyService : public Service {
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
const AotCompilationOptions& Options);
// A description of a xla computation to compile using CompileAheadOfTime.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
struct AotXlaComputationInstance {
HloModuleProto computation;
std::vector<const Shape*> argument_layouts;
const Shape* result_layout = nullptr;
};
// Compiles a list of xla computations for ahead-of-time execution. This is
// intended for use in static compilation. See
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
// Override Service methods that require or imply the existence of an
// execute backend. Note that this does not include TransferToClient, as
// computing constants produces global data that we may wish to transfer.