Implement Serialization/Deserialization for remote compilation cache.

PiperOrigin-RevId: 329561676
Change-Id: If9df54e30be753420ff70e88eb179751fc3de4ae
This commit is contained in:
Henry Tan 2020-09-01 12:45:00 -07:00 committed by TensorFlower Gardener
parent a1e501e957
commit b4ee2c4294
12 changed files with 316 additions and 54 deletions

View File

@ -317,6 +317,7 @@ cc_library(
":tpu_mesh_state_interface",
":tpu_program_c_api_hdrs",
":tpu_program_group_interface",
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/client:compile_only_client",
@ -467,6 +468,7 @@ cc_library(
deps = [
":tpu_util_c_api_hdrs",
"//tensorflow/core/tpu:libtftpu_header",
"//tensorflow/stream_executor/tpu:c_api_decl",
"//tensorflow/stream_executor/tpu:proto_helper",
],
alwayslink = True,
@ -515,8 +517,8 @@ cc_library(
DEFAULT: [],
}),
deps = select({
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], # build_cleaner: keep
}) + [
":tpu_compilation_cache_entry",
":tpu_compilation_cache_interface",
@ -536,8 +538,16 @@ cc_library(
DEFAULT: [],
}),
deps = [
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_proto_cc",
":tpu_compilation_cache_rpc_support_hdrs",
":tpu_program_group",
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/tpu:tpu_config_c_api",
"//tensorflow/stream_executor/tpu:proto_helper",
],
)
@ -613,27 +623,22 @@ cc_library(
}),
deps = select({
WITH_TPU_SUPPORT: [
":tpu_compilation_cache_rpc_support",
":tpu_compilation_cache_proto_cc",
":tpu_compilation_cache_rpc_support", # build_cleaner: keep
":tpu_compilation_cache_proto_cc", # build_cleaner: keep
],
DEFAULT: [
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support",
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc",
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc", # build_cleaner: keep
],
}) + [
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_entry",
":tpu_compilation_cache_grpc",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_rpc_support_hdrs",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/lib/core:refcount",
"//tensorflow/core/lib/core:threadpool",
"//tensorflow/core/platform:coding",
"//tensorflow/core:protos_all_cc",
tf_grpc_cc_dependency(),
],
)

View File

@ -497,7 +497,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
*uid = entry->uid;
// Let the caller know the keys for each of the cached protos.
*proto_key = entry->proto_key;
*may_modify_variables = entry->tpu_program_group->may_modify_variables();
*may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
// If the caller didn't supply a per_step_ref_holder then the caller is going

View File

