Implement Serialization/Deserialization for remote compilation cache.
PiperOrigin-RevId: 329561676 Change-Id: If9df54e30be753420ff70e88eb179751fc3de4ae
This commit is contained in:
parent
a1e501e957
commit
b4ee2c4294
@ -317,6 +317,7 @@ cc_library(
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_program_c_api_hdrs",
|
||||
":tpu_program_group_interface",
|
||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
@ -467,6 +468,7 @@ cc_library(
|
||||
deps = [
|
||||
":tpu_util_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = True,
|
||||
@ -515,8 +517,8 @@ cc_library(
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
|
||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep
|
||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], # build_cleaner: keep
|
||||
}) + [
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_interface",
|
||||
@ -536,8 +538,16 @@ cc_library(
|
||||
DEFAULT: [],
|
||||
}),
|
||||
deps = [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
":tpu_compilation_cache_rpc_support_hdrs",
|
||||
":tpu_program_group",
|
||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/tpu:tpu_config_c_api",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
)
|
||||
|
||||
@ -613,27 +623,22 @@ cc_library(
|
||||
}),
|
||||
deps = select({
|
||||
WITH_TPU_SUPPORT: [
|
||||
":tpu_compilation_cache_rpc_support",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
":tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
||||
":tpu_compilation_cache_proto_cc", # build_cleaner: keep
|
||||
],
|
||||
DEFAULT: [
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc", # build_cleaner: keep
|
||||
],
|
||||
}) + [
|
||||
":tpu_compilation_cache_common_proto_cc",
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_grpc",
|
||||
":tpu_compilation_cache_interface",
|
||||
":tpu_compilation_cache_rpc_support_hdrs",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
"//tensorflow/core/lib/core:threadpool",
|
||||
"//tensorflow/core/platform:coding",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
@ -497,7 +497,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
*uid = entry->uid;
|
||||
// Let the caller know the keys for each of the cached protos.
|
||||
*proto_key = entry->proto_key;
|
||||
*may_modify_variables = entry->tpu_program_group->may_modify_variables();
|
||||
*may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
|
||||
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
|
||||
|
||||
// If the caller didn't supply a per_step_ref_holder then the caller is going
|
||||
|
@ -160,7 +160,7 @@ Status TpuCompilationCacheRpcLookup::RemoteLookupLocked(
|
||||
<< " in remote subgraph cache status " << s;
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
|
||||
TF_RETURN_IF_ERROR(FillCacheEntryFromGetTpuProgramResponse(
|
||||
TF_RETURN_IF_ERROR(DeserializeRpcResponseToCacheEntry(
|
||||
local_proto_key, &response, cache_entry));
|
||||
cache_.emplace(local_proto_key, (*cache_entry));
|
||||
cache_size_ += (*cache_entry)->size;
|
||||
|
@ -14,9 +14,15 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#if defined(LIBTFTPU)
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#endif // LIBTFTPU
|
||||
#endif
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
@ -26,18 +32,127 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
|
||||
|
||||
#if defined(LIBTFTPU)
|
||||
template <>
|
||||
Status FillCacheEntryFromGetTpuProgramResponse<GetTpuProgramResponseExternal>(
|
||||
Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
|
||||
absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
|
||||
std::shared_ptr<CacheEntry>* cache_entry) {
|
||||
// TODO(b/162904194): implement this method.
|
||||
LOG(FATAL) << "Not implemented yet.";
|
||||
CHECK_NE(response, nullptr);
|
||||
CHECK_NE(cache_entry, nullptr);
|
||||
*cache_entry = std::make_shared<CacheEntry>();
|
||||
CacheEntry& entry = **cache_entry;
|
||||
entry.key = local_proto_key;
|
||||
|
||||
if (response->is_empty()) {
|
||||
entry.size = 0;
|
||||
} else {
|
||||
// When we lookup from remote cache, we fetch a TPU program for a specific
|
||||
// core hence we allocate TPU program group for a single program.
|
||||
entry.tpu_program_group = TpuProgramGroup::Create(/*count=*/1);
|
||||
|
||||
TpuSerializedProto serialized_response_proto =
|
||||
stream_executor::tpu::SerializeProto(*response);
|
||||
auto cleanup = xla::MakeCleanup([&serialized_response_proto]() {
|
||||
stream_executor::tpu::SerializedProto_Free(serialized_response_proto);
|
||||
});
|
||||
// TODO(b/166575150): can be optimized by sending the buffer over the gRPC
|
||||
// without an extra deserializing.
|
||||
TpuProgramGroup* tpu_program_group =
|
||||
tensorflow::down_cast<TpuProgramGroup*>(entry.tpu_program_group.get());
|
||||
TF_RETURN_IF_ERROR(tpu_program_group->DeserializeFromProto(
|
||||
/*index=*/0, serialized_response_proto));
|
||||
entry.size = entry.tpu_program_group->program_size();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SendGetTpuProgramResponseHelper(
|
||||
const TpuCompilationCacheEntry& cache_entry,
|
||||
std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn) {
|
||||
// TODO(b/162904194): implement this method.
|
||||
LOG(FATAL) << "Not implemented yet.";
|
||||
xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
|
||||
const TpuCompilationCacheEntry& cache_entry) {
|
||||
if (cache_entry.tpu_program_group() == nullptr) {
|
||||
// It's possible that the sharding/unsharding entry does not exist, but the
|
||||
// main entry must exist.
|
||||
GetTpuProgramResponseExternal header;
|
||||
header.set_is_empty(true);
|
||||
std::string encoded_header;
|
||||
if (!header.AppendToString(&encoded_header)) {
|
||||
return errors::Internal("Failed to serialize TPU program metadata.");
|
||||
}
|
||||
::grpc::Slice slice(encoded_header);
|
||||
return std::vector<::grpc::Slice>{slice};
|
||||
}
|
||||
|
||||
const TpuProgramGroup* tpu_program_group =
|
||||
tensorflow::down_cast<const TpuProgramGroup*>(
|
||||
cache_entry.tpu_program_group());
|
||||
CHECK_NE(tpu_program_group, nullptr);
|
||||
CHECK_GE(tpu_program_group->program_count(), 0);
|
||||
CHECK_GE(cache_entry.core_index(), 0);
|
||||
CHECK_LT(cache_entry.core_index(), tpu_program_group->program_count());
|
||||
const int64 program_size = tpu_program_group->program_size();
|
||||
if (program_size > INT_MAX) {
|
||||
return errors::Internal("TPU program exceeded 2 GiB.");
|
||||
}
|
||||
|
||||
TpuExecutableSerializedProto executable;
|
||||
auto cleanup_executable = xla::MakeCleanup([&executable]() {
|
||||
if (executable.size > 0) {
|
||||
stream_executor::tpu::SerializedProto_Free(executable);
|
||||
}
|
||||
});
|
||||
auto get_executable_status = tpu_program_group->SerializeExecutable(
|
||||
cache_entry.core_index(), &executable);
|
||||
if (!get_executable_status.ok()) {
|
||||
return errors::Internal("Failed to serialize TPU program.");
|
||||
}
|
||||
|
||||
// Encode and serialize header fields.
|
||||
GetTpuProgramResponseExternal header;
|
||||
header.mutable_proto()->ParseFromArray(executable.bytes, executable.size);
|
||||
header.set_is_empty(false);
|
||||
|
||||
HostComputeMetadataSerializedProto host_compute_metadata;
|
||||
auto cleanup_host_compute_metadata =
|
||||
xla::MakeCleanup([&host_compute_metadata]() {
|
||||
if (host_compute_metadata.size > 0) {
|
||||
stream_executor::tpu::SerializedProto_Free(host_compute_metadata);
|
||||
}
|
||||
});
|
||||
Status get_host_compute_metadata_status =
|
||||
tpu_program_group->SerializeHostComputeMetadata(cache_entry.core_index(),
|
||||
&host_compute_metadata);
|
||||
if (!get_host_compute_metadata_status.ok()) {
|
||||
return errors::Internal("Failed to serialize host compute metadata.");
|
||||
}
|
||||
tf2xla::HostComputeMetadata host_compute_metadata_proto =
|
||||
stream_executor::tpu::DeserializeProto<tf2xla::HostComputeMetadata>(
|
||||
host_compute_metadata);
|
||||
*header.mutable_host_compute_metadata() =
|
||||
std::move(host_compute_metadata_proto);
|
||||
|
||||
bool may_modify_variables =
|
||||
tpu_program_group->may_modify_variables(cache_entry.core_index());
|
||||
header.set_may_modify_variables(may_modify_variables);
|
||||
|
||||
CompilerMetadataSerializedProto compiler_metadata;
|
||||
auto cleanup_compiler_metadata = xla::MakeCleanup([&compiler_metadata]() {
|
||||
if (compiler_metadata.size > 0) {
|
||||
stream_executor::tpu::SerializedProto_Free(compiler_metadata);
|
||||
}
|
||||
});
|
||||
Status get_compiler_metadata_status =
|
||||
tpu_program_group->SerializeCompilerMetadata(cache_entry.core_index(),
|
||||
&compiler_metadata);
|
||||
if (!get_compiler_metadata_status.ok()) {
|
||||
return errors::Internal("Failed to serialize compiler metadata.");
|
||||
}
|
||||
header.mutable_compiler_metadata()->mutable_data()->assign(
|
||||
compiler_metadata.bytes, compiler_metadata.size);
|
||||
|
||||
std::string encoded_header;
|
||||
if (!header.AppendToString(&encoded_header)) {
|
||||
return errors::Internal("Failed to serialize TPU program metadata.");
|
||||
}
|
||||
|
||||
return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
|
||||
}
|
||||
#endif // LIBTFTPU
|
||||
} // namespace tpu
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "grpcpp/support/slice.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
@ -82,14 +84,13 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials();
|
||||
// Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The
|
||||
// `cache_entry` will be instantiated by the function.
|
||||
template <typename ResponseType>
|
||||
Status FillCacheEntryFromGetTpuProgramResponse(
|
||||
Status DeserializeRpcResponseToCacheEntry(
|
||||
const absl::string_view local_proto_key, ResponseType* response,
|
||||
std::shared_ptr<CacheEntry>* cache_entry);
|
||||
|
||||
// A helper to send `TpuCompilationCacheEntry` payload through gRPC channel.
|
||||
void SendGetTpuProgramResponseHelper(
|
||||
const TpuCompilationCacheEntry& cache_entry,
|
||||
std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn);
|
||||
// Serializes `TpuCompilationCacheEntry` to gRPC bufer slices.
|
||||
xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
|
||||
const TpuCompilationCacheEntry& cache_entry);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
|
||||
#include "grpcpp/support/byte_buffer.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/platform/coding.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
|
||||
@ -125,15 +126,17 @@ void TpuCompilationCacheService::GetTpuProgram(GetTpuProgramCall* call) {
|
||||
CHECK_NE(call->request.fetch_target(),
|
||||
tpu::CompilationCacheFetchTarget::MAIN);
|
||||
}
|
||||
return SendGetTpuProgramResponseHelper(
|
||||
cache_entry,
|
||||
[&call](::grpc::ByteBuffer* buffer, ::grpc::Status error_status) {
|
||||
if (buffer == nullptr) {
|
||||
return call->SendResponse(error_status);
|
||||
}
|
||||
call->response = *buffer;
|
||||
return call->SendResponse(::grpc::Status());
|
||||
});
|
||||
|
||||
xla::StatusOr<std::vector<::grpc::Slice>> buffer_slices =
|
||||
tpu::SerializeCacheEntryToBufferSlices(cache_entry);
|
||||
|
||||
if (!buffer_slices.ok()) {
|
||||
return call->SendResponse(ToGrpcStatus(buffer_slices.status()));
|
||||
}
|
||||
|
||||
call->response =
|
||||
::grpc::ByteBuffer{&buffer_slices.ValueOrDie()[0], buffer_slices->size()};
|
||||
return call->SendResponse(::grpc::Status());
|
||||
}
|
||||
|
||||
void TpuCompilationCacheService::HandleGetTpuProgram(GetTpuProgramCall* call) {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||
@ -24,6 +25,21 @@ typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
|
||||
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
|
||||
|
||||
struct TpuExecutableSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct CompilerMetadataSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct HostComputeMetadataSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Creates a new TPU program.
|
||||
@ -67,7 +83,7 @@ TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
||||
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
|
||||
|
||||
// Check if TPU program has sharding.
|
||||
// Checks if TPU program has sharding.
|
||||
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
||||
const XLA_TpuProgram* tpu_program);
|
||||
|
||||
@ -76,6 +92,27 @@ TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
||||
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
|
||||
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
|
||||
|
||||
// Gets TPU executable proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
|
||||
const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets compilation metadata proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
|
||||
const XLA_TpuProgram* tpu_program,
|
||||
CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status);
|
||||
|
||||
// Gets host transfer metadata proto from a `tpu_program`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_SerializeHostComputeMetadata(
|
||||
const XLA_TpuProgram* tpu_program,
|
||||
HostComputeMetadataSerializedProto* host_compute_metadata,
|
||||
SE_Status* status);
|
||||
|
||||
// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
|
||||
TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
|
||||
TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
|
||||
SE_Status* status);
|
||||
|
||||
struct TfTpu_TpuProgramApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
|
||||
@ -90,6 +127,10 @@ struct TfTpu_TpuProgramApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeHostComputeMetadata);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
|
||||
};
|
||||
|
||||
} // extern "C"
|
||||
|
@ -29,11 +29,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace se_tpu = ::stream_executor::tpu;
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
using xla::Shape;
|
||||
@ -195,7 +192,20 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
|
||||
tpu_programs_.clear();
|
||||
}
|
||||
|
||||
/*static*/ Status TpuProgramGroup::Build(
|
||||
/*static*/
|
||||
std::unique_ptr<TpuProgramGroup> TpuProgramGroup::Create(int count) {
|
||||
auto tpu_program_group = std::make_unique<TpuProgramGroup>();
|
||||
std::vector<XLA_TpuProgram*> tpu_programs;
|
||||
tpu_programs.resize(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
tpu_programs[i] = TpuProgramApiFn()->TpuProgram_NewFn();
|
||||
}
|
||||
tpu_program_group->set_tpu_programs(tpu_programs);
|
||||
return tpu_program_group;
|
||||
}
|
||||
|
||||
/*static*/
|
||||
Status TpuProgramGroup::Build(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
|
||||
const std::vector<ShardingAndIndex>& arg_core_mapping,
|
||||
@ -283,7 +293,7 @@ Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::vector<bool>& TpuProgramGroup::may_modify_variables() const {
|
||||
const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
|
||||
return may_modify_variables_;
|
||||
}
|
||||
|
||||
@ -292,6 +302,15 @@ void TpuProgramGroup::set_may_modify_variables(
|
||||
may_modify_variables_ = may_modify_variables;
|
||||
}
|
||||
|
||||
bool TpuProgramGroup::may_modify_variables(int index) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
bool may_modify_variables;
|
||||
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
|
||||
&may_modify_variables);
|
||||
return may_modify_variables;
|
||||
}
|
||||
|
||||
const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
|
||||
return tpu_programs_;
|
||||
}
|
||||
@ -371,5 +390,47 @@ std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
|
||||
}
|
||||
return tpu_programs;
|
||||
}
|
||||
|
||||
Status TpuProgramGroup::DeserializeFromProto(int index,
|
||||
TpuSerializedProto proto) {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
StatusHelper status;
|
||||
CHECK_NE(tpu_programs_[index], nullptr);
|
||||
TpuProgramApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
|
||||
proto, tpu_programs_[index], status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuProgramGroup::SerializeExecutable(
|
||||
int index, TpuExecutableSerializedProto* executable) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_SerializeTpuExecutableFn(
|
||||
tpu_programs_[index], executable, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuProgramGroup::SerializeCompilerMetadata(
|
||||
int index, CompilerMetadataSerializedProto* compiler_metadata) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_SerializeCompilerMetadataFn(
|
||||
tpu_programs_[index], compiler_metadata, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuProgramGroup::SerializeHostComputeMetadata(
|
||||
int index,
|
||||
HostComputeMetadataSerializedProto* host_compute_metadata) const {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, tpu_programs_.size());
|
||||
StatusHelper status;
|
||||
TpuProgramApiFn()->TpuProgram_SerializeHostComputeMetadataFn(
|
||||
tpu_programs_[index], host_compute_metadata, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -15,9 +15,11 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
@ -102,6 +104,9 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
||||
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
|
||||
TpuProgramGroupInterface* tpu_program_group_interface);
|
||||
|
||||
// Creates the `count` instances of uninitialized `XLA_TpuPrograms`.
|
||||
static std::unique_ptr<TpuProgramGroup> Create(int count);
|
||||
|
||||
// Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
|
||||
void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
|
||||
|
||||
@ -122,8 +127,9 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
||||
Status LogCompilationStats(const TpuCompilationCacheKey& key,
|
||||
absl::Duration duration) override;
|
||||
|
||||
const std::vector<bool>& may_modify_variables() const override;
|
||||
const std::vector<bool>& may_modify_variables_list() const override;
|
||||
void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
|
||||
bool may_modify_variables(int index) const override;
|
||||
|
||||
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
|
||||
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
|
||||
@ -137,6 +143,25 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
||||
const xla::HloProto* hlo_metadata(int index) const;
|
||||
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
|
||||
|
||||
// Deserializes `GetTpuProgramResponse` proto into an `XLA_TpuProgram` for
|
||||
// the given core `index`.
|
||||
Status DeserializeFromProto(int index, TpuSerializedProto proto);
|
||||
|
||||
// Serializes executable proto from the TPU program for the given core
|
||||
// `index`.
|
||||
Status SerializeExecutable(int index,
|
||||
TpuExecutableSerializedProto* executable) const;
|
||||
|
||||
// Serializes compiler metadata of the TPU program for the given core `index`.
|
||||
Status SerializeCompilerMetadata(
|
||||
int index, CompilerMetadataSerializedProto* compiler_metadata) const;
|
||||
|
||||
// Serializes host compute metadata of the TPU program for the given core
|
||||
// `index`.
|
||||
Status SerializeHostComputeMetadata(
|
||||
int index,
|
||||
HostComputeMetadataSerializedProto* host_compute_metadata) const;
|
||||
|
||||
private:
|
||||
void RefreshHloMetadatasPtrs();
|
||||
|
||||
|
@ -61,7 +61,11 @@ class TpuProgramGroupInterface {
|
||||
|
||||
// Boolean array to indicate if the modification of variables are
|
||||
// allowed.
|
||||
virtual const std::vector<bool>& may_modify_variables() const = 0;
|
||||
virtual const std::vector<bool>& may_modify_variables_list() const = 0;
|
||||
|
||||
// Gets may modify variables value of the TPU program for the given core
|
||||
// `index`.
|
||||
virtual bool may_modify_variables(int index) const = 0;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
@ -72,6 +72,11 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeTpuExecutable);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeCompilerMetadata);
|
||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeHostComputeMetadata);
|
||||
TFTPU_SET_FN(tpu_program_fn,
|
||||
TpuProgram_DeserializeFromGetTpuProgramResponseProto);
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
@ -32,12 +32,13 @@ namespace tpu {
|
||||
|
||||
using SerializedProto = TpuSerializedProto;
|
||||
|
||||
// Serializes a proto and put the result in the given SerializedProto* argument.
|
||||
// Serializes a `proto` and put the result in the given `SerializedProtoType*`
|
||||
// argument.
|
||||
//
|
||||
// Users should call SerializedProto_Free on `serialized_proto` afterwards.
|
||||
template <class Proto>
|
||||
inline void SerializeProto(const Proto& proto,
|
||||
SerializedProto* serialized_proto) {
|
||||
template <class ProtoType, class SerializedProtoType>
|
||||
inline void SerializeProto(const ProtoType& proto,
|
||||
SerializedProtoType* serialized_proto) {
|
||||
auto size = proto.ByteSizeLong();
|
||||
auto bytes = new char[size];
|
||||
CHECK(proto.SerializeToArray(bytes, size));
|
||||
@ -48,8 +49,8 @@ inline void SerializeProto(const Proto& proto,
|
||||
// Serializes a proto and return the result as a SerializedProto value.
|
||||
//
|
||||
// Users should call SerializedProto_Free on the return value afterwards.
|
||||
template <class Proto>
|
||||
inline SerializedProto SerializeProto(const Proto& proto) {
|
||||
template <class ProtoType>
|
||||
inline SerializedProto SerializeProto(const ProtoType& proto) {
|
||||
SerializedProto serialized_proto;
|
||||
SerializeProto(proto, &serialized_proto);
|
||||
return serialized_proto;
|
||||
@ -57,9 +58,9 @@ inline SerializedProto SerializeProto(const Proto& proto) {
|
||||
|
||||
// Deserializes a buffer and return the corresponding proto. If the buffer is
|
||||
// empty, return an empty proto.
|
||||
template <class Proto>
|
||||
inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
|
||||
Proto proto;
|
||||
template <class ProtoType, class SerializedProtoType>
|
||||
inline ProtoType DeserializeProto(const SerializedProtoType& serialized_proto) {
|
||||
ProtoType proto;
|
||||
if (serialized_proto.bytes != nullptr) {
|
||||
CHECK_GT(serialized_proto.size, 0);
|
||||
CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size))
|
||||
@ -69,7 +70,8 @@ inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
|
||||
}
|
||||
|
||||
// Releases the memory allocated for serialized protos.
|
||||
inline void SerializedProto_Free(const SerializedProto& serialized_proto) {
|
||||
template <class SerializedProtoType>
|
||||
inline void SerializedProto_Free(const SerializedProtoType& serialized_proto) {
|
||||
CHECK_NE(serialized_proto.bytes, nullptr);
|
||||
CHECK_GT(serialized_proto.size, 0);
|
||||
delete[] serialized_proto.bytes;
|
||||
|
Loading…
Reference in New Issue
Block a user