Refactor tpu function fingerprint to be a separate library

PiperOrigin-RevId: 354124962
Change-Id: I0f159bae5955b04941c20a61b5c916678cfc25cb
This commit is contained in:
A. Unique TensorFlower 2021-01-27 10:51:00 -08:00 committed by TensorFlower Gardener
parent e1f20e517c
commit 100e0d8ed0
6 changed files with 81 additions and 21 deletions

View File

@ -298,3 +298,16 @@ cc_library(
],
alwayslink = True,
)
cc_library(
name = "tpu_fingerprint_utils",
srcs = ["tpu_fingerprint_utils.cc"],
hdrs = ["tpu_fingerprint_utils.h"],
deps = [
":tpu_compile_interface",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/core/lib/core:status",
],
)

View File

@ -159,6 +159,7 @@ cc_library(
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
"//tensorflow/core/tpu:tpu_compile_interface",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/core/tpu:tpu_fingerprint_utils",
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"//tensorflow/stream_executor/tpu:tpu_topology_external",

View File

@ -66,6 +66,7 @@ limitations under the License.
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
#include "tensorflow/core/tpu/tpu_compile_interface.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
@ -4120,23 +4121,6 @@ DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
return Status::OK();
}
/* static */
Status DistributedTPURewritePass::FingerprintFunctionLibrary(
const FunctionLibraryDefinition& library, uint64* fingerprint) {
// TODO(phawkins): rather than fingerprinting the entire function library,
// consider fingerprinting just the transitive dependencies of a
// computation.
std::string serialized;
FunctionDefLibrary library_proto = library.ToProto();
if (library_proto.ByteSizeLong() >= 1.5 * 1024 * 1024 * 1024) {
LOG(WARNING) << "Serializing large proto, size: "
<< library_proto.ByteSizeLong();
}
TF_RET_CHECK(SerializeToStringDeterministic(library_proto, &serialized));
*fingerprint = TpuCompileInterface::Get()->FingerprintString(serialized);
return Status::OK();
}
// Performs the rewrite on a single TPUReplicate node.
/* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
const string& session_handle, const DeviceSet& device_set,

View File

@ -314,10 +314,6 @@ class DistributedTPURewritePass : public GraphOptimizationPass {
std::vector<::xla::OpSharding>* retval_sharding,
std::vector<std::string>* arg_names);
// Computes a fingerprint of the contents of `library`.
static Status FingerprintFunctionLibrary(
const FunctionLibraryDefinition& library, uint64* fingerprint);
// Populates `*variables` with the "variables" inputs to `index`-th output of
// `node`.
struct VariableInput {

View File

@ -0,0 +1,39 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/tpu/tpu_compile_interface.h"
namespace tensorflow {
Status FingerprintFunctionLibrary(const FunctionLibraryDefinition& library,
uint64* fingerprint) {
// TODO(phawkins): rather than fingerprinting the entire function library,
// consider fingerprinting just the transitive dependencies of a
// computation.
std::string serialized;
FunctionDefLibrary library_proto = library.ToProto();
if (library_proto.ByteSizeLong() >= 1.5 * 1024 * 1024 * 1024) {
LOG(WARNING) << "Serializing large proto, size: "
<< library_proto.ByteSizeLong();
}
TF_RET_CHECK(SerializeToStringDeterministic(library_proto, &serialized));
*fingerprint = TpuCompileInterface::Get()->FingerprintString(serialized);
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_
#define TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Computes a fingerprint of the contents of `library`.
Status FingerprintFunctionLibrary(const FunctionLibraryDefinition& library,
uint64* fingerprint);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_