[XLA] Split ExecuteGraph into Compile and Execute.

- The ExecuteGraph method is removed from the service interface.
- Client::Execute calls Compile + Execute instead of ExecuteGraph. The existing
  users of this method is not affected.
- The Compile compiles the graph into an executable. Since the argument shapes
  will affect how the graph is compiled, the Client::Compile has an argument
  `argument_shapes`, which must be the same as the shapes of the arguments being
  used in the Execute method.
- The service cache the exectuables.

PiperOrigin-RevId: 220569355
This commit is contained in:
Cong Liu 2018-11-07 19:16:10 -08:00 committed by TensorFlower Gardener
parent bb878129e1
commit 029d65ecbe
14 changed files with 400 additions and 92 deletions

View File

@ -210,11 +210,10 @@ StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
return XlaComputation(module.hlo().hlo_module());
}
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
ExecuteGraphRequest request;
StatusOr<ExecutionHandle> Client::Compile(
const XlaComputation& computation, absl::Span<const Shape> argument_shapes,
const ExecutionOptions* execution_options) {
CompileRequest request;
*request.mutable_computation() = computation.proto();
if (execution_options == nullptr) {
@ -222,6 +221,34 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
} else {
*request.mutable_execution_options() = *execution_options;
}
if (request.execution_options().device_handles_size() > 1) {
return InvalidArgument(
"Compiling with multiple device handles is not supported. Use "
"'Execute' instead.");
}
// The argument shapes affect how the computation is compiled.
for (const auto& arg_shape : argument_shapes) {
*request.add_input_shape_with_layout() = arg_shape;
}
CompileResponse response;
VLOG(1) << "making compile request: " << request.ShortDebugString();
Status s = stub_->Compile(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
TF_RET_CHECK(response.has_handle());
return response.handle();
}
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
ExecutionProfile* execution_profile) {
ExecuteRequest request;
*request.mutable_handle() = handle;
for (GlobalData* argument : arguments) {
CHECK(argument != nullptr) << "Argument pointers must not be null.";
*request.add_arguments() = argument->handle();
@ -229,7 +256,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
ExecuteResponse response;
VLOG(1) << "making execute request: " << request.ShortDebugString();
Status s = stub_->ExecuteGraph(&request, &response);
Status s = stub_->Execute(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
@ -238,15 +265,62 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
if (execution_profile != nullptr) {
*execution_profile = response.profile();
}
return absl::make_unique<GlobalData>(stub_, response.output());
}
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
if (execution_options != nullptr &&
execution_options->device_handles_size() > 1) {
std::vector<XlaComputationInstance> computation_instances = {
XlaComputationInstance{
computation,
std::vector<GlobalData*>(arguments.begin(), arguments.end()),
*execution_options, execution_profile}};
TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances));
// The result selection is a bit hacky, but better than assuming it is
// device 0.
//
// TODO(b/118493728): Allow Execute to return one result per computation.
for (int64 i = 0; i < results.size(); i++) {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i]));
if (!ShapeUtil::IsEmptyTuple(shape)) {
VLOG(3) << "Fetching result from device " << i << ": "
<< ShapeUtil::HumanString(shape);
return std::move(results[i]);
}
}
TF_RET_CHECK(!results.empty());
VLOG(1) << "Defaulting to device 0 result";
return std::move(results[0]);
}
// The argument shapes affect how the computation is compiled.
std::vector<Shape> arg_shapes(arguments.size());
for (int i = 0; i < arguments.size(); i++) {
TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i]));
}
TF_ASSIGN_OR_RETURN(auto handle,
Compile(computation, arg_shapes, execution_options));
TF_ASSIGN_OR_RETURN(auto result,
Execute(handle, arguments, execution_profile));
if (execution_profile != nullptr) {
if (VLOG_IS_ON(1)) {
TF_ASSIGN_OR_RETURN(
auto execution_stats,
ExecutionStatsAsString(computation, response.profile()));
ExecutionStatsAsString(computation, *execution_profile));
VLOG(1) << execution_stats;
}
}
return absl::make_unique<GlobalData>(stub_, response.output());
return std::move(result);
}
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
@ -274,10 +348,11 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
}
std::vector<std::unique_ptr<GlobalData>> outputs;
for (size_t i = 0; i < computations.size(); ++i) {
for (size_t i = 0; i < response.responses_size(); ++i) {
outputs.push_back(
absl::make_unique<GlobalData>(stub_, response.responses(i).output()));
if (computations[i].execution_profile != nullptr) {
if (i < computations.size() &&
computations[i].execution_profile != nullptr) {
*computations[i].execution_profile = response.responses(i).profile();
}
}

View File

@ -40,6 +40,31 @@ class Client {
explicit Client(ServiceInterface* stub);
virtual ~Client();
// Compile the computation with the given argument shapes and returns the
// handle to the compiled executable. The compiled executable is cached on the
// service, and the returned handle can be used for exection without
// re-compile.
// * The shape and layout of the arguments being executed with will affect how
// the computation is compiled. If argument_shapes is empty, the parameters'
// shape and layout will be used in the compilation.
// * If execution_options is not nullptr, these options are passed to the
// service to affect how it compiles our computation. (The pointer does not
// need to live beyond this call.)
// * If execution_options.device_handles should be empty. If you need
// non-empty device handles, call 'Execute' instead.
StatusOr<ExecutionHandle> Compile(
const XlaComputation& computation,
absl::Span<const Shape> argument_shapes,
const ExecutionOptions* execution_options = nullptr);
// Executes the compiled executable for the given handle with the given
// arguments and returns the global data that was produced from the execution.
// * If execution_profile is not nullptr then the pointed-to ExecutionProfile
// will be filled with profile data from the execution.
StatusOr<std::unique_ptr<GlobalData>> Execute(
const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
ExecutionProfile* execution_profile = nullptr);
// Executes the computation with the given arguments and returns the global
// data that was produced from the execution.
// * If execution_options is not nullptr, these options are passed to the

View File

@ -47,11 +47,18 @@ namespace xla {
});
}
::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/,
const ExecuteGraphRequest* arg,
ExecuteResponse* result) {
::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/,
const CompileRequest* arg,
CompileResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->ExecuteGraph(arg, result); });
[this, arg, result]() { return service_->Compile(arg, result); });
}
::grpc::Status GRPCService::Execute(::grpc::ServerContext* /*context*/,
const ExecuteRequest* arg,
ExecuteResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->Execute(arg, result); });
}
::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context,