@ -160,7 +160,7 @@ Status TpuCompilationCacheRpcLookup::RemoteLookupLocked(
<< " in remote subgraph cache status " << s;
TF_RETURN_IF_ERROR(s);
TF_RETURN_IF_ERROR(FillCacheEntryFromGetTpuProgramResponse(
TF_RETURN_IF_ERROR(DeserializeRpcResponseToCacheEntry(
local_proto_key, &response, cache_entry));
cache_.emplace(local_proto_key, (*cache_entry));
cache_size_ += (*cache_entry)->size;

View File

@ -14,9 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/platform/casts.h"
#if defined(LIBTFTPU)
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#endif // LIBTFTPU
#endif
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
namespace tensorflow {
namespace tpu {
@ -26,18 +32,127 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
#if defined(LIBTFTPU)
template <>
Status FillCacheEntryFromGetTpuProgramResponse<GetTpuProgramResponseExternal>(
Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
std::shared_ptr<CacheEntry>* cache_entry) {
// TODO(b/162904194): implement this method.
LOG(FATAL) << "Not implemented yet.";
CHECK_NE(response, nullptr);
CHECK_NE(cache_entry, nullptr);
*cache_entry = std::make_shared<CacheEntry>();
CacheEntry& entry = **cache_entry;
entry.key = local_proto_key;
if (response->is_empty()) {
entry.size = 0;
} else {
// When we lookup from remote cache, we fetch a TPU program for a specific
// core hence we allocate TPU program group for a single program.
entry.tpu_program_group = TpuProgramGroup::Create(/*count=*/1);
TpuSerializedProto serialized_response_proto =
stream_executor::tpu::SerializeProto(*response);
auto cleanup = xla::MakeCleanup([&serialized_response_proto]() {
stream_executor::tpu::SerializedProto_Free(serialized_response_proto);
});
// TODO(b/166575150): can be optimized by sending the buffer over the gRPC
// without an extra deserializing.
TpuProgramGroup* tpu_program_group =
tensorflow::down_cast<TpuProgramGroup*>(entry.tpu_program_group.get());
TF_RETURN_IF_ERROR(tpu_program_group->DeserializeFromProto(
/*index=*/0, serialized_response_proto));
entry.size = entry.tpu_program_group->program_size();
}
return Status::OK();
}
void SendGetTpuProgramResponseHelper(
const TpuCompilationCacheEntry& cache_entry,
std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn) {
// TODO(b/162904194): implement this method.
LOG(FATAL) << "Not implemented yet.";
xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
const TpuCompilationCacheEntry& cache_entry) {
if (cache_entry.tpu_program_group() == nullptr) {
// It's possible that the sharding/unsharding entry does not exist, but the
// main entry must exist.
GetTpuProgramResponseExternal header;
header.set_is_empty(true);
std::string encoded_header;
if (!header.AppendToString(&encoded_header)) {
return errors::Internal("Failed to serialize TPU program metadata.");
}
::grpc::Slice slice(encoded_header);
return std::vector<::grpc::Slice>{slice};
}
const TpuProgramGroup* tpu_program_group =
tensorflow::down_cast<const TpuProgramGroup*>(
cache_entry.tpu_program_group());
CHECK_NE(tpu_program_group, nullptr);
CHECK_GE(tpu_program_group->program_count(), 0);
CHECK_GE(cache_entry.core_index(), 0);
CHECK_LT(cache_entry.core_index(), tpu_program_group->program_count());
const int64 program_size = tpu_program_group->program_size();
if (program_size > INT_MAX) {
return errors::Internal("TPU program exceeded 2 GiB.");
}
TpuExecutableSerializedProto executable;
auto cleanup_executable = xla::MakeCleanup([&executable]() {
if (executable.size > 0) {
stream_executor::tpu::SerializedProto_Free(executable);
}
});
auto get_executable_status = tpu_program_group->SerializeExecutable(
cache_entry.core_index(), &executable);
if (!get_executable_status.ok()) {
return errors::Internal("Failed to serialize TPU program.");
}
// Encode and serialize header fields.
GetTpuProgramResponseExternal header;
header.mutable_proto()->ParseFromArray(executable.bytes, executable.size);
header.set_is_empty(false);
HostComputeMetadataSerializedProto host_compute_metadata;
auto cleanup_host_compute_metadata =
xla::MakeCleanup([&host_compute_metadata]() {
if (host_compute_metadata.size > 0) {
stream_executor::tpu::SerializedProto_Free(host_compute_metadata);
}
});
Status get_host_compute_metadata_status =
tpu_program_group->SerializeHostComputeMetadata(cache_entry.core_index(),
&host_compute_metadata);
if (!get_host_compute_metadata_status.ok()) {
return errors::Internal("Failed to serialize host compute metadata.");
}
tf2xla::HostComputeMetadata host_compute_metadata_proto =
stream_executor::tpu::DeserializeProto<tf2xla::HostComputeMetadata>(
host_compute_metadata);
*header.mutable_host_compute_metadata() =
std::move(host_compute_metadata_proto);
bool may_modify_variables =
tpu_program_group->may_modify_variables(cache_entry.core_index());
header.set_may_modify_variables(may_modify_variables);
CompilerMetadataSerializedProto compiler_metadata;
auto cleanup_compiler_metadata = xla::MakeCleanup([&compiler_metadata]() {
if (compiler_metadata.size > 0) {
stream_executor::tpu::SerializedProto_Free(compiler_metadata);
}
});
Status get_compiler_metadata_status =
tpu_program_group->SerializeCompilerMetadata(cache_entry.core_index(),
&compiler_metadata);
if (!get_compiler_metadata_status.ok()) {
return errors::Internal("Failed to serialize compiler metadata.");
}
header.mutable_compiler_metadata()->mutable_data()->assign(
compiler_metadata.bytes, compiler_metadata.size);
std::string encoded_header;
if (!header.AppendToString(&encoded_header)) {
return errors::Internal("Failed to serialize TPU program metadata.");
}
return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
}
#endif // LIBTFTPU
} // namespace tpu

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "grpcpp/support/slice.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
@ -82,14 +84,13 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials();
// Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The
// `cache_entry` will be instantiated by the function.
template <typename ResponseType>
Status FillCacheEntryFromGetTpuProgramResponse(
Status DeserializeRpcResponseToCacheEntry(
const absl::string_view local_proto_key, ResponseType* response,
std::shared_ptr<CacheEntry>* cache_entry);
// A helper to send `TpuCompilationCacheEntry` payload through gRPC channel.
void SendGetTpuProgramResponseHelper(
const TpuCompilationCacheEntry& cache_entry,
std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn);
// Serializes `TpuCompilationCacheEntry` to gRPC bufer slices.
xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
const TpuCompilationCacheEntry& cache_entry);
} // namespace tpu
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <chrono> // NOLINT
#include "grpcpp/support/byte_buffer.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/platform/coding.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
@ -125,15 +126,17 @@ void TpuCompilationCacheService::GetTpuProgram(GetTpuProgramCall* call) {
CHECK_NE(call->request.fetch_target(),
tpu::CompilationCacheFetchTarget::MAIN);
}
return SendGetTpuProgramResponseHelper(
cache_entry,
[&call](::grpc::ByteBuffer* buffer, ::grpc::Status error_status) {
if (buffer == nullptr) {
return call->SendResponse(error_status);
}
call->response = *buffer;
return call->SendResponse(::grpc::Status());
});
xla::StatusOr<std::vector<::grpc::Slice>> buffer_slices =
tpu::SerializeCacheEntryToBufferSlices(cache_entry);
if (!buffer_slices.ok()) {
return call->SendResponse(ToGrpcStatus(buffer_slices.status()));
}
call->response =
::grpc::ByteBuffer{&buffer_slices.ValueOrDie()[0], buffer_slices->size()};
return call->SendResponse(::grpc::Status());
}
void TpuCompilationCacheService::HandleGetTpuProgram(GetTpuProgramCall* call) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
typedef struct XLA_TpuProgram XLA_TpuProgram;
@ -24,6 +25,21 @@ typedef struct XLA_TpuProgram XLA_TpuProgram;
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
struct TpuExecutableSerializedProto {
const char* bytes;
size_t size;
};
struct CompilerMetadataSerializedProto {
const char* bytes;
size_t size;
};
struct HostComputeMetadataSerializedProto {
const char* bytes;
size_t size;
};
extern "C" {
// Creates a new TPU program.
@ -67,7 +83,7 @@ TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
// Check if TPU program has sharding.
// Checks if TPU program has sharding.
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
const XLA_TpuProgram* tpu_program);
@ -76,6 +92,27 @@ TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
// Gets TPU executable proto from a `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
SE_Status* status);
// Gets compilation metadata proto from a `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
const XLA_TpuProgram* tpu_program,
CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status);
// Gets host transfer metadata proto from a `tpu_program`.
TFTPU_CAPI_EXPORT void TpuProgram_SerializeHostComputeMetadata(
const XLA_TpuProgram* tpu_program,
HostComputeMetadataSerializedProto* host_compute_metadata,
SE_Status* status);
// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
SE_Status* status);
struct TfTpu_TpuProgramApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
@ -90,6 +127,10 @@ struct TfTpu_TpuProgramApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeHostComputeMetadata);
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
};
} // extern "C"

