Avoid serializing large protos in OptimizerCSE::NodeHash.

PiperOrigin-RevId: 289688850
Change-Id: I12076a9b6168f9909f9d045a445fa255e7faac55
This commit is contained in:
A. Unique TensorFlower 2020-01-14 11:07:38 -08:00 committed by TensorFlower Gardener
parent 83df634c7e
commit 8aed2672fc

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/graph/optimizer_cse.h"
#include <iostream>
#include <unordered_map>
#include <utility>
#include <vector>
@ -49,6 +50,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@ -89,38 +91,142 @@ static void FillInputs(const Node* n,
static size_t kIllegalNodeHash = 0;
size_t OptimizerCSE::NodeHash(const Node* n) {
const DataTypeVector& out = n->output_types();
string str_to_hash = strings::StrCat(n->type_string(), out.size());
for (DataType dt : out) {
strings::StrAppend(&str_to_hash, dt);
class Hasher {
public:
uint64 hash() { return h_ == kIllegalNodeHash ? kIllegalNodeHash + 1 : h_; }
void MixString(const string& s) { h_ = Hash64(s.data(), s.size(), h_); }
void MixInteger(size_t z) { h_ = Hash64Combine(h_, z); }
void MixProto(const protobuf::MessageLite& msg) {
msg.ByteSizeLong(); // Ensure sizes are cached accurately.
HashingOutputStream hasher;
{
// CodedOutputStream doesn't call BackUp until it's destroyed, so we need
// it to be destroyed before we call hasher.hash().
protobuf::io::CodedOutputStream stream(&hasher);
stream.EnableAliasing(true);
stream.SetSerializationDeterministic(true);
msg.SerializeWithCachedSizes(&stream);
}
h_ = Hash64Combine(h_, hasher.hash());
}
const int N_in = n->num_inputs();
strings::StrAppend(&str_to_hash, N_in);
private:
// HashingOutputStream produces the same exact hash as if you serialized the
// proto and hashed it sequentially in kBufSize chunks, except it doesn't
// manifest the entire proto into memory at any point.
class HashingOutputStream : public protobuf::io::ZeroCopyOutputStream {
public:
// This kBufSize makes sizeof(HashingOutputStream) == 256. It's not chosen
// for any particular reason except it's a nice even number of cache lines.
static constexpr size_t kBufSize = 228;
static constexpr uint64 kDefaultSeed = 2570847921467975139ULL;
bool Next(void** data, int* size) override {
if (i_ == kBufSize) {
// Mix the chunk in.
Mix(buf_, kBufSize);
*data = buf_;
*size = kBufSize;
} else {
*data = buf_ + i_;
*size = kBufSize - i_;
}
// We always set i_ to be past the end, since we've given the rest of buf_
// out.
i_ = kBufSize;
return true;
}
void BackUp(int count) override { i_ -= count; }
int64_t ByteCount() const override { return byte_count_; }
bool WriteAliasedRaw(const void* void_data, int size) override {
// We can't do math on void*.
const char* data = static_cast<const char*>(void_data);
const auto remaining = kBufSize - i_;
if (remaining > 0) {
if (size < remaining) {
memcpy(buf_ + i_, data, size);
i_ += size;
return true;
}
memcpy(buf_ + i_, data, remaining);
i_ = kBufSize;
data += remaining;
size -= remaining;
}
if (i_ == kBufSize) {
Mix(buf_, kBufSize);
i_ = 0;
}
while (size >= kBufSize) {
Mix(data, kBufSize);
data += kBufSize;
size -= kBufSize;
}
memcpy(buf_, data, size);
i_ = size;
return true;
}
bool AllowsAliasing() const override { return true; }
uint64 hash() {
if (i_ != 0) {
Mix(buf_, i_);
i_ = 0;
}
return h_;
}
private:
void Mix(const char* p, size_t n) {
byte_count_ += n;
h_ = Hash64(p, n, h_);
}
char buf_[kBufSize];
int i_ = 0;
int64_t byte_count_ = 0;
uint64 h_ = kDefaultSeed;
};
uint64 h_ = HashingOutputStream::kDefaultSeed;
};
size_t OptimizerCSE::NodeHash(const Node* n) {
Hasher hasher;
hasher.MixString(n->type_string());
hasher.MixInteger(n->output_types().size());
for (DataType dt : n->output_types()) {
hasher.MixInteger(dt);
}
hasher.MixInteger(n->num_inputs());
gtl::InlinedVector<const Node*, 4> control_edges;
gtl::InlinedVector<std::pair<const Node*, int>, 4> in(N_in);
gtl::InlinedVector<std::pair<const Node*, int>, 4> in(n->num_inputs());
FillInputs(n, &control_edges, &in);
for (const auto& edge : in) {
strings::StrAppend(&str_to_hash, edge.first->id(), edge.second);
hasher.MixInteger(edge.first->id());
hasher.MixInteger(edge.second);
}
size_t h = Hash64(str_to_hash);
#if !defined(__ANDROID__)
// Hash the attrs. For example, this makes sure different constants
// end up in different hash buckets.
string tmp;
size_t attr_hashes = 0;
for (const auto& attr : n->attrs()) {
tmp = attr.first;
attr.second.AppendToString(&tmp);
// Add hashes of attrs, so the order of attrs doesn't matter.
h += Hash32(tmp.data(), tmp.size(), 0x87341245);
Hasher h;
h.MixString(attr.first);
h.MixProto(attr.second);
attr_hashes = Hash64CombineUnordered(attr_hashes, h.hash());
}
hasher.MixInteger(attr_hashes);
#endif
if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1;
return h;
return hasher.hash();
}
static bool HasRefInput(const Node* n) {