From b4ee2c429428da3a3f7023ac692997e9d42fe910 Mon Sep 17 00:00:00 2001 From: Henry Tan Date: Tue, 1 Sep 2020 12:45:00 -0700 Subject: [PATCH] Implement Serialization/Deserialization for remote compilation cache. PiperOrigin-RevId: 329561676 Change-Id: If9df54e30be753420ff70e88eb179751fc3de4ae --- tensorflow/core/tpu/kernels/BUILD | 27 ++-- .../tpu_compilation_cache_interface.cc | 2 +- .../tpu_compilation_cache_rpc_lookup.cc | 2 +- .../tpu_compilation_cache_rpc_support.cc | 133 ++++++++++++++++-- .../tpu_compilation_cache_rpc_support.h | 11 +- .../kernels/tpu_compilation_cache_service.cc | 21 +-- .../core/tpu/kernels/tpu_program_c_api.h | 43 +++++- .../core/tpu/kernels/tpu_program_group.cc | 71 +++++++++- .../core/tpu/kernels/tpu_program_group.h | 27 +++- .../tpu/kernels/tpu_program_group_interface.h | 6 +- tensorflow/core/tpu/tpu_library_init_fns.inc | 5 + tensorflow/stream_executor/tpu/proto_helper.h | 22 +-- 12 files changed, 316 insertions(+), 54 deletions(-) diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 8f97b2e45fe..c47fdc0f9d2 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -317,6 +317,7 @@ cc_library( ":tpu_mesh_state_interface", ":tpu_program_c_api_hdrs", ":tpu_program_group_interface", + "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/client:compile_only_client", @@ -467,6 +468,7 @@ cc_library( deps = [ ":tpu_util_c_api_hdrs", "//tensorflow/core/tpu:libtftpu_header", + "//tensorflow/stream_executor/tpu:c_api_decl", "//tensorflow/stream_executor/tpu:proto_helper", ], alwayslink = True, @@ -515,8 +517,8 @@ cc_library( DEFAULT: [], }), deps = select({ - WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], + WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep + DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], # build_cleaner: keep }) + [ ":tpu_compilation_cache_entry", ":tpu_compilation_cache_interface", @@ -536,8 +538,16 @@ cc_library( DEFAULT: [], }), deps = [ + ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_proto_cc", ":tpu_compilation_cache_rpc_support_hdrs", + ":tpu_program_group", + "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "//tensorflow/core/tpu:tpu_config_c_api", + "//tensorflow/stream_executor/tpu:proto_helper", ], ) @@ -613,27 +623,22 @@ cc_library( }), deps = select({ WITH_TPU_SUPPORT: [ - ":tpu_compilation_cache_rpc_support", - ":tpu_compilation_cache_proto_cc", + ":tpu_compilation_cache_rpc_support", # build_cleaner: keep + ":tpu_compilation_cache_proto_cc", # build_cleaner: keep ], DEFAULT: [ - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep + "//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc", # build_cleaner: keep ], }) + [ ":tpu_compilation_cache_common_proto_cc", - ":tpu_compilation_cache_entry", ":tpu_compilation_cache_grpc", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_rpc_support_hdrs", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "//tensorflow/core/lib/core:refcount", "//tensorflow/core/lib/core:threadpool", "//tensorflow/core/platform:coding", - "//tensorflow/core:protos_all_cc", tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc index 4cd2b864203..1928303b21c 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc @@ -497,7 +497,7 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( *uid = entry->uid; // Let the caller know the keys for each of the cached protos. *proto_key = entry->proto_key; - *may_modify_variables = entry->tpu_program_group->may_modify_variables(); + *may_modify_variables = entry->tpu_program_group->may_modify_variables_list(); *hlo_metadatas = entry->tpu_program_group->hlo_metadatas(); // If the caller didn't supply a per_step_ref_holder then the caller is going diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc index e3560de0c44..8b0fb674682 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc @@ -160,7 +160,7 @@ Status TpuCompilationCacheRpcLookup::RemoteLookupLocked( << " in remote subgraph cache status " << s; TF_RETURN_IF_ERROR(s); - TF_RETURN_IF_ERROR(FillCacheEntryFromGetTpuProgramResponse( + TF_RETURN_IF_ERROR(DeserializeRpcResponseToCacheEntry( local_proto_key, &response, cache_entry)); cache_.emplace(local_proto_key, (*cache_entry)); cache_size_ += (*cache_entry)->size; diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc index 0e77edf4ecf..60f71a18f1c 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc @@ -14,9 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/platform/casts.h" #if defined(LIBTFTPU) #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" -#endif // LIBTFTPU +#endif +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group.h" +#include "tensorflow/stream_executor/tpu/proto_helper.h" namespace tensorflow { namespace tpu { @@ -26,18 +32,127 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() { #if defined(LIBTFTPU) template <> -Status FillCacheEntryFromGetTpuProgramResponse( +Status DeserializeRpcResponseToCacheEntry( absl::string_view local_proto_key, GetTpuProgramResponseExternal* response, std::shared_ptr* cache_entry) { - // TODO(b/162904194): implement this method. - LOG(FATAL) << "Not implemented yet."; + CHECK_NE(response, nullptr); + CHECK_NE(cache_entry, nullptr); + *cache_entry = std::make_shared(); + CacheEntry& entry = **cache_entry; + entry.key = local_proto_key; + + if (response->is_empty()) { + entry.size = 0; + } else { + // When we lookup from remote cache, we fetch a TPU program for a specific + // core hence we allocate TPU program group for a single program. + entry.tpu_program_group = TpuProgramGroup::Create(/*count=*/1); + + TpuSerializedProto serialized_response_proto = + stream_executor::tpu::SerializeProto(*response); + auto cleanup = xla::MakeCleanup([&serialized_response_proto]() { + stream_executor::tpu::SerializedProto_Free(serialized_response_proto); + }); + // TODO(b/166575150): can be optimized by sending the buffer over the gRPC + // without an extra deserializing. + TpuProgramGroup* tpu_program_group = + tensorflow::down_cast(entry.tpu_program_group.get()); + TF_RETURN_IF_ERROR(tpu_program_group->DeserializeFromProto( + /*index=*/0, serialized_response_proto)); + entry.size = entry.tpu_program_group->program_size(); + } + + return Status::OK(); } -void SendGetTpuProgramResponseHelper( - const TpuCompilationCacheEntry& cache_entry, - std::function call_fn) { - // TODO(b/162904194): implement this method. - LOG(FATAL) << "Not implemented yet."; +xla::StatusOr> SerializeCacheEntryToBufferSlices( + const TpuCompilationCacheEntry& cache_entry) { + if (cache_entry.tpu_program_group() == nullptr) { + // It's possible that the sharding/unsharding entry does not exist, but the + // main entry must exist. + GetTpuProgramResponseExternal header; + header.set_is_empty(true); + std::string encoded_header; + if (!header.AppendToString(&encoded_header)) { + return errors::Internal("Failed to serialize TPU program metadata."); + } + ::grpc::Slice slice(encoded_header); + return std::vector<::grpc::Slice>{slice}; + } + + const TpuProgramGroup* tpu_program_group = + tensorflow::down_cast( + cache_entry.tpu_program_group()); + CHECK_NE(tpu_program_group, nullptr); + CHECK_GE(tpu_program_group->program_count(), 0); + CHECK_GE(cache_entry.core_index(), 0); + CHECK_LT(cache_entry.core_index(), tpu_program_group->program_count()); + const int64 program_size = tpu_program_group->program_size(); + if (program_size > INT_MAX) { + return errors::Internal("TPU program exceeded 2 GiB."); + } + + TpuExecutableSerializedProto executable; + auto cleanup_executable = xla::MakeCleanup([&executable]() { + if (executable.size > 0) { + stream_executor::tpu::SerializedProto_Free(executable); + } + }); + auto get_executable_status = tpu_program_group->SerializeExecutable( + cache_entry.core_index(), &executable); + if (!get_executable_status.ok()) { + return errors::Internal("Failed to serialize TPU program."); + } + + // Encode and serialize header fields. + GetTpuProgramResponseExternal header; + header.mutable_proto()->ParseFromArray(executable.bytes, executable.size); + header.set_is_empty(false); + + HostComputeMetadataSerializedProto host_compute_metadata; + auto cleanup_host_compute_metadata = + xla::MakeCleanup([&host_compute_metadata]() { + if (host_compute_metadata.size > 0) { + stream_executor::tpu::SerializedProto_Free(host_compute_metadata); + } + }); + Status get_host_compute_metadata_status = + tpu_program_group->SerializeHostComputeMetadata(cache_entry.core_index(), + &host_compute_metadata); + if (!get_host_compute_metadata_status.ok()) { + return errors::Internal("Failed to serialize host compute metadata."); + } + tf2xla::HostComputeMetadata host_compute_metadata_proto = + stream_executor::tpu::DeserializeProto( + host_compute_metadata); + *header.mutable_host_compute_metadata() = + std::move(host_compute_metadata_proto); + + bool may_modify_variables = + tpu_program_group->may_modify_variables(cache_entry.core_index()); + header.set_may_modify_variables(may_modify_variables); + + CompilerMetadataSerializedProto compiler_metadata; + auto cleanup_compiler_metadata = xla::MakeCleanup([&compiler_metadata]() { + if (compiler_metadata.size > 0) { + stream_executor::tpu::SerializedProto_Free(compiler_metadata); + } + }); + Status get_compiler_metadata_status = + tpu_program_group->SerializeCompilerMetadata(cache_entry.core_index(), + &compiler_metadata); + if (!get_compiler_metadata_status.ok()) { + return errors::Internal("Failed to serialize compiler metadata."); + } + header.mutable_compiler_metadata()->mutable_data()->assign( + compiler_metadata.bytes, compiler_metadata.size); + + std::string encoded_header; + if (!header.AppendToString(&encoded_header)) { + return errors::Internal("Failed to serialize TPU program metadata."); + } + + return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)}; } #endif // LIBTFTPU } // namespace tpu diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h index 6749138d710..c9099ec7a27 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h @@ -20,7 +20,9 @@ limitations under the License. #include #include #include +#include +#include "grpcpp/support/slice.h" #include "absl/strings/string_view.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" @@ -82,14 +84,13 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials(); // Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The // `cache_entry` will be instantiated by the function. template -Status FillCacheEntryFromGetTpuProgramResponse( +Status DeserializeRpcResponseToCacheEntry( const absl::string_view local_proto_key, ResponseType* response, std::shared_ptr* cache_entry); -// A helper to send `TpuCompilationCacheEntry` payload through gRPC channel. -void SendGetTpuProgramResponseHelper( - const TpuCompilationCacheEntry& cache_entry, - std::function call_fn); +// Serializes `TpuCompilationCacheEntry` to gRPC bufer slices. +xla::StatusOr> SerializeCacheEntryToBufferSlices( + const TpuCompilationCacheEntry& cache_entry); } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.cc index 5abd0c7f26b..f7a87c266e5 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.cc @@ -16,6 +16,7 @@ limitations under the License. #include // NOLINT +#include "grpcpp/support/byte_buffer.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/platform/coding.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h" @@ -125,15 +126,17 @@ void TpuCompilationCacheService::GetTpuProgram(GetTpuProgramCall* call) { CHECK_NE(call->request.fetch_target(), tpu::CompilationCacheFetchTarget::MAIN); } - return SendGetTpuProgramResponseHelper( - cache_entry, - [&call](::grpc::ByteBuffer* buffer, ::grpc::Status error_status) { - if (buffer == nullptr) { - return call->SendResponse(error_status); - } - call->response = *buffer; - return call->SendResponse(::grpc::Status()); - }); + + xla::StatusOr> buffer_slices = + tpu::SerializeCacheEntryToBufferSlices(cache_entry); + + if (!buffer_slices.ok()) { + return call->SendResponse(ToGrpcStatus(buffer_slices.status())); + } + + call->response = + ::grpc::ByteBuffer{&buffer_slices.ValueOrDie()[0], buffer_slices->size()}; + return call->SendResponse(::grpc::Status()); } void TpuCompilationCacheService::HandleGetTpuProgram(GetTpuProgramCall* call) { diff --git a/tensorflow/core/tpu/kernels/tpu_program_c_api.h b/tensorflow/core/tpu/kernels/tpu_program_c_api.h index 41c7d47cf97..1918546491a 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_c_api.h +++ b/tensorflow/core/tpu/kernels/tpu_program_c_api.h @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" #include "tensorflow/core/tpu/libtftpu.h" +#include "tensorflow/stream_executor/tpu/c_api_decl.h" #include "tensorflow/stream_executor/tpu/proto_helper.h" typedef struct XLA_TpuProgram XLA_TpuProgram; @@ -24,6 +25,21 @@ typedef struct XLA_TpuProgram XLA_TpuProgram; // Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj. enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding }; +struct TpuExecutableSerializedProto { + const char* bytes; + size_t size; +}; + +struct CompilerMetadataSerializedProto { + const char* bytes; + size_t size; +}; + +struct HostComputeMetadataSerializedProto { + const char* bytes; + size_t size; +}; + extern "C" { // Creates a new TPU program. @@ -67,7 +83,7 @@ TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata( TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables( const XLA_TpuProgram* tpu_program, bool* may_modify_variables); -// Check if TPU program has sharding. +// Checks if TPU program has sharding. TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( const XLA_TpuProgram* tpu_program); @@ -76,6 +92,27 @@ TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram( XLA_TpuProgram* tpu_program, TpuProgramShardingType type); +// Gets TPU executable proto from a `tpu_program`. +TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable( + const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable, + SE_Status* status); + +// Gets compilation metadata proto from a `tpu_program`. +TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata( + const XLA_TpuProgram* tpu_program, + CompilerMetadataSerializedProto* compiler_metadata, SE_Status* status); + +// Gets host transfer metadata proto from a `tpu_program`. +TFTPU_CAPI_EXPORT void TpuProgram_SerializeHostComputeMetadata( + const XLA_TpuProgram* tpu_program, + HostComputeMetadataSerializedProto* host_compute_metadata, + SE_Status* status); + +// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`. +TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto( + TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program, + SE_Status* status); + struct TfTpu_TpuProgramApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free); @@ -90,6 +127,10 @@ struct TfTpu_TpuProgramApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding); TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeHostComputeMetadata); + TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto); }; } // extern "C" diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.cc b/tensorflow/core/tpu/kernels/tpu_program_group.cc index ff7a526cd45..f502426fdaa 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.cc +++ b/tensorflow/core/tpu/kernels/tpu_program_group.cc @@ -29,11 +29,8 @@ limitations under the License. namespace tensorflow { namespace tpu { - namespace { - namespace se_tpu = ::stream_executor::tpu; - using stream_executor::port::Status; using stream_executor::port::StatusOr; using xla::Shape; @@ -195,7 +192,20 @@ void TpuProgramGroup::UnloadAndDestroyPrograms() { tpu_programs_.clear(); } -/*static*/ Status TpuProgramGroup::Build( +/*static*/ +std::unique_ptr TpuProgramGroup::Create(int count) { + auto tpu_program_group = std::make_unique(); + std::vector tpu_programs; + tpu_programs.resize(count); + for (int i = 0; i < count; ++i) { + tpu_programs[i] = TpuProgramApiFn()->TpuProgram_NewFn(); + } + tpu_program_group->set_tpu_programs(tpu_programs); + return tpu_program_group; +} + +/*static*/ +Status TpuProgramGroup::Build( const TPUCompileMetadataProto& metadata, const tensorflow::XlaCompiler::CompilationResult& compilation_result, const std::vector& arg_core_mapping, @@ -283,7 +293,7 @@ Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key, return Status::OK(); } -const std::vector& TpuProgramGroup::may_modify_variables() const { +const std::vector& TpuProgramGroup::may_modify_variables_list() const { return may_modify_variables_; } @@ -292,6 +302,15 @@ void TpuProgramGroup::set_may_modify_variables( may_modify_variables_ = may_modify_variables; } +bool TpuProgramGroup::may_modify_variables(int index) const { + CHECK_GE(index, 0); + CHECK_LT(index, tpu_programs_.size()); + bool may_modify_variables; + TpuProgramApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index], + &may_modify_variables); + return may_modify_variables; +} + const std::vector& TpuProgramGroup::tpu_programs() const { return tpu_programs_; } @@ -371,5 +390,47 @@ std::vector TpuProgramGroup::tpu_programs( } return tpu_programs; } + +Status TpuProgramGroup::DeserializeFromProto(int index, + TpuSerializedProto proto) { + CHECK_GE(index, 0); + CHECK_LT(index, tpu_programs_.size()); + StatusHelper status; + CHECK_NE(tpu_programs_[index], nullptr); + TpuProgramApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn( + proto, tpu_programs_[index], status.c_status); + return status.status(); +} + +Status TpuProgramGroup::SerializeExecutable( + int index, TpuExecutableSerializedProto* executable) const { + CHECK_GE(index, 0); + CHECK_LT(index, tpu_programs_.size()); + StatusHelper status; + TpuProgramApiFn()->TpuProgram_SerializeTpuExecutableFn( + tpu_programs_[index], executable, status.c_status); + return status.status(); +} + +Status TpuProgramGroup::SerializeCompilerMetadata( + int index, CompilerMetadataSerializedProto* compiler_metadata) const { + CHECK_GE(index, 0); + CHECK_LT(index, tpu_programs_.size()); + StatusHelper status; + TpuProgramApiFn()->TpuProgram_SerializeCompilerMetadataFn( + tpu_programs_[index], compiler_metadata, status.c_status); + return status.status(); +} + +Status TpuProgramGroup::SerializeHostComputeMetadata( + int index, + HostComputeMetadataSerializedProto* host_compute_metadata) const { + CHECK_GE(index, 0); + CHECK_LT(index, tpu_programs_.size()); + StatusHelper status; + TpuProgramApiFn()->TpuProgram_SerializeHostComputeMetadataFn( + tpu_programs_[index], host_compute_metadata, status.c_status); + return status.status(); +} } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h index 0fc8bff08de..121596aa2d1 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group.h @@ -15,9 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_ #define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_ +#include #include #include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/service/computation_placer.h" @@ -102,6 +104,9 @@ class TpuProgramGroup : public TpuProgramGroupInterface { const absl::optional& xla_device_assignment, TpuProgramGroupInterface* tpu_program_group_interface); + // Creates the `count` instances of uninitialized `XLA_TpuPrograms`. + static std::unique_ptr Create(int count); + // Initializes `TpuProgramGroup` object with `xla_tpu_programs`. void Initialize(absl::Span xla_tpu_programs); @@ -122,8 +127,9 @@ class TpuProgramGroup : public TpuProgramGroupInterface { Status LogCompilationStats(const TpuCompilationCacheKey& key, absl::Duration duration) override; - const std::vector& may_modify_variables() const override; + const std::vector& may_modify_variables_list() const override; void set_may_modify_variables(const std::vector& may_modify_variables); + bool may_modify_variables(int index) const override; const std::vector& tpu_programs() const; std::vector tpu_programs(TpuProgramShardingType type) const; @@ -137,6 +143,25 @@ class TpuProgramGroup : public TpuProgramGroupInterface { const xla::HloProto* hlo_metadata(int index) const; absl::Span hlo_metadatas() const override; + // Deserializes `GetTpuProgramResponse` proto into an `XLA_TpuProgram` for + // the given core `index`. + Status DeserializeFromProto(int index, TpuSerializedProto proto); + + // Serializes executable proto from the TPU program for the given core + // `index`. + Status SerializeExecutable(int index, + TpuExecutableSerializedProto* executable) const; + + // Serializes compiler metadata of the TPU program for the given core `index`. + Status SerializeCompilerMetadata( + int index, CompilerMetadataSerializedProto* compiler_metadata) const; + + // Serializes host compute metadata of the TPU program for the given core + // `index`. + Status SerializeHostComputeMetadata( + int index, + HostComputeMetadataSerializedProto* host_compute_metadata) const; + private: void RefreshHloMetadatasPtrs(); diff --git a/tensorflow/core/tpu/kernels/tpu_program_group_interface.h b/tensorflow/core/tpu/kernels/tpu_program_group_interface.h index 4af94f8e1ad..8bf4404859f 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group_interface.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group_interface.h @@ -61,7 +61,11 @@ class TpuProgramGroupInterface { // Boolean array to indicate if the modification of variables are // allowed. - virtual const std::vector& may_modify_variables() const = 0; + virtual const std::vector& may_modify_variables_list() const = 0; + + // Gets may modify variables value of the TPU program for the given core + // `index`. + virtual bool may_modify_variables(int index) const = 0; }; } // namespace tpu diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc index 6a1432e27fa..cb8871a60c5 100644 --- a/tensorflow/core/tpu/tpu_library_init_fns.inc +++ b/tensorflow/core/tpu/tpu_library_init_fns.inc @@ -72,6 +72,11 @@ tensorflow::Status SetTpuProgramStructFn(void* library_handle) { TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetMayModifyVariables); TFTPU_SET_FN(tpu_program_fn, TpuProgram_HasSharding); TFTPU_SET_FN(tpu_program_fn, TpuProgram_GetTpuProgram); + TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeTpuExecutable); + TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeCompilerMetadata); + TFTPU_SET_FN(tpu_program_fn, TpuProgram_SerializeHostComputeMetadata); + TFTPU_SET_FN(tpu_program_fn, + TpuProgram_DeserializeFromGetTpuProgramResponseProto); return tensorflow::Status::OK(); } diff --git a/tensorflow/stream_executor/tpu/proto_helper.h b/tensorflow/stream_executor/tpu/proto_helper.h index 29c322b0e9e..cd231e06c22 100644 --- a/tensorflow/stream_executor/tpu/proto_helper.h +++ b/tensorflow/stream_executor/tpu/proto_helper.h @@ -32,12 +32,13 @@ namespace tpu { using SerializedProto = TpuSerializedProto; -// Serializes a proto and put the result in the given SerializedProto* argument. +// Serializes a `proto` and put the result in the given `SerializedProtoType*` +// argument. // // Users should call SerializedProto_Free on `serialized_proto` afterwards. -template -inline void SerializeProto(const Proto& proto, - SerializedProto* serialized_proto) { +template +inline void SerializeProto(const ProtoType& proto, + SerializedProtoType* serialized_proto) { auto size = proto.ByteSizeLong(); auto bytes = new char[size]; CHECK(proto.SerializeToArray(bytes, size)); @@ -48,8 +49,8 @@ inline void SerializeProto(const Proto& proto, // Serializes a proto and return the result as a SerializedProto value. // // Users should call SerializedProto_Free on the return value afterwards. -template -inline SerializedProto SerializeProto(const Proto& proto) { +template +inline SerializedProto SerializeProto(const ProtoType& proto) { SerializedProto serialized_proto; SerializeProto(proto, &serialized_proto); return serialized_proto; @@ -57,9 +58,9 @@ inline SerializedProto SerializeProto(const Proto& proto) { // Deserializes a buffer and return the corresponding proto. If the buffer is // empty, return an empty proto. -template -inline Proto DeserializeProto(const SerializedProto& serialized_proto) { - Proto proto; +template +inline ProtoType DeserializeProto(const SerializedProtoType& serialized_proto) { + ProtoType proto; if (serialized_proto.bytes != nullptr) { CHECK_GT(serialized_proto.size, 0); CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size)) @@ -69,7 +70,8 @@ inline Proto DeserializeProto(const SerializedProto& serialized_proto) { } // Releases the memory allocated for serialized protos. -inline void SerializedProto_Free(const SerializedProto& serialized_proto) { +template +inline void SerializedProto_Free(const SerializedProtoType& serialized_proto) { CHECK_NE(serialized_proto.bytes, nullptr); CHECK_GT(serialized_proto.size, 0); delete[] serialized_proto.bytes;