View File

@ -29,11 +29,8 @@ limitations under the License.
namespace tensorflow {
namespace tpu {
namespace {
namespace se_tpu = ::stream_executor::tpu;
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
using xla::Shape;
@ -195,7 +192,20 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
tpu_programs_.clear();
}
/*static*/ Status TpuProgramGroup::Build(
/*static*/
std::unique_ptr<TpuProgramGroup> TpuProgramGroup::Create(int count) {
auto tpu_program_group = std::make_unique<TpuProgramGroup>();
std::vector<XLA_TpuProgram*> tpu_programs;
tpu_programs.resize(count);
for (int i = 0; i < count; ++i) {
tpu_programs[i] = TpuProgramApiFn()->TpuProgram_NewFn();
}
tpu_program_group->set_tpu_programs(tpu_programs);
return tpu_program_group;
}
/*static*/
Status TpuProgramGroup::Build(
const TPUCompileMetadataProto& metadata,
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
const std::vector<ShardingAndIndex>& arg_core_mapping,
@ -283,7 +293,7 @@ Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
return Status::OK();
}
const std::vector<bool>& TpuProgramGroup::may_modify_variables() const {
const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
return may_modify_variables_;
}
@ -292,6 +302,15 @@ void TpuProgramGroup::set_may_modify_variables(
may_modify_variables_ = may_modify_variables;
}
bool TpuProgramGroup::may_modify_variables(int index) const {
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
bool may_modify_variables;
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
&may_modify_variables);
return may_modify_variables;
}
const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
return tpu_programs_;
}
@ -371,5 +390,47 @@ std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
}
return tpu_programs;
}
Status TpuProgramGroup::DeserializeFromProto(int index,
TpuSerializedProto proto) {
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
StatusHelper status;
CHECK_NE(tpu_programs_[index], nullptr);
TpuProgramApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
proto, tpu_programs_[index], status.c_status);
return status.status();
}
Status TpuProgramGroup::SerializeExecutable(
int index, TpuExecutableSerializedProto* executable) const {
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
StatusHelper status;
TpuProgramApiFn()->TpuProgram_SerializeTpuExecutableFn(
tpu_programs_[index], executable, status.c_status);
return status.status();
}
Status TpuProgramGroup::SerializeCompilerMetadata(
int index, CompilerMetadataSerializedProto* compiler_metadata) const {
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
StatusHelper status;
TpuProgramApiFn()->TpuProgram_SerializeCompilerMetadataFn(
tpu_programs_[index], compiler_metadata, status.c_status);
return status.status();
}
Status TpuProgramGroup::SerializeHostComputeMetadata(
int index,
HostComputeMetadataSerializedProto* host_compute_metadata) const {
CHECK_GE(index, 0);
CHECK_LT(index, tpu_programs_.size());
StatusHelper status;
TpuProgramApiFn()->TpuProgram_SerializeHostComputeMetadataFn(
tpu_programs_[index], host_compute_metadata, status.c_status);
return status.status();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -15,9 +15,11 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
@ -102,6 +104,9 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
TpuProgramGroupInterface* tpu_program_group_interface);
// Creates the `count` instances of uninitialized `XLA_TpuPrograms`.
static std::unique_ptr<TpuProgramGroup> Create(int count);
// Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
@ -122,8 +127,9 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
Status LogCompilationStats(const TpuCompilationCacheKey& key,
absl::Duration duration) override;
const std::vector<bool>& may_modify_variables() const override;
const std::vector<bool>& may_modify_variables_list() const override;
void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
bool may_modify_variables(int index) const override;
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
@ -137,6 +143,25 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
const xla::HloProto* hlo_metadata(int index) const;
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
// Deserializes `GetTpuProgramResponse` proto into an `XLA_TpuProgram` for
// the given core `index`.
Status DeserializeFromProto(int index, TpuSerializedProto proto);
// Serializes executable proto from the TPU program for the given core
// `index`.
Status SerializeExecutable(int index,
TpuExecutableSerializedProto* executable) const;
// Serializes compiler metadata of the TPU program for the given core `index`.
Status SerializeCompilerMetadata(
int index, CompilerMetadataSerializedProto* compiler_metadata) const;
// Serializes host compute metadata of the TPU program for the given core
// `index`.
Status SerializeHostComputeMetadata(
int index,
HostComputeMetadataSerializedProto* host_compute_metadata) const;
private:
void RefreshHloMetadatasPtrs();

