[XLA] Redesign: support xla::XlaComputation in compile-only client and service.
PiperOrigin-RevId: 193247845
This commit is contained in:
parent
fabf010116
commit
72df3d60fa
tensorflow/compiler/xla
@ -130,6 +130,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:compile_only_service",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
|
@ -39,6 +39,24 @@ CompileOnlyClient::CompileAheadOfTime(
|
||||
return compiler_service_->CompileAheadOfTime(service_instances, options);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyClient::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options) {
|
||||
std::vector<CompileOnlyService::AotXlaComputationInstance> service_instances;
|
||||
service_instances.reserve(computations.size());
|
||||
for (const AotXlaComputationInstance& instance : computations) {
|
||||
service_instances.emplace_back();
|
||||
CompileOnlyService::AotXlaComputationInstance& service_instance =
|
||||
service_instances.back();
|
||||
TF_RET_CHECK(instance.computation != nullptr);
|
||||
service_instance.computation = instance.computation->proto();
|
||||
service_instance.argument_layouts = instance.argument_layouts;
|
||||
service_instance.result_layout = instance.result_layout;
|
||||
}
|
||||
return compiler_service_->CompileAheadOfTime(service_instances, options);
|
||||
}
|
||||
|
||||
int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) {
|
||||
llvm::Triple llvm_triple(
|
||||
llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size())));
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/compile_only_service.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -54,6 +55,27 @@ class CompileOnlyClient : public Client {
|
||||
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||
const AotCompilationOptions& options);
|
||||
|
||||
// A description of an xla computation to compile using CompileAheadOfTime.
|
||||
//
|
||||
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
|
||||
struct AotXlaComputationInstance {
|
||||
const XlaComputation* computation;
|
||||
// Inform the compiler of the expected layout for arguments.
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
// Specifies the expected result layout.
|
||||
const Shape* result_layout;
|
||||
};
|
||||
|
||||
// Compiles a list of xla computations for ahead-of-time execution. This is
|
||||
// intended for use in static compilation. The |options| parameter describes
|
||||
// the target for which the compiler should emit code.
|
||||
//
|
||||
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options);
|
||||
|
||||
// Returns the size of a pointer in bytes for a given triple.
|
||||
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
|
||||
|
||||
|
@ -61,6 +61,33 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
|
||||
Compiler* compiler)
|
||||
: Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyService::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options) {
|
||||
std::vector<std::unique_ptr<HloModule>> hlo_modules;
|
||||
for (const AotXlaComputationInstance& instance : computations) {
|
||||
TF_RET_CHECK(instance.computation.has_program_shape());
|
||||
|
||||
const DebugOptions& debug_options = options.debug_options();
|
||||
const auto& program_shape = instance.computation.program_shape();
|
||||
ExecutionOptions execution_options;
|
||||
*execution_options.mutable_debug_options() = debug_options;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(program_shape, instance.argument_layouts,
|
||||
&execution_options));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
HloModule::CreateFromProto(instance.computation, *module_config));
|
||||
TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
|
||||
hlo_modules.push_back(std::move(hlo_module));
|
||||
}
|
||||
|
||||
return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyService::CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||
|
@ -53,6 +53,25 @@ class CompileOnlyService : public Service {
|
||||
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
|
||||
const AotCompilationOptions& Options);
|
||||
|
||||
// A description of a xla computation to compile using CompileAheadOfTime.
|
||||
//
|
||||
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
|
||||
struct AotXlaComputationInstance {
|
||||
HloModuleProto computation;
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
const Shape* result_layout = nullptr;
|
||||
};
|
||||
|
||||
// Compiles a list of xla computations for ahead-of-time execution. This is
|
||||
// intended for use in static compilation. See
|
||||
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
|
||||
//
|
||||
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(
|
||||
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
|
||||
const AotCompilationOptions& options);
|
||||
|
||||
// Override Service methods that require or imply the existence of an
|
||||
// execute backend. Note that this does not include TransferToClient, as
|
||||
// computing constants produces global data that we may wish to transfer.
|
||||
|
Loading…
Reference in New Issue
Block a user