Implement Serialization/Deserialization for remote compilation cache.
PiperOrigin-RevId: 329561676 Change-Id: If9df54e30be753420ff70e88eb179751fc3de4ae
This commit is contained in:
parent
a1e501e957
commit
b4ee2c4294
@ -317,6 +317,7 @@ cc_library(
|
|||||||
":tpu_mesh_state_interface",
|
":tpu_mesh_state_interface",
|
||||||
":tpu_program_c_api_hdrs",
|
":tpu_program_c_api_hdrs",
|
||||||
":tpu_program_group_interface",
|
":tpu_program_group_interface",
|
||||||
|
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||||
@ -467,6 +468,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":tpu_util_c_api_hdrs",
|
":tpu_util_c_api_hdrs",
|
||||||
"//tensorflow/core/tpu:libtftpu_header",
|
"//tensorflow/core/tpu:libtftpu_header",
|
||||||
|
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||||
],
|
],
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
@ -515,8 +517,8 @@ cc_library(
|
|||||||
DEFAULT: [],
|
DEFAULT: [],
|
||||||
}),
|
}),
|
||||||
deps = select({
|
deps = select({
|
||||||
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
|
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep
|
||||||
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
|
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], # build_cleaner: keep
|
||||||
}) + [
|
}) + [
|
||||||
":tpu_compilation_cache_entry",
|
":tpu_compilation_cache_entry",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
@ -536,8 +538,16 @@ cc_library(
|
|||||||
DEFAULT: [],
|
DEFAULT: [],
|
||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
|
":tpu_compilation_cache_common_proto_cc",
|
||||||
":tpu_compilation_cache_proto_cc",
|
":tpu_compilation_cache_proto_cc",
|
||||||
":tpu_compilation_cache_rpc_support_hdrs",
|
":tpu_compilation_cache_rpc_support_hdrs",
|
||||||
|
":tpu_program_group",
|
||||||
|
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
|
"//tensorflow/core/tpu:tpu_config_c_api",
|
||||||
|
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -613,27 +623,22 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
deps = select({
|
deps = select({
|
||||||
WITH_TPU_SUPPORT: [
|
WITH_TPU_SUPPORT: [
|
||||||
":tpu_compilation_cache_rpc_support",
|
":tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
||||||
":tpu_compilation_cache_proto_cc",
|
":tpu_compilation_cache_proto_cc", # build_cleaner: keep
|
||||||
],
|
],
|
||||||
DEFAULT: [
|
DEFAULT: [
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support",
|
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep
|
||||||
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc",
|
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc", # build_cleaner: keep
|
||||||
],
|
],
|
||||||
}) + [
|
}) + [
|
||||||
":tpu_compilation_cache_common_proto_cc",
|
":tpu_compilation_cache_common_proto_cc",
|
||||||
":tpu_compilation_cache_entry",
|
|
||||||
":tpu_compilation_cache_grpc",
|
":tpu_compilation_cache_grpc",
|
||||||
":tpu_compilation_cache_interface",
|
":tpu_compilation_cache_interface",
|
||||||
":tpu_compilation_cache_rpc_support_hdrs",
|
":tpu_compilation_cache_rpc_support_hdrs",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
|
||||||
"@com_google_absl//absl/synchronization",
|
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"//tensorflow/core/lib/core:refcount",
|
|
||||||
"//tensorflow/core/lib/core:threadpool",
|
"//tensorflow/core/lib/core:threadpool",
|
||||||
"//tensorflow/core/platform:coding",
|
"//tensorflow/core/platform:coding",
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
tf_grpc_cc_dependency(),
|
tf_grpc_cc_dependency(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -497,7 +497,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
|||||||
*uid = entry->uid;
|
*uid = entry->uid;
|
||||||
// Let the caller know the keys for each of the cached protos.
|
// Let the caller know the keys for each of the cached protos.
|
||||||
*proto_key = entry->proto_key;
|
*proto_key = entry->proto_key;
|
||||||
*may_modify_variables = entry->tpu_program_group->may_modify_variables();
|
*may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
|
||||||
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
|
*hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
|
||||||
|
|
||||||
// If the caller didn't supply a per_step_ref_holder then the caller is going
|
// If the caller didn't supply a per_step_ref_holder then the caller is going
|
||||||
|
@ -160,7 +160,7 @@ Status TpuCompilationCacheRpcLookup::RemoteLookupLocked(
|
|||||||
<< " in remote subgraph cache status " << s;
|
<< " in remote subgraph cache status " << s;
|
||||||
TF_RETURN_IF_ERROR(s);
|
TF_RETURN_IF_ERROR(s);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(FillCacheEntryFromGetTpuProgramResponse(
|
TF_RETURN_IF_ERROR(DeserializeRpcResponseToCacheEntry(
|
||||||
local_proto_key, &response, cache_entry));
|
local_proto_key, &response, cache_entry));
|
||||||
cache_.emplace(local_proto_key, (*cache_entry));
|
cache_.emplace(local_proto_key, (*cache_entry));
|
||||||
cache_size_ += (*cache_entry)->size;
|
cache_size_ += (*cache_entry)->size;
|
||||||
|
@ -14,9 +14,15 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTFTPU)
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||||
#endif // LIBTFTPU
|
#endif
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
@ -26,18 +32,127 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
|
|||||||
|
|
||||||
#if defined(LIBTFTPU)
|
#if defined(LIBTFTPU)
|
||||||
template <>
|
template <>
|
||||||
Status FillCacheEntryFromGetTpuProgramResponse<GetTpuProgramResponseExternal>(
|
Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
|
||||||
absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
|
absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
|
||||||
std::shared_ptr<CacheEntry>* cache_entry) {
|
std::shared_ptr<CacheEntry>* cache_entry) {
|
||||||
// TODO(b/162904194): implement this method.
|
CHECK_NE(response, nullptr);
|
||||||
LOG(FATAL) << "Not implemented yet.";
|
CHECK_NE(cache_entry, nullptr);
|
||||||
|
*cache_entry = std::make_shared<CacheEntry>();
|
||||||
|
CacheEntry& entry = **cache_entry;
|
||||||
|
entry.key = local_proto_key;
|
||||||
|
|
||||||
|
if (response->is_empty()) {
|
||||||
|
entry.size = 0;
|
||||||
|
} else {
|
||||||
|
// When we lookup from remote cache, we fetch a TPU program for a specific
|
||||||
|
// core hence we allocate TPU program group for a single program.
|
||||||
|
entry.tpu_program_group = TpuProgramGroup::Create(/*count=*/1);
|
||||||
|
|
||||||
|
TpuSerializedProto serialized_response_proto =
|
||||||
|
stream_executor::tpu::SerializeProto(*response);
|
||||||
|
auto cleanup = xla::MakeCleanup([&serialized_response_proto]() {
|
||||||
|
stream_executor::tpu::SerializedProto_Free(serialized_response_proto);
|
||||||
|
});
|
||||||
|
// TODO(b/166575150): can be optimized by sending the buffer over the gRPC
|
||||||
|
// without an extra deserializing.
|
||||||
|
TpuProgramGroup* tpu_program_group =
|
||||||
|
tensorflow::down_cast<TpuProgramGroup*>(entry.tpu_program_group.get());
|
||||||
|
TF_RETURN_IF_ERROR(tpu_program_group->DeserializeFromProto(
|
||||||
|
/*index=*/0, serialized_response_proto));
|
||||||
|
entry.size = entry.tpu_program_group->program_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void SendGetTpuProgramResponseHelper(
|
xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
|
||||||
const TpuCompilationCacheEntry& cache_entry,
|
const TpuCompilationCacheEntry& cache_entry) {
|
||||||
std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn) {
|
if (cache_entry.tpu_program_group() == nullptr) {
|
||||||
// TODO(b/162904194): implement this method.
|
// It's possible that the sharding/unsharding entry does not exist, but the
|
||||||
LOG(FATAL) << "Not implemented yet.";
|
// main entry must exist.
|
||||||
|
GetTpuProgramResponseExternal header;
|
||||||
|
header.set_is_empty(true);
|
||||||
|
std::string encoded_header;
|
||||||
|
if (!header.AppendToString(&encoded_header)) {
|
||||||
|
return errors::Internal("Failed to serialize TPU program metadata.");
|
||||||
|
}
|
||||||
|
::grpc::Slice slice(encoded_header);
|
||||||
|
return std::vector<::grpc::Slice>{slice};
|
||||||
|
}
|
||||||
|
|
||||||
|
const TpuProgramGroup* tpu_program_group =
|
||||||
|
tensorflow::down_cast<const TpuProgramGroup*>(
|
||||||
|
cache_entry.tpu_program_group());
|
||||||
|
CHECK_NE(tpu_program_group, nullptr);
|
||||||
|
CHECK_GE(tpu_program_group->program_count(), 0);
|
||||||
|
CHECK_GE(cache_entry.core_index(), 0);
|
||||||
|
CHECK_LT(cache_entry.core_index(), tpu_program_group->program_count());
|
||||||
|
const int64 program_size = tpu_program_group->program_size();
|
||||||
|
if (program_size > INT_MAX) {
|
||||||
|
return errors::Internal("TPU program exceeded 2 GiB.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TpuExecutableSerializedProto executable;
|
||||||
|
auto cleanup_executable = xla::MakeCleanup([&executable]() {
|
||||||
|
if (executable.size > 0) {
|
||||||
|
stream_executor::tpu::SerializedProto_Free(executable);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
auto get_executable_status = tpu_program_group->SerializeExecutable(
|
||||||
|
cache_entry.core_index(), &executable);
|
||||||
|
if (!get_executable_status.ok()) {
|
||||||
|
return errors::Internal("Failed to serialize TPU program.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode and serialize header fields.
|
||||||
|
GetTpuProgramResponseExternal header;
|
||||||
|
header.mutable_proto()->ParseFromArray(executable.bytes, executable.size);
|
||||||
|
header.set_is_empty(false);
|
||||||
|
|
||||||
|
HostComputeMetadataSerializedProto host_compute_metadata;
|
||||||
|
auto cleanup_host_compute_metadata =
|
||||||
|
xla::MakeCleanup([&host_compute_metadata]() {
|
||||||
|
if (host_compute_metadata.size > 0) {
|
||||||
|
stream_executor::tpu::SerializedProto_Free(host_compute_metadata);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Status get_host_compute_metadata_status =
|
||||||
|
tpu_program_group->SerializeHostComputeMetadata(cache_entry.core_index(),
|
||||||
|
&host_compute_metadata);
|
||||||
|
if (!get_host_compute_metadata_status.ok()) {
|
||||||
|
return errors::Internal("Failed to serialize host compute metadata.");
|
||||||
|
}
|
||||||
|
tf2xla::HostComputeMetadata host_compute_metadata_proto =
|
||||||
|
stream_executor::tpu::DeserializeProto<tf2xla::HostComputeMetadata>(
|
||||||
|
host_compute_metadata);
|
||||||
|
*header.mutable_host_compute_metadata() =
|
||||||
|
std::move(host_compute_metadata_proto);
|
||||||
|
|
||||||
|
bool may_modify_variables =
|
||||||
|
tpu_program_group->may_modify_variables(cache_entry.core_index());
|
||||||
|
header.set_may_modify_variables(may_modify_variables);
|
||||||
|
|
||||||
|
CompilerMetadataSerializedProto compiler_metadata;
|
||||||
|
auto cleanup_compiler_metadata = xla::MakeCleanup([&compiler_metadata]() {
|
||||||
|
if (compiler_metadata.size > 0) {
|
||||||
|
stream_executor::tpu::SerializedProto_Free(compiler_metadata);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Status get_compiler_metadata_status =
|
||||||
|
tpu_program_group->SerializeCompilerMetadata(cache_entry.core_index(),
|
||||||
|
&compiler_metadata);
|
||||||
|
if (!get_compiler_metadata_status.ok()) {
|
||||||
|
return errors::Internal("Failed to serialize compiler metadata.");
|
||||||
|
}
|
||||||
|
header.mutable_compiler_metadata()->mutable_data()->assign(
|
||||||
|
compiler_metadata.bytes, compiler_metadata.size);
|
||||||
|
|
||||||
|
std::string encoded_header;
|
||||||
|
if (!header.AppendToString(&encoded_header)) {
|
||||||
|
return errors::Internal("Failed to serialize TPU program metadata.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
|
||||||
}
|
}
|
||||||
#endif // LIBTFTPU
|
#endif // LIBTFTPU
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "grpcpp/support/slice.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||||
@ -82,14 +84,13 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials();
|
|||||||
// Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The
|
// Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The
|
||||||
// `cache_entry` will be instantiated by the function.
|
// `cache_entry` will be instantiated by the function.
|
||||||
template <typename ResponseType>
|
template <typename ResponseType>
|
||||||
Status FillCacheEntryFromGetTpuProgramResponse(
|
Status DeserializeRpcResponseToCacheEntry(
|
||||||
const absl::string_view local_proto_key, ResponseType* response,
|
const absl::string_view local_proto_key, ResponseType* response,
|
||||||
std::shared_ptr<CacheEntry>* cache_entry);
|
std::shared_ptr<CacheEntry>* cache_entry);
|
||||||
|
|
||||||
// A helper to send `TpuCompilationCacheEntry` payload through gRPC channel.
|
// Serializes `TpuCompilationCacheEntry` to gRPC bufer slices.
|
||||||
void SendGetTpuProgramResponseHelper(
|
xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
|
||||||
const TpuCompilationCacheEntry& cache_entry,
|
const TpuCompilationCacheEntry& cache_entry);
|
||||||
std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn);
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <chrono> // NOLINT
|
#include <chrono> // NOLINT
|
||||||
|
|
||||||
|
#include "grpcpp/support/byte_buffer.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
#include "tensorflow/core/platform/coding.h"
|
#include "tensorflow/core/platform/coding.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
|
||||||
@ -125,15 +126,17 @@ void TpuCompilationCacheService::GetTpuProgram(GetTpuProgramCall* call) {
|
|||||||
CHECK_NE(call->request.fetch_target(),
|
CHECK_NE(call->request.fetch_target(),
|
||||||
tpu::CompilationCacheFetchTarget::MAIN);
|
tpu::CompilationCacheFetchTarget::MAIN);
|
||||||
}
|
}
|
||||||
return SendGetTpuProgramResponseHelper(
|
|
||||||
cache_entry,
|
xla::StatusOr<std::vector<::grpc::Slice>> buffer_slices =
|
||||||
[&call](::grpc::ByteBuffer* buffer, ::grpc::Status error_status) {
|
tpu::SerializeCacheEntryToBufferSlices(cache_entry);
|
||||||
if (buffer == nullptr) {
|
|
||||||
return call->SendResponse(error_status);
|
if (!buffer_slices.ok()) {
|
||||||
|
return call->SendResponse(ToGrpcStatus(buffer_slices.status()));
|
||||||
}
|
}
|
||||||
call->response = *buffer;
|
|
||||||
|
call->response =
|
||||||
|
::grpc::ByteBuffer{&buffer_slices.ValueOrDie()[0], buffer_slices->size()};
|
||||||
return call->SendResponse(::grpc::Status());
|
return call->SendResponse(::grpc::Status());
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TpuCompilationCacheService::HandleGetTpuProgram(GetTpuProgramCall* call) {
|
void TpuCompilationCacheService::HandleGetTpuProgram(GetTpuProgramCall* call) {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||||
#include "tensorflow/core/tpu/libtftpu.h"
|
#include "tensorflow/core/tpu/libtftpu.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
|
|
||||||
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||||
@ -24,6 +25,21 @@ typedef struct XLA_TpuProgram XLA_TpuProgram;
|
|||||||
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
|
// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
|
||||||
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
|
enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
|
||||||
|
|
||||||
|
struct TpuExecutableSerializedProto {
|
||||||
|
const char* bytes;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CompilerMetadataSerializedProto {
|
||||||
|
const char* bytes;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct HostComputeMetadataSerializedProto {
|
||||||
|
const char* bytes;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
// Creates a new TPU program.
|
// Creates a new TPU program.
|
||||||
@ -67,7 +83,7 @@ TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
|
|||||||
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
|
||||||
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
|
const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
|
||||||
|
|
||||||
// Check if TPU program has sharding.
|
// Checks if TPU program has sharding.
|
||||||
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
||||||
const XLA_TpuProgram* tpu_program);
|
const XLA_TpuProgram* tpu_program);
|
||||||
|
|
||||||
@ -76,6 +92,27 @@ TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
|
|||||||
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
|
TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
|
||||||
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
|
XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
|
||||||
|
|
||||||
|
// Gets TPU executable proto from a `tpu_program`.
|
||||||
|
TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
|
||||||
|
const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
|
||||||
|
SE_Status* status);
|
||||||
|
|
||||||
|
// Gets compilation metadata proto from a `tpu_program`.
|
||||||
|
TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
|
||||||
|
const XLA_TpuProgram* tpu_program,
|
||||||
|
CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status);
|
||||||
|
|
||||||
|
// Gets host transfer metadata proto from a `tpu_program`.
|
||||||
|
TFTPU_CAPI_EXPORT void TpuProgram_SerializeHostComputeMetadata(
|
||||||
|
const XLA_TpuProgram* tpu_program,
|
||||||
|
HostComputeMetadataSerializedProto* host_compute_metadata,
|
||||||
|
SE_Status* status);
|
||||||
|
|
||||||
|
// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
|
||||||
|
TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
|
||||||
|
TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
|
||||||
|
SE_Status* status);
|
||||||
|
|
||||||
struct TfTpu_TpuProgramApiFn {
|
struct TfTpu_TpuProgramApiFn {
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
|
||||||
@ -90,6 +127,10 @@ struct TfTpu_TpuProgramApiFn {
|
|||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeHostComputeMetadata);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
@ -29,11 +29,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
namespace se_tpu = ::stream_executor::tpu;
|
namespace se_tpu = ::stream_executor::tpu;
|
||||||
|
|
||||||
using stream_executor::port::Status;
|
using stream_executor::port::Status;
|
||||||
using stream_executor::port::StatusOr;
|
using stream_executor::port::StatusOr;
|
||||||
using xla::Shape;
|
using xla::Shape;
|
||||||
@ -195,7 +192,20 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() {
|
|||||||
tpu_programs_.clear();
|
tpu_programs_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ Status TpuProgramGroup::Build(
|
/*static*/
|
||||||
|
std::unique_ptr<TpuProgramGroup> TpuProgramGroup::Create(int count) {
|
||||||
|
auto tpu_program_group = std::make_unique<TpuProgramGroup>();
|
||||||
|
std::vector<XLA_TpuProgram*> tpu_programs;
|
||||||
|
tpu_programs.resize(count);
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
tpu_programs[i] = TpuProgramApiFn()->TpuProgram_NewFn();
|
||||||
|
}
|
||||||
|
tpu_program_group->set_tpu_programs(tpu_programs);
|
||||||
|
return tpu_program_group;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/
|
||||||
|
Status TpuProgramGroup::Build(
|
||||||
const TPUCompileMetadataProto& metadata,
|
const TPUCompileMetadataProto& metadata,
|
||||||
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
|
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
|
||||||
const std::vector<ShardingAndIndex>& arg_core_mapping,
|
const std::vector<ShardingAndIndex>& arg_core_mapping,
|
||||||
@ -283,7 +293,7 @@ Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<bool>& TpuProgramGroup::may_modify_variables() const {
|
const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
|
||||||
return may_modify_variables_;
|
return may_modify_variables_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,6 +302,15 @@ void TpuProgramGroup::set_may_modify_variables(
|
|||||||
may_modify_variables_ = may_modify_variables;
|
may_modify_variables_ = may_modify_variables;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TpuProgramGroup::may_modify_variables(int index) const {
|
||||||
|
CHECK_GE(index, 0);
|
||||||
|
CHECK_LT(index, tpu_programs_.size());
|
||||||
|
bool may_modify_variables;
|
||||||
|
TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
|
||||||
|
&may_modify_variables);
|
||||||
|
return may_modify_variables;
|
||||||
|
}
|
||||||
|
|
||||||
const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
|
const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
|
||||||
return tpu_programs_;
|
return tpu_programs_;
|
||||||
}
|
}
|
||||||
@ -371,5 +390,47 @@ std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
|
|||||||
}
|
}
|
||||||
return tpu_programs;
|
return tpu_programs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TpuProgramGroup::DeserializeFromProto(int index,
|
||||||
|
TpuSerializedProto proto) {
|
||||||
|
CHECK_GE(index, 0);
|
||||||
|
CHECK_LT(index, tpu_programs_.size());
|
||||||
|
StatusHelper status;
|
||||||
|
CHECK_NE(tpu_programs_[index], nullptr);
|
||||||
|
TpuProgramApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
|
||||||
|
proto, tpu_programs_[index], status.c_status);
|
||||||
|
return status.status();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuProgramGroup::SerializeExecutable(
|
||||||
|
int index, TpuExecutableSerializedProto* executable) const {
|
||||||
|
CHECK_GE(index, 0);
|
||||||
|
CHECK_LT(index, tpu_programs_.size());
|
||||||
|
StatusHelper status;
|
||||||
|
TpuProgramApiFn()->TpuProgram_SerializeTpuExecutableFn(
|
||||||
|
tpu_programs_[index], executable, status.c_status);
|
||||||
|
return status.status();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuProgramGroup::SerializeCompilerMetadata(
|
||||||
|
int index, CompilerMetadataSerializedProto* compiler_metadata) const {
|
||||||
|
CHECK_GE(index, 0);
|
||||||
|
CHECK_LT(index, tpu_programs_.size());
|
||||||
|
StatusHelper status;
|
||||||
|
TpuProgramApiFn()->TpuProgram_SerializeCompilerMetadataFn(
|
||||||
|
tpu_programs_[index], compiler_metadata, status.c_status);
|
||||||
|
return status.status();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TpuProgramGroup::SerializeHostComputeMetadata(
|
||||||
|
int index,
|
||||||
|
HostComputeMetadataSerializedProto* host_compute_metadata) const {
|
||||||
|
CHECK_GE(index, 0);
|
||||||
|
CHECK_LT(index, tpu_programs_.size());
|
||||||
|
StatusHelper status;
|
||||||
|
TpuProgramApiFn()->TpuProgram_SerializeHostComputeMetadataFn(
|
||||||
|
tpu_programs_[index], host_compute_metadata, status.c_status);
|
||||||
|
return status.status();
|
||||||
|
}
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -15,9 +15,11 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
|
||||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
@ -102,6 +104,9 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
|||||||
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
|
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
|
||||||
TpuProgramGroupInterface* tpu_program_group_interface);
|
TpuProgramGroupInterface* tpu_program_group_interface);
|
||||||
|
|
||||||
|
// Creates the `count` instances of uninitialized `XLA_TpuPrograms`.
|
||||||
|
static std::unique_ptr<TpuProgramGroup> Create(int count);
|
||||||
|
|
||||||
// Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
|
// Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
|
||||||
void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
|
void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
|
||||||
|
|
||||||
@ -122,8 +127,9 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
|||||||
Status LogCompilationStats(const TpuCompilationCacheKey& key,
|
Status LogCompilationStats(const TpuCompilationCacheKey& key,
|
||||||
absl::Duration duration) override;
|
absl::Duration duration) override;
|
||||||
|
|
||||||
const std::vector<bool>& may_modify_variables() const override;
|
const std::vector<bool>& may_modify_variables_list() const override;
|
||||||
void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
|
void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
|
||||||
|
bool may_modify_variables(int index) const override;
|
||||||
|
|
||||||
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
|
const std::vector<XLA_TpuProgram*>& tpu_programs() const;
|
||||||
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
|
std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
|
||||||
@ -137,6 +143,25 @@ class TpuProgramGroup : public TpuProgramGroupInterface {
|
|||||||
const xla::HloProto* hlo_metadata(int index) const;
|
const xla::HloProto* hlo_metadata(int index) const;
|
||||||
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
|
absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
|
||||||
|
|
||||||
|
// Deserializes `GetTpuProgramResponse` proto into an `XLA_TpuProgram` for
|
||||||
|
// the given core `index`.
|
||||||
|
Status DeserializeFromProto(int index, TpuSerializedProto proto);
|
||||||
|
|
||||||
|
// Serializes executable proto from the TPU program for the given core
|
||||||
|
// `index`.
|
||||||
|
Status SerializeExecutable(int index,
|
||||||
|
TpuExecutableSerializedProto* executable) const;
|
||||||
|
|
||||||
|
// Serializes compiler metadata of the TPU program for the given core `index`.
|
||||||
|
Status SerializeCompilerMetadata(
|
||||||
|
int index, CompilerMetadataSerializedProto* compiler_metadata) const;
|
||||||
|
|
||||||
|
// Serializes host compute metadata of the TPU program for the given core
|
||||||
|
// `index`.
|
||||||
|
Status SerializeHostComputeMetadata(
|
||||||
|
int index,
|
||||||
|
HostComputeMetadataSerializedProto* host_compute_metadata) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void RefreshHloMetadatasPtrs();
|
void RefreshHloMetadatasPtrs();
|
||||||
|
|
||||||
|
@ -61,7 +61,11 @@ class TpuProgramGroupInterface {
|
|||||||
|
|
||||||
// Boolean array to indicate if the modification of variables are
|
// Boolean array to indicate if the modification of variables are
|
||||||
// allowed.
|
// allowed.
|
||||||
virtual const std::vector<bool>& may_modify_variables() const = 0;
|
virtual const std::vector<bool>& may_modify_variables_list() const = 0;
|
||||||
|
|
||||||
|
// Gets may modify variables value of the TPU program for the given core
|
||||||
|
// `index`.
|
||||||
|
virtual bool may_modify_variables(int index) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
|
@ -72,6 +72,11 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) {
|
|||||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables);
|
||||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding);
|
||||||
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram);
|
||||||
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeTpuExecutable);
|
||||||
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeCompilerMetadata);
|
||||||
|
TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeHostComputeMetadata);
|
||||||
|
TFTPU_SET_FN(tpu_program_fn,
|
||||||
|
TpuProgram_DeserializeFromGetTpuProgramResponseProto);
|
||||||
|
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -32,12 +32,13 @@ namespace tpu {
|
|||||||
|
|
||||||
using SerializedProto = TpuSerializedProto;
|
using SerializedProto = TpuSerializedProto;
|
||||||
|
|
||||||
// Serializes a proto and put the result in the given SerializedProto* argument.
|
// Serializes a `proto` and put the result in the given `SerializedProtoType*`
|
||||||
|
// argument.
|
||||||
//
|
//
|
||||||
// Users should call SerializedProto_Free on `serialized_proto` afterwards.
|
// Users should call SerializedProto_Free on `serialized_proto` afterwards.
|
||||||
template <class Proto>
|
template <class ProtoType, class SerializedProtoType>
|
||||||
inline void SerializeProto(const Proto& proto,
|
inline void SerializeProto(const ProtoType& proto,
|
||||||
SerializedProto* serialized_proto) {
|
SerializedProtoType* serialized_proto) {
|
||||||
auto size = proto.ByteSizeLong();
|
auto size = proto.ByteSizeLong();
|
||||||
auto bytes = new char[size];
|
auto bytes = new char[size];
|
||||||
CHECK(proto.SerializeToArray(bytes, size));
|
CHECK(proto.SerializeToArray(bytes, size));
|
||||||
@ -48,8 +49,8 @@ inline void SerializeProto(const Proto& proto,
|
|||||||
// Serializes a proto and return the result as a SerializedProto value.
|
// Serializes a proto and return the result as a SerializedProto value.
|
||||||
//
|
//
|
||||||
// Users should call SerializedProto_Free on the return value afterwards.
|
// Users should call SerializedProto_Free on the return value afterwards.
|
||||||
template <class Proto>
|
template <class ProtoType>
|
||||||
inline SerializedProto SerializeProto(const Proto& proto) {
|
inline SerializedProto SerializeProto(const ProtoType& proto) {
|
||||||
SerializedProto serialized_proto;
|
SerializedProto serialized_proto;
|
||||||
SerializeProto(proto, &serialized_proto);
|
SerializeProto(proto, &serialized_proto);
|
||||||
return serialized_proto;
|
return serialized_proto;
|
||||||
@ -57,9 +58,9 @@ inline SerializedProto SerializeProto(const Proto& proto) {
|
|||||||
|
|
||||||
// Deserializes a buffer and return the corresponding proto. If the buffer is
|
// Deserializes a buffer and return the corresponding proto. If the buffer is
|
||||||
// empty, return an empty proto.
|
// empty, return an empty proto.
|
||||||
template <class Proto>
|
template <class ProtoType, class SerializedProtoType>
|
||||||
inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
|
inline ProtoType DeserializeProto(const SerializedProtoType& serialized_proto) {
|
||||||
Proto proto;
|
ProtoType proto;
|
||||||
if (serialized_proto.bytes != nullptr) {
|
if (serialized_proto.bytes != nullptr) {
|
||||||
CHECK_GT(serialized_proto.size, 0);
|
CHECK_GT(serialized_proto.size, 0);
|
||||||
CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size))
|
CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size))
|
||||||
@ -69,7 +70,8 @@ inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Releases the memory allocated for serialized protos.
|
// Releases the memory allocated for serialized protos.
|
||||||
inline void SerializedProto_Free(const SerializedProto& serialized_proto) {
|
template <class SerializedProtoType>
|
||||||
|
inline void SerializedProto_Free(const SerializedProtoType& serialized_proto) {
|
||||||
CHECK_NE(serialized_proto.bytes, nullptr);
|
CHECK_NE(serialized_proto.bytes, nullptr);
|
||||||
CHECK_GT(serialized_proto.size, 0);
|
CHECK_GT(serialized_proto.size, 0);
|
||||||
delete[] serialized_proto.bytes;
|
delete[] serialized_proto.bytes;
|
||||||
|
Loading…
Reference in New Issue
Block a user