[XLA] Split ExecuteGraph into Compile and Execute.
- The ExecuteGraph method is removed from the service interface. - Client::Execute calls Compile + Execute instead of ExecuteGraph. The existing users of this method is not affected. - The Compile compiles the graph into an executable. Since the argument shapes will affect how the graph is compiled, the Client::Compile has an argument `argument_shapes`, which must be the same as the shapes of the arguments being used in the Execute method. - The service cache the exectuables. PiperOrigin-RevId: 220569355
This commit is contained in:
parent
bb878129e1
commit
029d65ecbe
@ -210,11 +210,10 @@ StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
|
||||
return XlaComputation(module.hlo().hlo_module());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
|
||||
const ExecutionOptions* execution_options,
|
||||
ExecutionProfile* execution_profile) {
|
||||
ExecuteGraphRequest request;
|
||||
StatusOr<ExecutionHandle> Client::Compile(
|
||||
const XlaComputation& computation, absl::Span<const Shape> argument_shapes,
|
||||
const ExecutionOptions* execution_options) {
|
||||
CompileRequest request;
|
||||
*request.mutable_computation() = computation.proto();
|
||||
|
||||
if (execution_options == nullptr) {
|
||||
@ -222,6 +221,34 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
} else {
|
||||
*request.mutable_execution_options() = *execution_options;
|
||||
}
|
||||
if (request.execution_options().device_handles_size() > 1) {
|
||||
return InvalidArgument(
|
||||
"Compiling with multiple device handles is not supported. Use "
|
||||
"'Execute' instead.");
|
||||
}
|
||||
|
||||
// The argument shapes affect how the computation is compiled.
|
||||
for (const auto& arg_shape : argument_shapes) {
|
||||
*request.add_input_shape_with_layout() = arg_shape;
|
||||
}
|
||||
|
||||
CompileResponse response;
|
||||
VLOG(1) << "making compile request: " << request.ShortDebugString();
|
||||
Status s = stub_->Compile(&request, &response);
|
||||
VLOG(1) << "done with request";
|
||||
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
TF_RET_CHECK(response.has_handle());
|
||||
return response.handle();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
|
||||
ExecutionProfile* execution_profile) {
|
||||
ExecuteRequest request;
|
||||
*request.mutable_handle() = handle;
|
||||
for (GlobalData* argument : arguments) {
|
||||
CHECK(argument != nullptr) << "Argument pointers must not be null.";
|
||||
*request.add_arguments() = argument->handle();
|
||||
@ -229,7 +256,7 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
|
||||
ExecuteResponse response;
|
||||
VLOG(1) << "making execute request: " << request.ShortDebugString();
|
||||
Status s = stub_->ExecuteGraph(&request, &response);
|
||||
Status s = stub_->Execute(&request, &response);
|
||||
VLOG(1) << "done with request";
|
||||
|
||||
if (!s.ok()) {
|
||||
@ -238,15 +265,62 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
|
||||
if (execution_profile != nullptr) {
|
||||
*execution_profile = response.profile();
|
||||
}
|
||||
|
||||
return absl::make_unique<GlobalData>(stub_, response.output());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
|
||||
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
|
||||
const ExecutionOptions* execution_options,
|
||||
ExecutionProfile* execution_profile) {
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->device_handles_size() > 1) {
|
||||
std::vector<XlaComputationInstance> computation_instances = {
|
||||
XlaComputationInstance{
|
||||
computation,
|
||||
std::vector<GlobalData*>(arguments.begin(), arguments.end()),
|
||||
*execution_options, execution_profile}};
|
||||
TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances));
|
||||
// The result selection is a bit hacky, but better than assuming it is
|
||||
// device 0.
|
||||
//
|
||||
// TODO(b/118493728): Allow Execute to return one result per computation.
|
||||
for (int64 i = 0; i < results.size(); i++) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i]));
|
||||
if (!ShapeUtil::IsEmptyTuple(shape)) {
|
||||
VLOG(3) << "Fetching result from device " << i << ": "
|
||||
<< ShapeUtil::HumanString(shape);
|
||||
return std::move(results[i]);
|
||||
}
|
||||
}
|
||||
TF_RET_CHECK(!results.empty());
|
||||
VLOG(1) << "Defaulting to device 0 result";
|
||||
return std::move(results[0]);
|
||||
}
|
||||
|
||||
// The argument shapes affect how the computation is compiled.
|
||||
std::vector<Shape> arg_shapes(arguments.size());
|
||||
for (int i = 0; i < arguments.size(); i++) {
|
||||
TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i]));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto handle,
|
||||
Compile(computation, arg_shapes, execution_options));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
Execute(handle, arguments, execution_profile));
|
||||
|
||||
if (execution_profile != nullptr) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto execution_stats,
|
||||
ExecutionStatsAsString(computation, response.profile()));
|
||||
ExecutionStatsAsString(computation, *execution_profile));
|
||||
VLOG(1) << execution_stats;
|
||||
}
|
||||
}
|
||||
|
||||
return absl::make_unique<GlobalData>(stub_, response.output());
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
|
||||
@ -274,10 +348,11 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<GlobalData>> outputs;
|
||||
for (size_t i = 0; i < computations.size(); ++i) {
|
||||
for (size_t i = 0; i < response.responses_size(); ++i) {
|
||||
outputs.push_back(
|
||||
absl::make_unique<GlobalData>(stub_, response.responses(i).output()));
|
||||
if (computations[i].execution_profile != nullptr) {
|
||||
if (i < computations.size() &&
|
||||
computations[i].execution_profile != nullptr) {
|
||||
*computations[i].execution_profile = response.responses(i).profile();
|
||||
}
|
||||
}
|
||||
|
@ -40,6 +40,31 @@ class Client {
|
||||
explicit Client(ServiceInterface* stub);
|
||||
virtual ~Client();
|
||||
|
||||
// Compile the computation with the given argument shapes and returns the
|
||||
// handle to the compiled executable. The compiled executable is cached on the
|
||||
// service, and the returned handle can be used for exection without
|
||||
// re-compile.
|
||||
// * The shape and layout of the arguments being executed with will affect how
|
||||
// the computation is compiled. If argument_shapes is empty, the parameters'
|
||||
// shape and layout will be used in the compilation.
|
||||
// * If execution_options is not nullptr, these options are passed to the
|
||||
// service to affect how it compiles our computation. (The pointer does not
|
||||
// need to live beyond this call.)
|
||||
// * If execution_options.device_handles should be empty. If you need
|
||||
// non-empty device handles, call 'Execute' instead.
|
||||
StatusOr<ExecutionHandle> Compile(
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const Shape> argument_shapes,
|
||||
const ExecutionOptions* execution_options = nullptr);
|
||||
|
||||
// Executes the compiled executable for the given handle with the given
|
||||
// arguments and returns the global data that was produced from the execution.
|
||||
// * If execution_profile is not nullptr then the pointed-to ExecutionProfile
|
||||
// will be filled with profile data from the execution.
|
||||
StatusOr<std::unique_ptr<GlobalData>> Execute(
|
||||
const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
|
||||
ExecutionProfile* execution_profile = nullptr);
|
||||
|
||||
// Executes the computation with the given arguments and returns the global
|
||||
// data that was produced from the execution.
|
||||
// * If execution_options is not nullptr, these options are passed to the
|
||||
|
@ -47,11 +47,18 @@ namespace xla {
|
||||
});
|
||||
}
|
||||
|
||||
::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/,
|
||||
const ExecuteGraphRequest* arg,
|
||||
ExecuteResponse* result) {
|
||||
::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/,
|
||||
const CompileRequest* arg,
|
||||
CompileResponse* result) {
|
||||
return DelegateRPC(
|
||||
[this, arg, result]() { return service_->ExecuteGraph(arg, result); });
|
||||
[this, arg, result]() { return service_->Compile(arg, result); });
|
||||
}
|
||||
|
||||
::grpc::Status GRPCService::Execute(::grpc::ServerContext* /*context*/,
|
||||
const ExecuteRequest* arg,
|
||||
ExecuteResponse* result) {
|
||||
return DelegateRPC(
|
||||
[this, arg, result]() { return service_->Execute(arg, result); });
|
||||
}
|
||||
|
||||
::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context,
|
||||
|
@ -39,9 +39,13 @@ class GRPCService : public grpc::XlaService::Service {
|
||||
const DeconstructTupleRequest* arg,
|
||||
DeconstructTupleResponse* result) override;
|
||||
|
||||
::grpc::Status ExecuteGraph(::grpc::ServerContext* context,
|
||||
const ExecuteGraphRequest* arg,
|
||||
ExecuteResponse* result) override;
|
||||
::grpc::Status Compile(::grpc::ServerContext* context,
|
||||
const CompileRequest* arg,
|
||||
CompileResponse* result) override;
|
||||
|
||||
::grpc::Status Execute(::grpc::ServerContext* context,
|
||||
const ExecuteRequest* arg,
|
||||
ExecuteResponse* result) override;
|
||||
|
||||
::grpc::Status WaitForExecution(::grpc::ServerContext* context,
|
||||
const WaitForExecutionRequest* arg,
|
||||
|
@ -62,10 +62,17 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request,
|
||||
});
|
||||
}
|
||||
|
||||
Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request,
|
||||
ExecuteResponse* response) {
|
||||
Status GRPCStub::Compile(const CompileRequest* request,
|
||||
CompileResponse* response) {
|
||||
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
|
||||
return grpc_stub_->ExecuteGraph(context, *request, response);
|
||||
return grpc_stub_->Compile(context, *request, response);
|
||||
});
|
||||
}
|
||||
|
||||
Status GRPCStub::Execute(const ExecuteRequest* request,
|
||||
ExecuteResponse* response) {
|
||||
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
|
||||
return grpc_stub_->Execute(context, *request, response);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -43,8 +43,11 @@ class GRPCStub : public ServiceInterface {
|
||||
Status ResetDevice(const ResetDeviceRequest* arg,
|
||||
ResetDeviceResponse* result) override;
|
||||
|
||||
Status ExecuteGraph(const ExecuteGraphRequest* request,
|
||||
ExecuteResponse* response) override;
|
||||
Status Compile(const CompileRequest* request,
|
||||
CompileResponse* response) override;
|
||||
|
||||
Status Execute(const ExecuteRequest* request,
|
||||
ExecuteResponse* response) override;
|
||||
|
||||
Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request,
|
||||
ExecuteParallelResponse* response) override;
|
||||
|
@ -128,11 +128,14 @@ service XlaService {
|
||||
returns (CreateChannelHandleResponse) {
|
||||
}
|
||||
|
||||
// Invokes the provided computation with the provided global data passed as
|
||||
// immutable arguments. The request contains the whole computation graph.
|
||||
// Compiles the provided computation into executable. Returns the handle of
|
||||
// the executable.
|
||||
rpc Compile(CompileRequest) returns (CompileResponse) {}
|
||||
|
||||
// Invokes the provided executable with the provided global data passed as
|
||||
// immutable arguments. The request contains the handle to the executable.
|
||||
// Returns global data output and execution timing.
|
||||
rpc ExecuteGraph(ExecuteGraphRequest) returns (ExecuteResponse) {
|
||||
}
|
||||
rpc Execute(ExecuteRequest) returns (ExecuteResponse) {}
|
||||
|
||||
// Invokes the provided list of computations in parallel with the provided
|
||||
// global data for each computation. Returns a list of global data output and
|
||||
|
@ -647,6 +647,7 @@ cc_library(
|
||||
":allocation_tracker",
|
||||
":backend",
|
||||
":channel_tracker",
|
||||
":compilation_cache",
|
||||
":compiler",
|
||||
":computation_layout",
|
||||
":device_memory_allocator",
|
||||
@ -2337,6 +2338,20 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compilation_cache",
|
||||
srcs = ["compilation_cache.cc"],
|
||||
hdrs = ["compilation_cache.h"],
|
||||
deps = [
|
||||
":executable",
|
||||
":hlo_module_config",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "layout_assignment",
|
||||
srcs = [
|
||||
|
70
tensorflow/compiler/xla/service/compilation_cache.cc
Normal file
70
tensorflow/compiler/xla/service/compilation_cache.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/compilation_cache.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
int64 GetUniqueId() {
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
static int64 counter = 0;
|
||||
tensorflow::mutex_lock loc(mu);
|
||||
const int64 id = counter++;
|
||||
return id;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ExecutionHandle CompilationCache::Insert(
|
||||
std::unique_ptr<Executable> executable) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
CacheKey key = GetUniqueId();
|
||||
VLOG(2) << "inserting cache key: " << key;
|
||||
CHECK_EQ(cache_.count(key), 0);
|
||||
cache_.emplace(key, std::move(executable));
|
||||
|
||||
ExecutionHandle handle;
|
||||
handle.set_handle(key);
|
||||
return handle;
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<Executable>> CompilationCache::LookUp(
|
||||
const ExecutionHandle& handle) const {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
CacheKey key = handle.handle();
|
||||
VLOG(2) << "looking up cache key: " << key;
|
||||
if (cache_.count(key) == 0) {
|
||||
VLOG(2) << "cache key not found: " << key;
|
||||
return InvalidArgumentStrCat("can not find executable with handle ", key);
|
||||
} else {
|
||||
auto& result = cache_.at(key);
|
||||
VLOG(2) << "hit executable: " << result->module().name();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
62
tensorflow/compiler/xla/service/compilation_cache.h
Normal file
62
tensorflow/compiler/xla/service/compilation_cache.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A cache which stores Executables indexed by computation handle and version.
|
||||
//
|
||||
// TODO(b/119042872): Provide mechanism for removing computations from the
|
||||
// compilation cache.
|
||||
class CompilationCache {
|
||||
public:
|
||||
CompilationCache() {}
|
||||
|
||||
ExecutionHandle Insert(std::unique_ptr<Executable> executable);
|
||||
|
||||
// Lookup the Executable for the specified handle in the cache. Return a
|
||||
// shared_ptr to the Executable if it exists in the cache.
|
||||
StatusOr<std::shared_ptr<Executable>> LookUp(
|
||||
const ExecutionHandle& handle) const;
|
||||
|
||||
protected:
|
||||
mutable tensorflow::mutex mutex_;
|
||||
|
||||
using CacheKey = int64;
|
||||
|
||||
absl::flat_hash_map<CacheKey, std::shared_ptr<Executable>> cache_
|
||||
GUARDED_BY(mutex_);
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
|
@ -760,38 +760,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg,
|
||||
ExecuteResponse* result) {
|
||||
ExecuteGraphParallelRequest parallel_arg;
|
||||
*parallel_arg.add_requests() = *arg;
|
||||
ExecuteParallelResponse parallel_result;
|
||||
TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result));
|
||||
return PickParallelResponse(parallel_result, result);
|
||||
}
|
||||
|
||||
Status Service::PickParallelResponse(
|
||||
const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) {
|
||||
// The "result device" selection is a bit hacky, but better than assuming it
|
||||
// is device 0. We have b/76035356 for restructuring the client API to clean
|
||||
// up the current asymmetries and support more functionalities.
|
||||
for (int64 i = 0; i < parallel_result.responses_size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
|
||||
allocation_tracker_.ResolveForReplica(
|
||||
parallel_result.responses(i).output(), 0));
|
||||
const Shape& shape = buffer->on_host_shape();
|
||||
if (!ShapeUtil::IsEmptyTuple(shape)) {
|
||||
*result = parallel_result.responses(i);
|
||||
VLOG(3) << "Fetching result from device " << i << ": "
|
||||
<< ShapeUtil::HumanString(shape);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
TF_RET_CHECK(parallel_result.responses_size() > 0);
|
||||
*result = parallel_result.responses(0);
|
||||
VLOG(1) << "Defaulting to device 0 result";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
const HloModuleProto& module_proto,
|
||||
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
|
||||
@ -836,10 +804,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
return std::move(executable);
|
||||
}
|
||||
|
||||
Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
|
||||
ExecuteResponse* result) {
|
||||
VLOG(1) << "running execute-graph request";
|
||||
|
||||
Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
|
||||
VLOG(1) << "running compile request";
|
||||
if (!arg->has_computation()) {
|
||||
return InvalidArgument("computations may not be empty");
|
||||
}
|
||||
@ -847,22 +813,21 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
|
||||
return InvalidArgument("programe shape may not be empty");
|
||||
}
|
||||
|
||||
// If we received multiple device handles, we must partition the module.
|
||||
if (arg->execution_options().device_handles_size() > 1) {
|
||||
return ExecuteOneToN(arg, result);
|
||||
return InvalidArgument(
|
||||
"The compile request does not support multiple device handles.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
|
||||
SingleComputationDeviceHandle()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
|
||||
ResolveAndValidateArguments(arg->arguments(), replicas));
|
||||
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
absl::c_transform(arg->input_shape_with_layout(),
|
||||
std::back_inserter(argument_shapes),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(arg->computation().host_program_shape(),
|
||||
replicated_arguments.front(),
|
||||
arg->execution_options()));
|
||||
argument_shapes, &arg->execution_options()));
|
||||
VLOG(3) << "Compile created HloModuleConfig computation layout: "
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
@ -871,6 +836,48 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
|
||||
execute_backend_->default_stream_executor(),
|
||||
/*device_allocator=*/nullptr));
|
||||
|
||||
*result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
|
||||
|
||||
VLOG(1) << "successfully completed 'compile' request";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
|
||||
VLOG(1) << "running execute request";
|
||||
if (!arg->has_handle()) {
|
||||
return InvalidArgument("execution handle should not be empty");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto executable,
|
||||
compilation_cache_.LookUp(arg->handle()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
|
||||
SingleComputationDeviceHandle()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
|
||||
ResolveAndValidateArguments(arg->arguments(), replicas));
|
||||
|
||||
// Check that the replicated_arguments has the same shape and layout as the
|
||||
// module config used when creating the exectuable.
|
||||
const int64 num_module_args =
|
||||
executable->module_config().entry_computation_layout().parameter_count();
|
||||
if (num_module_args != arg->arguments_size()) {
|
||||
return InvalidArgument(
|
||||
"The executable expects %lld arguments, but sees %lld.",
|
||||
num_module_args, arg->arguments_size());
|
||||
}
|
||||
for (int64 i = 0; i < num_module_args; i++) {
|
||||
const Shape& shape_module =
|
||||
executable->module_config().entry_computation_layout().parameter_shape(
|
||||
i);
|
||||
const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape();
|
||||
if (!ShapeUtil::Equal(shape_module, shape_arg)) {
|
||||
return InvalidArgumentStrCat(
|
||||
"The executable exepcts the ", i, "th argument in shape ",
|
||||
ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
|
||||
ShapeUtil::HumanStringWithLayout(shape_arg));
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto stream,
|
||||
execute_backend_->BorrowStream(
|
||||
execute_backend_->default_stream_executor()));
|
||||
@ -884,9 +891,10 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*result->mutable_output(),
|
||||
ExecuteAndRegisterResult(
|
||||
executable.get(), replicated_arguments, execute_backend_.get(),
|
||||
"result of " + arg->computation().name(), result->mutable_profile()));
|
||||
ExecuteAndRegisterResult(executable.get(), replicated_arguments,
|
||||
execute_backend_.get(),
|
||||
"result of " + executable->module().name(),
|
||||
result->mutable_profile()));
|
||||
|
||||
if (executable->dumping_snapshot()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -898,7 +906,7 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
|
||||
TF_RETURN_IF_ERROR(executable->DumpHloSnapshot());
|
||||
}
|
||||
|
||||
VLOG(1) << "successfully completed 'execute-graph' request";
|
||||
VLOG(1) << "successfully completed 'execute' request";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/allocation_tracker.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/channel_tracker.h"
|
||||
#include "tensorflow/compiler/xla/service/compilation_cache.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/execution_tracker.h"
|
||||
@ -90,11 +91,14 @@ class Service : public ServiceInterface {
|
||||
Status DeconstructTuple(const DeconstructTupleRequest* arg,
|
||||
DeconstructTupleResponse* result) override;
|
||||
|
||||
// Executes a computation with the provided global data passed as
|
||||
// immutable arguments. The request contains the whole computation graph.
|
||||
// Returns global data output and execution timing.
|
||||
Status ExecuteGraph(const ExecuteGraphRequest* arg,
|
||||
ExecuteResponse* result) override;
|
||||
// Compiles a computation into an executable. The request contains the whole
|
||||
// computation graph. Returns the handle to the executable.
|
||||
Status Compile(const CompileRequest* arg, CompileResponse* result) override;
|
||||
|
||||
// Executes an executable with the provided global data passes as immutable
|
||||
// arguments. The request contains the handle to the executable. Returns
|
||||
// global data output and execution timing.
|
||||
Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
|
||||
|
||||
// Executes one or more computations in parallel with the provided global data
|
||||
// passed as immutable arguments. Returns global data output for each
|
||||
@ -179,10 +183,6 @@ class Service : public ServiceInterface {
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
const ExecutionOptions& execution_options);
|
||||
|
||||
// Picks a parallel response and fills the result.
|
||||
Status PickParallelResponse(const ExecuteParallelResponse& parallel_result,
|
||||
ExecuteResponse* result);
|
||||
|
||||
// Prepare the executors for executing parallel.
|
||||
StatusOr<std::vector<se::StreamExecutor*>> GetExecutors(
|
||||
const ExecutionOptions& execution_options, int64 requests_size,
|
||||
@ -254,11 +254,6 @@ class Service : public ServiceInterface {
|
||||
Backend* backend, absl::Span<const DeviceHandle> device_handles,
|
||||
absl::Span<const string> result_tags, ExecutionProfile* profile);
|
||||
|
||||
// Executes a single computation which has more than one target device.
|
||||
// The N devices are expected to all return an empty tuple, but one, which
|
||||
// will be the result of this computation.
|
||||
Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result);
|
||||
|
||||
// Convenience function which checks whether the given client_shape
|
||||
// (presumably passed by the client to set the result layout) is valid for the
|
||||
// given computation result shape.
|
||||
@ -281,6 +276,9 @@ class Service : public ServiceInterface {
|
||||
|
||||
ServiceOptions options_;
|
||||
|
||||
// Cache containing previously built Executables.
|
||||
CompilationCache compilation_cache_;
|
||||
|
||||
// Tracks channels created via the API.
|
||||
ChannelTracker channel_tracker_;
|
||||
|
||||
|
@ -47,8 +47,11 @@ class ServiceInterface {
|
||||
virtual Status ResetDevice(const ResetDeviceRequest* arg,
|
||||
ResetDeviceResponse* result) = 0;
|
||||
|
||||
virtual Status ExecuteGraph(const ExecuteGraphRequest* arg,
|
||||
ExecuteResponse* result) = 0;
|
||||
virtual Status Compile(const CompileRequest* arg,
|
||||
CompileResponse* result) = 0;
|
||||
|
||||
virtual Status Execute(const ExecuteRequest* arg,
|
||||
ExecuteResponse* result) = 0;
|
||||
|
||||
virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
|
||||
ExecuteParallelResponse* result) = 0;
|
||||
|
@ -322,6 +322,34 @@ message UnregisterRequest {
|
||||
message UnregisterResponse {
|
||||
}
|
||||
|
||||
message CompileRequest {
|
||||
// The graph to be compiled.
|
||||
HloModuleProto computation = 1;
|
||||
|
||||
// Options that affect how XLA compiles code to service this request.
|
||||
ExecutionOptions execution_options = 2;
|
||||
|
||||
// The layouts of the input arguments. If not set, the default layout will be
|
||||
// used. Although the real arguments are not needed in compilation, the
|
||||
// layouts of the arguments can affect the compilation.
|
||||
repeated Shape input_shape_with_layout = 3;
|
||||
}
|
||||
|
||||
message CompileResponse {
|
||||
// The handle to the executable.
|
||||
ExecutionHandle handle = 1;
|
||||
}
|
||||
|
||||
message ExecuteRequest {
|
||||
ExecutionHandle handle = 1;
|
||||
|
||||
// The shape and layout of the arguments must be the same as the those of the
|
||||
// executable's parameters.
|
||||
repeated GlobalDataHandle arguments = 2;
|
||||
}
|
||||
|
||||
// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace
|
||||
// the uses with calls to Compile and Execute.
|
||||
message ExecuteGraphRequest {
|
||||
HloModuleProto computation = 1;
|
||||
repeated GlobalDataHandle arguments = 2;
|
||||
|
Loading…
Reference in New Issue
Block a user