View File

@ -39,9 +39,13 @@ class GRPCService : public grpc::XlaService::Service {
const DeconstructTupleRequest* arg,
DeconstructTupleResponse* result) override;
::grpc::Status ExecuteGraph(::grpc::ServerContext* context,
const ExecuteGraphRequest* arg,
ExecuteResponse* result) override;
::grpc::Status Compile(::grpc::ServerContext* context,
const CompileRequest* arg,
CompileResponse* result) override;
::grpc::Status Execute(::grpc::ServerContext* context,
const ExecuteRequest* arg,
ExecuteResponse* result) override;
::grpc::Status WaitForExecution(::grpc::ServerContext* context,
const WaitForExecutionRequest* arg,

View File

@ -62,10 +62,17 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request,
});
}
Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request,
ExecuteResponse* response) {
Status GRPCStub::Compile(const CompileRequest* request,
CompileResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ExecuteGraph(context, *request, response);
return grpc_stub_->Compile(context, *request, response);
});
}
Status GRPCStub::Execute(const ExecuteRequest* request,
ExecuteResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->Execute(context, *request, response);
});
}

View File

@ -43,8 +43,11 @@ class GRPCStub : public ServiceInterface {
Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override;
Status ExecuteGraph(const ExecuteGraphRequest* request,
ExecuteResponse* response) override;
Status Compile(const CompileRequest* request,
CompileResponse* response) override;
Status Execute(const ExecuteRequest* request,
ExecuteResponse* response) override;
Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request,
ExecuteParallelResponse* response) override;