View File

@ -61,7 +61,11 @@ class TpuProgramGroupInterface {
// Boolean array to indicate if the modification of variables are
// allowed.
virtual const std::vector<bool>& may_modify_variables() const = 0;
virtual const std::vector<bool>& may_modify_variables_list() const = 0;
// Gets may modify variables value of the TPU program for the given core
// `index`.
virtual bool may_modify_variables(int index) const = 0;
};
} // namespace tpu

View File

@ -72,6 +72,11 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeTpuExecutable);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeCompilerMetadata);
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeHostComputeMetadata);
TFTPU_SET_FN(tpu_program_fn,
TpuProgram_DeserializeFromGetTpuProgramResponseProto);
return tensorflow::Status::OK();
}

View File

@ -32,12 +32,13 @@ namespace tpu {
using SerializedProto = TpuSerializedProto;
// Serializes a proto and put the result in the given SerializedProto* argument.
// Serializes a `proto` and put the result in the given `SerializedProtoType*`
// argument.
//
// Users should call SerializedProto_Free on `serialized_proto` afterwards.
template <class Proto>
inline void SerializeProto(const Proto& proto,
SerializedProto* serialized_proto) {
template <class ProtoType, class SerializedProtoType>
inline void SerializeProto(const ProtoType& proto,
SerializedProtoType* serialized_proto) {
auto size = proto.ByteSizeLong();
auto bytes = new char[size];
CHECK(proto.SerializeToArray(bytes, size));
@ -48,8 +49,8 @@ inline void SerializeProto(const Proto& proto,
// Serializes a proto and return the result as a SerializedProto value.
//
// Users should call SerializedProto_Free on the return value afterwards.
template <class Proto>
inline SerializedProto SerializeProto(const Proto& proto) {
template <class ProtoType>
inline SerializedProto SerializeProto(const ProtoType& proto) {
SerializedProto serialized_proto;
SerializeProto(proto, &serialized_proto);
return serialized_proto;
@ -57,9 +58,9 @@ inline SerializedProto SerializeProto(const Proto& proto) {
// Deserializes a buffer and return the corresponding proto. If the buffer is
// empty, return an empty proto.
template <class Proto>
inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
Proto proto;
template <class ProtoType, class SerializedProtoType>
inline ProtoType DeserializeProto(const SerializedProtoType& serialized_proto) {
ProtoType proto;
if (serialized_proto.bytes != nullptr) {
CHECK_GT(serialized_proto.size, 0);
CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size))
@ -69,7 +70,8 @@ inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
}
// Releases the memory allocated for serialized protos.
inline void SerializedProto_Free(const SerializedProto& serialized_proto) {
template <class SerializedProtoType>
inline void SerializedProto_Free(const SerializedProtoType& serialized_proto) {
CHECK_NE(serialized_proto.bytes, nullptr);
CHECK_GT(serialized_proto.size, 0);
delete[] serialized_proto.bytes;