Refactor tpu function fingerprint to be a separate library
PiperOrigin-RevId: 354124962 Change-Id: I0f159bae5955b04941c20a61b5c916678cfc25cb
This commit is contained in:
parent
e1f20e517c
commit
100e0d8ed0
@ -298,3 +298,16 @@ cc_library(
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_fingerprint_utils",
|
||||
srcs = ["tpu_fingerprint_utils.cc"],
|
||||
hdrs = ["tpu_fingerprint_utils.h"],
|
||||
deps = [
|
||||
":tpu_compile_interface",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
],
|
||||
)
|
||||
|
@ -159,6 +159,7 @@ cc_library(
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_compile_interface",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu:tpu_fingerprint_utils",
|
||||
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_topology_external",
|
||||
|
@ -66,6 +66,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
|
||||
#include "tensorflow/core/tpu/tpu_compile_interface.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
|
||||
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
@ -4120,23 +4121,6 @@ DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status DistributedTPURewritePass::FingerprintFunctionLibrary(
|
||||
const FunctionLibraryDefinition& library, uint64* fingerprint) {
|
||||
// TODO(phawkins): rather than fingerprinting the entire function library,
|
||||
// consider fingerprinting just the transitive dependencies of a
|
||||
// computation.
|
||||
std::string serialized;
|
||||
FunctionDefLibrary library_proto = library.ToProto();
|
||||
if (library_proto.ByteSizeLong() >= 1.5 * 1024 * 1024 * 1024) {
|
||||
LOG(WARNING) << "Serializing large proto, size: "
|
||||
<< library_proto.ByteSizeLong();
|
||||
}
|
||||
TF_RET_CHECK(SerializeToStringDeterministic(library_proto, &serialized));
|
||||
*fingerprint = TpuCompileInterface::Get()->FingerprintString(serialized);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Performs the rewrite on a single TPUReplicate node.
|
||||
/* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
|
||||
const string& session_handle, const DeviceSet& device_set,
|
||||
|
@ -314,10 +314,6 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
|
||||
std::vector<::xla::OpSharding>* retval_sharding,
|
||||
std::vector<std::string>* arg_names);
|
||||
|
||||
// Computes a fingerprint of the contents of `library`.
|
||||
static Status FingerprintFunctionLibrary(
|
||||
const FunctionLibraryDefinition& library, uint64* fingerprint);
|
||||
|
||||
// Populates `*variables` with the "variables" inputs to `index`-th output of
|
||||
// `node`.
|
||||
struct VariableInput {
|
||||
|
39
tensorflow/core/tpu/tpu_fingerprint_utils.cc
Normal file
39
tensorflow/core/tpu/tpu_fingerprint_utils.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/tpu/tpu_compile_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status FingerprintFunctionLibrary(const FunctionLibraryDefinition& library,
|
||||
uint64* fingerprint) {
|
||||
// TODO(phawkins): rather than fingerprinting the entire function library,
|
||||
// consider fingerprinting just the transitive dependencies of a
|
||||
// computation.
|
||||
std::string serialized;
|
||||
FunctionDefLibrary library_proto = library.ToProto();
|
||||
if (library_proto.ByteSizeLong() >= 1.5 * 1024 * 1024 * 1024) {
|
||||
LOG(WARNING) << "Serializing large proto, size: "
|
||||
<< library_proto.ByteSizeLong();
|
||||
}
|
||||
TF_RET_CHECK(SerializeToStringDeterministic(library_proto, &serialized));
|
||||
*fingerprint = TpuCompileInterface::Get()->FingerprintString(serialized);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
27
tensorflow/core/tpu/tpu_fingerprint_utils.h
Normal file
27
tensorflow/core/tpu/tpu_fingerprint_utils.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_
|
||||
#define TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Computes a fingerprint of the contents of `library`.
|
||||
Status FingerprintFunctionLibrary(const FunctionLibraryDefinition& library,
|
||||
uint64* fingerprint);
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_
|
Loading…
Reference in New Issue
Block a user