View File

@ -128,11 +128,14 @@ service XlaService {
returns (CreateChannelHandleResponse) {
}
// Invokes the provided computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Compiles the provided computation into executable. Returns the handle of
// the executable.
rpc Compile(CompileRequest) returns (CompileResponse) {}
// Invokes the provided executable with the provided global data passed as
// immutable arguments. The request contains the handle to the executable.
// Returns global data output and execution timing.
rpc ExecuteGraph(ExecuteGraphRequest) returns (ExecuteResponse) {
}
rpc Execute(ExecuteRequest) returns (ExecuteResponse) {}
// Invokes the provided list of computations in parallel with the provided
// global data for each computation. Returns a list of global data output and

View File

@ -647,6 +647,7 @@ cc_library(
":allocation_tracker",
":backend",
":channel_tracker",
":compilation_cache",
":compiler",
":computation_layout",
":device_memory_allocator",
@ -2337,6 +2338,20 @@ tf_cc_test(
],
)
cc_library(
name = "compilation_cache",
srcs = ["compilation_cache.cc"],
hdrs = ["compilation_cache.h"],
deps = [
":executable",
":hlo_module_config",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
cc_library(
name = "layout_assignment",
srcs = [

View File

@ -0,0 +1,70 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/compilation_cache.h"
#include <utility>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace {
int64 GetUniqueId() {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static int64 counter = 0;
tensorflow::mutex_lock loc(mu);
const int64 id = counter++;
return id;
}
} // namespace
ExecutionHandle CompilationCache::Insert(
std::unique_ptr<Executable> executable) {
tensorflow::mutex_lock lock(mutex_);
CacheKey key = GetUniqueId();
VLOG(2) << "inserting cache key: " << key;
CHECK_EQ(cache_.count(key), 0);
cache_.emplace(key, std::move(executable));
ExecutionHandle handle;
handle.set_handle(key);
return handle;
}
StatusOr<std::shared_ptr<Executable>> CompilationCache::LookUp(
const ExecutionHandle& handle) const {
tensorflow::mutex_lock lock(mutex_);
CacheKey key = handle.handle();
VLOG(2) << "looking up cache key: " << key;
if (cache_.count(key) == 0) {
VLOG(2) << "cache key not found: " << key;
return InvalidArgumentStrCat("can not find executable with handle ", key);
} else {
auto& result = cache_.at(key);
VLOG(2) << "hit executable: " << result->module().name();
return result;
}
}
} // namespace xla

View File

@ -0,0 +1,62 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
#include <map>
#include <memory>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace xla {
// A cache which stores Executables indexed by computation handle and version.
//
// TODO(b/119042872): Provide mechanism for removing computations from the
// compilation cache.
class CompilationCache {
public:
CompilationCache() {}
ExecutionHandle Insert(std::unique_ptr<Executable> executable);
// Lookup the Executable for the specified handle in the cache. Return a
// shared_ptr to the Executable if it exists in the cache.
StatusOr<std::shared_ptr<Executable>> LookUp(
const ExecutionHandle& handle) const;
protected:
mutable tensorflow::mutex mutex_;
using CacheKey = int64;
absl::flat_hash_map<CacheKey, std::shared_ptr<Executable>> cache_
GUARDED_BY(mutex_);
private:
TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_

View File

@ -760,38 +760,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
return Status::OK();
}
Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg,
ExecuteResponse* result) {
ExecuteGraphParallelRequest parallel_arg;
*parallel_arg.add_requests() = *arg;
ExecuteParallelResponse parallel_result;
TF_RETURN_IF_ERROR(ExecuteGraphParallel(&parallel_arg, &parallel_result));
return PickParallelResponse(parallel_result, result);
}
Status Service::PickParallelResponse(
const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) {
// The "result device" selection is a bit hacky, but better than assuming it
// is device 0. We have b/76035356 for restructuring the client API to clean
// up the current asymmetries and support more functionalities.
for (int64 i = 0; i < parallel_result.responses_size(); ++i) {
TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
allocation_tracker_.ResolveForReplica(
parallel_result.responses(i).output(), 0));
const Shape& shape = buffer->on_host_shape();
if (!ShapeUtil::IsEmptyTuple(shape)) {
*result = parallel_result.responses(i);
VLOG(3) << "Fetching result from device " << i << ": "
<< ShapeUtil::HumanString(shape);
return Status::OK();
}
}
TF_RET_CHECK(parallel_result.responses_size() > 0);
*result = parallel_result.responses(0);
VLOG(1) << "Defaulting to device 0 result";
return Status::OK();
}
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
@ -836,10 +804,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
return std::move(executable);
}
Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
ExecuteResponse* result) {
VLOG(1) << "running execute-graph request";
Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
VLOG(1) << "running compile request";
if (!arg->has_computation()) {
return InvalidArgument("computations may not be empty");
}
@ -847,22 +813,21 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
return InvalidArgument("programe shape may not be empty");
}
// If we received multiple device handles, we must partition the module.
if (arg->execution_options().device_handles_size() > 1) {
return ExecuteOneToN(arg, result);
return InvalidArgument(
"The compile request does not support multiple device handles.");
}
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
SingleComputationDeviceHandle()));
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
ResolveAndValidateArguments(arg->arguments(), replicas));
std::vector<const Shape*> argument_shapes;
absl::c_transform(arg->input_shape_with_layout(),
std::back_inserter(argument_shapes),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(arg->computation().host_program_shape(),
replicated_arguments.front(),
arg->execution_options()));
argument_shapes, &arg->execution_options()));
VLOG(3) << "Compile created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
@ -871,6 +836,48 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
execute_backend_->default_stream_executor(),
/*device_allocator=*/nullptr));
*result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
VLOG(1) << "successfully completed 'compile' request";
return Status::OK();
}
Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
VLOG(1) << "running execute request";
if (!arg->has_handle()) {
return InvalidArgument("execution handle should not be empty");
}
TF_ASSIGN_OR_RETURN(auto executable,
compilation_cache_.LookUp(arg->handle()));
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
SingleComputationDeviceHandle()));
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
ResolveAndValidateArguments(arg->arguments(), replicas));
// Check that the replicated_arguments has the same shape and layout as the
// module config used when creating the exectuable.
const int64 num_module_args =
executable->module_config().entry_computation_layout().parameter_count();
if (num_module_args != arg->arguments_size()) {
return InvalidArgument(
"The executable expects %lld arguments, but sees %lld.",
num_module_args, arg->arguments_size());
}
for (int64 i = 0; i < num_module_args; i++) {
const Shape& shape_module =
executable->module_config().entry_computation_layout().parameter_shape(
i);
const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape();
if (!ShapeUtil::Equal(shape_module, shape_arg)) {
return InvalidArgumentStrCat(
"The executable exepcts the ", i, "th argument in shape ",
ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
ShapeUtil::HumanStringWithLayout(shape_arg));
}
}
TF_ASSIGN_OR_RETURN(auto stream,
execute_backend_->BorrowStream(
execute_backend_->default_stream_executor()));
@ -884,9 +891,10 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
TF_ASSIGN_OR_RETURN(
*result->mutable_output(),
ExecuteAndRegisterResult(
executable.get(), replicated_arguments, execute_backend_.get(),
"result of " + arg->computation().name(), result->mutable_profile()));
ExecuteAndRegisterResult(executable.get(), replicated_arguments,
execute_backend_.get(),
"result of " + executable->module().name(),
result->mutable_profile()));
if (executable->dumping_snapshot()) {
TF_ASSIGN_OR_RETURN(
@ -898,7 +906,7 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
TF_RETURN_IF_ERROR(executable->DumpHloSnapshot());
}
VLOG(1) << "successfully completed 'execute-graph' request";
VLOG(1) << "successfully completed 'execute' request";
return Status::OK();
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/allocation_tracker.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/channel_tracker.h"
#include "tensorflow/compiler/xla/service/compilation_cache.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/execution_tracker.h"
@ -90,11 +91,14 @@ class Service : public ServiceInterface {
Status DeconstructTuple(const DeconstructTupleRequest* arg,
DeconstructTupleResponse* result) override;
// Executes a computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Returns global data output and execution timing.
Status ExecuteGraph(const ExecuteGraphRequest* arg,
ExecuteResponse* result) override;
// Compiles a computation into an executable. The request contains the whole
// computation graph. Returns the handle to the executable.
Status Compile(const CompileRequest* arg, CompileResponse* result) override;
// Executes an executable with the provided global data passes as immutable
// arguments. The request contains the handle to the executable. Returns
// global data output and execution timing.
Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
// Executes one or more computations in parallel with the provided global data
// passed as immutable arguments. Returns global data output for each
@ -179,10 +183,6 @@ class Service : public ServiceInterface {
absl::Span<const ShapedBuffer* const> arguments,
const ExecutionOptions& execution_options);
// Picks a parallel response and fills the result.
Status PickParallelResponse(const ExecuteParallelResponse& parallel_result,
ExecuteResponse* result);
// Prepare the executors for executing parallel.
StatusOr<std::vector<se::StreamExecutor*>> GetExecutors(
const ExecutionOptions& execution_options, int64 requests_size,
@ -254,11 +254,6 @@ class Service : public ServiceInterface {
Backend* backend, absl::Span<const DeviceHandle> device_handles,
absl::Span<const string> result_tags, ExecutionProfile* profile);
// Executes a single computation which has more than one target device.
// The N devices are expected to all return an empty tuple, but one, which
// will be the result of this computation.
Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result);
// Convenience function which checks whether the given client_shape
// (presumably passed by the client to set the result layout) is valid for the
// given computation result shape.
@ -281,6 +276,9 @@ class Service : public ServiceInterface {
ServiceOptions options_;
// Cache containing previously built Executables.
CompilationCache compilation_cache_;
// Tracks channels created via the API.
ChannelTracker channel_tracker_;

View File

@ -47,8 +47,11 @@ class ServiceInterface {
virtual Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) = 0;
virtual Status ExecuteGraph(const ExecuteGraphRequest* arg,
ExecuteResponse* result) = 0;
virtual Status Compile(const CompileRequest* arg,
CompileResponse* result) = 0;
virtual Status Execute(const ExecuteRequest* arg,
ExecuteResponse* result) = 0;
virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
ExecuteParallelResponse* result) = 0;

View File

@ -322,6 +322,34 @@ message UnregisterRequest {
message UnregisterResponse {
}
message CompileRequest {
// The graph to be compiled.
HloModuleProto computation = 1;
// Options that affect how XLA compiles code to service this request.
ExecutionOptions execution_options = 2;
// The layouts of the input arguments. If not set, the default layout will be
// used. Although the real arguments are not needed in compilation, the
// layouts of the arguments can affect the compilation.
repeated Shape input_shape_with_layout = 3;
}
message CompileResponse {
// The handle to the executable.
ExecutionHandle handle = 1;
}
message ExecuteRequest {
ExecutionHandle handle = 1;
// The shape and layout of the arguments must be the same as the those of the
// executable's parameters.
repeated GlobalDataHandle arguments = 2;
}
// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace
// the uses with calls to Compile and Execute.
message ExecuteGraphRequest {
HloModuleProto computation = 1;
repeated GlobalDataHandle arguments = 2;