From cd463b2bab5eff7eced13ad1a0d95421ccfc8985 Mon Sep 17 00:00:00 2001 From: Henry Tan Date: Fri, 26 Jun 2020 17:00:06 -0700 Subject: [PATCH] TPU op internal refactor. PiperOrigin-RevId: 318569224 Change-Id: I93bb934321b166c9ce8d7dc3b8275abe215d0e1e --- .../core/tpu/kernels/tpu_compile_c_api.h | 2 +- tensorflow/core/tpu/kernels/tpu_op_util.cc | 33 +++++-------------- tensorflow/core/tpu/kernels/tpu_op_util.h | 7 ---- 3 files changed, 9 insertions(+), 33 deletions(-) diff --git a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h index e82df78b3bd..37de24c339b 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h @@ -35,7 +35,7 @@ struct CompilationCacheKeyProperty { const char* mlir_module; const int32_t* device_ids; size_t device_ids_size; - size_t guaranteed_constants_size; + int32_t guaranteed_constants_size; uint64_t function_library_fingerprint; int32_t num_cores_per_replica; int32_t num_replicas; diff --git a/tensorflow/core/tpu/kernels/tpu_op_util.cc b/tensorflow/core/tpu/kernels/tpu_op_util.cc index 477afac6491..b3b675e2734 100644 --- a/tensorflow/core/tpu/kernels/tpu_op_util.cc +++ b/tensorflow/core/tpu/kernels/tpu_op_util.cc @@ -70,14 +70,12 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) { // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor // data to compute the fingerprint. -std::string GuaranteedConstFingerprint(const string& fingerprint_in_metadata, - const Tensor* guaranteed_constants, - size_t guaranteed_constants_size) { +std::string GuaranteedConstFingerprint( + const string& fingerprint_in_metadata, + const OpInputList& guaranteed_constants) { if (fingerprint_in_metadata.empty()) { uint64_t fingerprint = 0; - for (size_t i = 0; i < guaranteed_constants_size; ++i) { - const Tensor& constant = guaranteed_constants[i]; - // TODO(henrytan): constant.tensor_data() may be uninitialized. + for (const Tensor& constant : guaranteed_constants) { fingerprint = TpuCompile_CreateGuaranteedConstFingerprint( fingerprint, constant.tensor_data().data(), constant.tensor_data().size()); @@ -92,8 +90,7 @@ std::string GuaranteedConstFingerprint(const string& fingerprint_in_metadata, // evaluation of `guaranteed_const_fingerprint()` callback. TpuCompilationCacheKey CreateCompilationCacheKey( absl::string_view function_name, uint64 function_library_fingerprint, - absl::string_view mlir_module, const Tensor* guaranteed_constants, - size_t guaranteed_constants_size, + absl::string_view mlir_module, const OpInputList& guaranteed_constants, const std::vector& dynamic_shapes, const TPUCompileMetadataProto& metadata, const TpuMeshStateInterface& mesh_state) { @@ -119,7 +116,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey( mlir_module.data(), flattened_device_ids.data(), flattened_device_ids.size(), - guaranteed_constants_size, + guaranteed_constants.size(), function_library_fingerprint, metadata.num_cores_per_replica(), metadata.num_replicas(), @@ -133,7 +130,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey( // Guaranteed constants can be different across sessions. Use session_handle // and guaranteed_const fingerprint to guarantee no collision. - if (guaranteed_constants != nullptr && guaranteed_constants_size > 0) { + if (guaranteed_constants.size() > 0) { key.has_guaranteed_const = true; key.session_handle = metadata.session_handle(); // Both `metadata` and `guaranteed_constants` lifetime are captured by @@ -142,29 +139,15 @@ TpuCompilationCacheKey CreateCompilationCacheKey( // lifetime of the compilation cache lookups. string fingerprint; key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants, - guaranteed_constants_size, fingerprint]() mutable { if (fingerprint.empty()) { fingerprint = GuaranteedConstFingerprint( - metadata.guaranteed_const_fingerprint(), guaranteed_constants, - guaranteed_constants_size); + metadata.guaranteed_const_fingerprint(), guaranteed_constants); } return fingerprint; }; } return key; } - -TpuCompilationCacheKey CreateCompilationCacheKey( - absl::string_view function_name, uint64 function_library_fingerprint, - absl::string_view mlir_module, const OpInputList& guaranteed_constants, - const std::vector& dynamic_shapes, - const TPUCompileMetadataProto& metadata, - const TpuMeshStateInterface& mesh_state) { - return CreateCompilationCacheKey( - function_name, function_library_fingerprint, mlir_module, - (guaranteed_constants.size() > 0 ? &guaranteed_constants[0] : nullptr), - guaranteed_constants.size(), dynamic_shapes, metadata, mesh_state); -} } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_op_util.h b/tensorflow/core/tpu/kernels/tpu_op_util.h index bbaa05682e6..0a9657ca05e 100644 --- a/tensorflow/core/tpu/kernels/tpu_op_util.h +++ b/tensorflow/core/tpu/kernels/tpu_op_util.h @@ -34,13 +34,6 @@ TpuCompilationCacheKey CreateCompilationCacheKey( const std::vector& dynamic_shapes, const TPUCompileMetadataProto& metadata, const TpuMeshStateInterface& mesh_state); -TpuCompilationCacheKey CreateCompilationCacheKey( - absl::string_view function_name, uint64 function_library_fingerprint, - absl::string_view mlir_module, const Tensor* guaranteed_constants, - size_t guaranteed_constants_size, - const std::vector& dynamic_shapes, - const TPUCompileMetadataProto& metadata, - const TpuMeshStateInterface& mesh_state); } // namespace tpu } // namespace tensorflow