TPU op internal refactor.

PiperOrigin-RevId: 318569224
Change-Id: I93bb934321b166c9ce8d7dc3b8275abe215d0e1e
This commit is contained in:
Henry Tan 2020-06-26 17:00:06 -07:00 committed by TensorFlower Gardener
parent eba2e3ca1c
commit cd463b2bab
3 changed files with 9 additions and 33 deletions

View File

@ -35,7 +35,7 @@ struct CompilationCacheKeyProperty {
const char* mlir_module; const char* mlir_module;
const int32_t* device_ids; const int32_t* device_ids;
size_t device_ids_size; size_t device_ids_size;
size_t guaranteed_constants_size; int32_t guaranteed_constants_size;
uint64_t function_library_fingerprint; uint64_t function_library_fingerprint;
int32_t num_cores_per_replica; int32_t num_cores_per_replica;
int32_t num_replicas; int32_t num_replicas;

View File

@ -70,14 +70,12 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
// data to compute the fingerprint. // data to compute the fingerprint.
std::string GuaranteedConstFingerprint(const string& fingerprint_in_metadata, std::string GuaranteedConstFingerprint(
const Tensor* guaranteed_constants, const string& fingerprint_in_metadata,
size_t guaranteed_constants_size) { const OpInputList& guaranteed_constants) {
if (fingerprint_in_metadata.empty()) { if (fingerprint_in_metadata.empty()) {
uint64_t fingerprint = 0; uint64_t fingerprint = 0;
for (size_t i = 0; i < guaranteed_constants_size; ++i) { for (const Tensor& constant : guaranteed_constants) {
const Tensor& constant = guaranteed_constants[i];
// TODO(henrytan): constant.tensor_data() may be uninitialized.
fingerprint = TpuCompile_CreateGuaranteedConstFingerprint( fingerprint = TpuCompile_CreateGuaranteedConstFingerprint(
fingerprint, constant.tensor_data().data(), fingerprint, constant.tensor_data().data(),
constant.tensor_data().size()); constant.tensor_data().size());
@ -92,8 +90,7 @@ std::string GuaranteedConstFingerprint(const string& fingerprint_in_metadata,
// evaluation of `guaranteed_const_fingerprint()` callback. // evaluation of `guaranteed_const_fingerprint()` callback.
TpuCompilationCacheKey CreateCompilationCacheKey( TpuCompilationCacheKey CreateCompilationCacheKey(
absl::string_view function_name, uint64 function_library_fingerprint, absl::string_view function_name, uint64 function_library_fingerprint,
absl::string_view mlir_module, const Tensor* guaranteed_constants, absl::string_view mlir_module, const OpInputList& guaranteed_constants,
size_t guaranteed_constants_size,
const std::vector<TensorShape>& dynamic_shapes, const std::vector<TensorShape>& dynamic_shapes,
const TPUCompileMetadataProto& metadata, const TPUCompileMetadataProto& metadata,
const TpuMeshStateInterface& mesh_state) { const TpuMeshStateInterface& mesh_state) {
@ -119,7 +116,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
mlir_module.data(), mlir_module.data(),
flattened_device_ids.data(), flattened_device_ids.data(),
flattened_device_ids.size(), flattened_device_ids.size(),
guaranteed_constants_size, guaranteed_constants.size(),
function_library_fingerprint, function_library_fingerprint,
metadata.num_cores_per_replica(), metadata.num_cores_per_replica(),
metadata.num_replicas(), metadata.num_replicas(),
@ -133,7 +130,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
// Guaranteed constants can be different across sessions. Use session_handle // Guaranteed constants can be different across sessions. Use session_handle
// and guaranteed_const fingerprint to guarantee no collision. // and guaranteed_const fingerprint to guarantee no collision.
if (guaranteed_constants != nullptr && guaranteed_constants_size > 0) { if (guaranteed_constants.size() > 0) {
key.has_guaranteed_const = true; key.has_guaranteed_const = true;
key.session_handle = metadata.session_handle(); key.session_handle = metadata.session_handle();
// Both `metadata` and `guaranteed_constants` lifetime are captured by // Both `metadata` and `guaranteed_constants` lifetime are captured by
@ -142,29 +139,15 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
// lifetime of the compilation cache lookups. // lifetime of the compilation cache lookups.
string fingerprint; string fingerprint;
key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants, key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
guaranteed_constants_size,
fingerprint]() mutable { fingerprint]() mutable {
if (fingerprint.empty()) { if (fingerprint.empty()) {
fingerprint = GuaranteedConstFingerprint( fingerprint = GuaranteedConstFingerprint(
metadata.guaranteed_const_fingerprint(), guaranteed_constants, metadata.guaranteed_const_fingerprint(), guaranteed_constants);
guaranteed_constants_size);
} }
return fingerprint; return fingerprint;
}; };
} }
return key; return key;
} }
TpuCompilationCacheKey CreateCompilationCacheKey(
absl::string_view function_name, uint64 function_library_fingerprint,
absl::string_view mlir_module, const OpInputList& guaranteed_constants,
const std::vector<TensorShape>& dynamic_shapes,
const TPUCompileMetadataProto& metadata,
const TpuMeshStateInterface& mesh_state) {
return CreateCompilationCacheKey(
function_name, function_library_fingerprint, mlir_module,
(guaranteed_constants.size() > 0 ? &guaranteed_constants[0] : nullptr),
guaranteed_constants.size(), dynamic_shapes, metadata, mesh_state);
}
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow

View File

@ -34,13 +34,6 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
const std::vector<TensorShape>& dynamic_shapes, const std::vector<TensorShape>& dynamic_shapes,
const TPUCompileMetadataProto& metadata, const TPUCompileMetadataProto& metadata,
const TpuMeshStateInterface& mesh_state); const TpuMeshStateInterface& mesh_state);
TpuCompilationCacheKey CreateCompilationCacheKey(
absl::string_view function_name, uint64 function_library_fingerprint,
absl::string_view mlir_module, const Tensor* guaranteed_constants,
size_t guaranteed_constants_size,
const std::vector<TensorShape>& dynamic_shapes,
const TPUCompileMetadataProto& metadata,
const TpuMeshStateInterface& mesh_state);
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow