TPU op internal refactor.
PiperOrigin-RevId: 318569224 Change-Id: I93bb934321b166c9ce8d7dc3b8275abe215d0e1e
This commit is contained in:
parent
eba2e3ca1c
commit
cd463b2bab
@ -35,7 +35,7 @@ struct CompilationCacheKeyProperty {
|
||||
const char* mlir_module;
|
||||
const int32_t* device_ids;
|
||||
size_t device_ids_size;
|
||||
size_t guaranteed_constants_size;
|
||||
int32_t guaranteed_constants_size;
|
||||
uint64_t function_library_fingerprint;
|
||||
int32_t num_cores_per_replica;
|
||||
int32_t num_replicas;
|
||||
|
@ -70,14 +70,12 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
|
||||
|
||||
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
|
||||
// data to compute the fingerprint.
|
||||
std::string GuaranteedConstFingerprint(const string& fingerprint_in_metadata,
|
||||
const Tensor* guaranteed_constants,
|
||||
size_t guaranteed_constants_size) {
|
||||
std::string GuaranteedConstFingerprint(
|
||||
const string& fingerprint_in_metadata,
|
||||
const OpInputList& guaranteed_constants) {
|
||||
if (fingerprint_in_metadata.empty()) {
|
||||
uint64_t fingerprint = 0;
|
||||
for (size_t i = 0; i < guaranteed_constants_size; ++i) {
|
||||
const Tensor& constant = guaranteed_constants[i];
|
||||
// TODO(henrytan): constant.tensor_data() may be uninitialized.
|
||||
for (const Tensor& constant : guaranteed_constants) {
|
||||
fingerprint = TpuCompile_CreateGuaranteedConstFingerprint(
|
||||
fingerprint, constant.tensor_data().data(),
|
||||
constant.tensor_data().size());
|
||||
@ -92,8 +90,7 @@ std::string GuaranteedConstFingerprint(const string& fingerprint_in_metadata,
|
||||
// evaluation of `guaranteed_const_fingerprint()` callback.
|
||||
TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
||||
absl::string_view mlir_module, const Tensor* guaranteed_constants,
|
||||
size_t guaranteed_constants_size,
|
||||
absl::string_view mlir_module, const OpInputList& guaranteed_constants,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const TpuMeshStateInterface& mesh_state) {
|
||||
@ -119,7 +116,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
mlir_module.data(),
|
||||
flattened_device_ids.data(),
|
||||
flattened_device_ids.size(),
|
||||
guaranteed_constants_size,
|
||||
guaranteed_constants.size(),
|
||||
function_library_fingerprint,
|
||||
metadata.num_cores_per_replica(),
|
||||
metadata.num_replicas(),
|
||||
@ -133,7 +130,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
|
||||
// Guaranteed constants can be different across sessions. Use session_handle
|
||||
// and guaranteed_const fingerprint to guarantee no collision.
|
||||
if (guaranteed_constants != nullptr && guaranteed_constants_size > 0) {
|
||||
if (guaranteed_constants.size() > 0) {
|
||||
key.has_guaranteed_const = true;
|
||||
key.session_handle = metadata.session_handle();
|
||||
// Both `metadata` and `guaranteed_constants` lifetime are captured by
|
||||
@ -142,29 +139,15 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
// lifetime of the compilation cache lookups.
|
||||
string fingerprint;
|
||||
key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
|
||||
guaranteed_constants_size,
|
||||
fingerprint]() mutable {
|
||||
if (fingerprint.empty()) {
|
||||
fingerprint = GuaranteedConstFingerprint(
|
||||
metadata.guaranteed_const_fingerprint(), guaranteed_constants,
|
||||
guaranteed_constants_size);
|
||||
metadata.guaranteed_const_fingerprint(), guaranteed_constants);
|
||||
}
|
||||
return fingerprint;
|
||||
};
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
||||
absl::string_view mlir_module, const OpInputList& guaranteed_constants,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const TpuMeshStateInterface& mesh_state) {
|
||||
return CreateCompilationCacheKey(
|
||||
function_name, function_library_fingerprint, mlir_module,
|
||||
(guaranteed_constants.size() > 0 ? &guaranteed_constants[0] : nullptr),
|
||||
guaranteed_constants.size(), dynamic_shapes, metadata, mesh_state);
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -34,13 +34,6 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const TpuMeshStateInterface& mesh_state);
|
||||
TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
||||
absl::string_view mlir_module, const Tensor* guaranteed_constants,
|
||||
size_t guaranteed_constants_size,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const TpuMeshStateInterface& mesh_state);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user