TPU op internal refactor.
PiperOrigin-RevId: 318569224 Change-Id: I93bb934321b166c9ce8d7dc3b8275abe215d0e1e
This commit is contained in:
parent
eba2e3ca1c
commit
cd463b2bab
@ -35,7 +35,7 @@ struct CompilationCacheKeyProperty {
|
|||||||
const char* mlir_module;
|
const char* mlir_module;
|
||||||
const int32_t* device_ids;
|
const int32_t* device_ids;
|
||||||
size_t device_ids_size;
|
size_t device_ids_size;
|
||||||
size_t guaranteed_constants_size;
|
int32_t guaranteed_constants_size;
|
||||||
uint64_t function_library_fingerprint;
|
uint64_t function_library_fingerprint;
|
||||||
int32_t num_cores_per_replica;
|
int32_t num_cores_per_replica;
|
||||||
int32_t num_replicas;
|
int32_t num_replicas;
|
||||||
|
@ -70,14 +70,12 @@ std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
|
|||||||
|
|
||||||
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
|
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
|
||||||
// data to compute the fingerprint.
|
// data to compute the fingerprint.
|
||||||
std::string GuaranteedConstFingerprint(const string& fingerprint_in_metadata,
|
std::string GuaranteedConstFingerprint(
|
||||||
const Tensor* guaranteed_constants,
|
const string& fingerprint_in_metadata,
|
||||||
size_t guaranteed_constants_size) {
|
const OpInputList& guaranteed_constants) {
|
||||||
if (fingerprint_in_metadata.empty()) {
|
if (fingerprint_in_metadata.empty()) {
|
||||||
uint64_t fingerprint = 0;
|
uint64_t fingerprint = 0;
|
||||||
for (size_t i = 0; i < guaranteed_constants_size; ++i) {
|
for (const Tensor& constant : guaranteed_constants) {
|
||||||
const Tensor& constant = guaranteed_constants[i];
|
|
||||||
// TODO(henrytan): constant.tensor_data() may be uninitialized.
|
|
||||||
fingerprint = TpuCompile_CreateGuaranteedConstFingerprint(
|
fingerprint = TpuCompile_CreateGuaranteedConstFingerprint(
|
||||||
fingerprint, constant.tensor_data().data(),
|
fingerprint, constant.tensor_data().data(),
|
||||||
constant.tensor_data().size());
|
constant.tensor_data().size());
|
||||||
@ -92,8 +90,7 @@ std::string GuaranteedConstFingerprint(const string& fingerprint_in_metadata,
|
|||||||
// evaluation of `guaranteed_const_fingerprint()` callback.
|
// evaluation of `guaranteed_const_fingerprint()` callback.
|
||||||
TpuCompilationCacheKey CreateCompilationCacheKey(
|
TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
absl::string_view function_name, uint64 function_library_fingerprint,
|
||||||
absl::string_view mlir_module, const Tensor* guaranteed_constants,
|
absl::string_view mlir_module, const OpInputList& guaranteed_constants,
|
||||||
size_t guaranteed_constants_size,
|
|
||||||
const std::vector<TensorShape>& dynamic_shapes,
|
const std::vector<TensorShape>& dynamic_shapes,
|
||||||
const TPUCompileMetadataProto& metadata,
|
const TPUCompileMetadataProto& metadata,
|
||||||
const TpuMeshStateInterface& mesh_state) {
|
const TpuMeshStateInterface& mesh_state) {
|
||||||
@ -119,7 +116,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
|||||||
mlir_module.data(),
|
mlir_module.data(),
|
||||||
flattened_device_ids.data(),
|
flattened_device_ids.data(),
|
||||||
flattened_device_ids.size(),
|
flattened_device_ids.size(),
|
||||||
guaranteed_constants_size,
|
guaranteed_constants.size(),
|
||||||
function_library_fingerprint,
|
function_library_fingerprint,
|
||||||
metadata.num_cores_per_replica(),
|
metadata.num_cores_per_replica(),
|
||||||
metadata.num_replicas(),
|
metadata.num_replicas(),
|
||||||
@ -133,7 +130,7 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
|||||||
|
|
||||||
// Guaranteed constants can be different across sessions. Use session_handle
|
// Guaranteed constants can be different across sessions. Use session_handle
|
||||||
// and guaranteed_const fingerprint to guarantee no collision.
|
// and guaranteed_const fingerprint to guarantee no collision.
|
||||||
if (guaranteed_constants != nullptr && guaranteed_constants_size > 0) {
|
if (guaranteed_constants.size() > 0) {
|
||||||
key.has_guaranteed_const = true;
|
key.has_guaranteed_const = true;
|
||||||
key.session_handle = metadata.session_handle();
|
key.session_handle = metadata.session_handle();
|
||||||
// Both `metadata` and `guaranteed_constants` lifetime are captured by
|
// Both `metadata` and `guaranteed_constants` lifetime are captured by
|
||||||
@ -142,29 +139,15 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
|||||||
// lifetime of the compilation cache lookups.
|
// lifetime of the compilation cache lookups.
|
||||||
string fingerprint;
|
string fingerprint;
|
||||||
key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
|
key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
|
||||||
guaranteed_constants_size,
|
|
||||||
fingerprint]() mutable {
|
fingerprint]() mutable {
|
||||||
if (fingerprint.empty()) {
|
if (fingerprint.empty()) {
|
||||||
fingerprint = GuaranteedConstFingerprint(
|
fingerprint = GuaranteedConstFingerprint(
|
||||||
metadata.guaranteed_const_fingerprint(), guaranteed_constants,
|
metadata.guaranteed_const_fingerprint(), guaranteed_constants);
|
||||||
guaranteed_constants_size);
|
|
||||||
}
|
}
|
||||||
return fingerprint;
|
return fingerprint;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return key;
|
return key;
|
||||||
}
|
}
|
||||||
|
|
||||||
TpuCompilationCacheKey CreateCompilationCacheKey(
|
|
||||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
|
||||||
absl::string_view mlir_module, const OpInputList& guaranteed_constants,
|
|
||||||
const std::vector<TensorShape>& dynamic_shapes,
|
|
||||||
const TPUCompileMetadataProto& metadata,
|
|
||||||
const TpuMeshStateInterface& mesh_state) {
|
|
||||||
return CreateCompilationCacheKey(
|
|
||||||
function_name, function_library_fingerprint, mlir_module,
|
|
||||||
(guaranteed_constants.size() > 0 ? &guaranteed_constants[0] : nullptr),
|
|
||||||
guaranteed_constants.size(), dynamic_shapes, metadata, mesh_state);
|
|
||||||
}
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -34,13 +34,6 @@ TpuCompilationCacheKey CreateCompilationCacheKey(
|
|||||||
const std::vector<TensorShape>& dynamic_shapes,
|
const std::vector<TensorShape>& dynamic_shapes,
|
||||||
const TPUCompileMetadataProto& metadata,
|
const TPUCompileMetadataProto& metadata,
|
||||||
const TpuMeshStateInterface& mesh_state);
|
const TpuMeshStateInterface& mesh_state);
|
||||||
TpuCompilationCacheKey CreateCompilationCacheKey(
|
|
||||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
|
||||||
absl::string_view mlir_module, const Tensor* guaranteed_constants,
|
|
||||||
size_t guaranteed_constants_size,
|
|
||||||
const std::vector<TensorShape>& dynamic_shapes,
|
|
||||||
const TPUCompileMetadataProto& metadata,
|
|
||||||
const TpuMeshStateInterface& mesh_state);
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user