Avoid serializing large protos in OptimizerCSE::NodeHash.
PiperOrigin-RevId: 289688850 Change-Id: I12076a9b6168f9909f9d045a445fa255e7faac55
This commit is contained in:
parent
83df634c7e
commit
8aed2672fc
@ -38,6 +38,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/graph/optimizer_cse.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -49,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -89,38 +91,142 @@ static void FillInputs(const Node* n,
|
||||
|
||||
static size_t kIllegalNodeHash = 0;
|
||||
|
||||
size_t OptimizerCSE::NodeHash(const Node* n) {
|
||||
const DataTypeVector& out = n->output_types();
|
||||
string str_to_hash = strings::StrCat(n->type_string(), out.size());
|
||||
for (DataType dt : out) {
|
||||
strings::StrAppend(&str_to_hash, dt);
|
||||
class Hasher {
|
||||
public:
|
||||
uint64 hash() { return h_ == kIllegalNodeHash ? kIllegalNodeHash + 1 : h_; }
|
||||
|
||||
void MixString(const string& s) { h_ = Hash64(s.data(), s.size(), h_); }
|
||||
|
||||
void MixInteger(size_t z) { h_ = Hash64Combine(h_, z); }
|
||||
|
||||
void MixProto(const protobuf::MessageLite& msg) {
|
||||
msg.ByteSizeLong(); // Ensure sizes are cached accurately.
|
||||
HashingOutputStream hasher;
|
||||
{
|
||||
// CodedOutputStream doesn't call BackUp until it's destroyed, so we need
|
||||
// it to be destroyed before we call hasher.hash().
|
||||
protobuf::io::CodedOutputStream stream(&hasher);
|
||||
stream.EnableAliasing(true);
|
||||
stream.SetSerializationDeterministic(true);
|
||||
msg.SerializeWithCachedSizes(&stream);
|
||||
}
|
||||
h_ = Hash64Combine(h_, hasher.hash());
|
||||
}
|
||||
|
||||
const int N_in = n->num_inputs();
|
||||
strings::StrAppend(&str_to_hash, N_in);
|
||||
private:
|
||||
// HashingOutputStream produces the same exact hash as if you serialized the
|
||||
// proto and hashed it sequentially in kBufSize chunks, except it doesn't
|
||||
// manifest the entire proto into memory at any point.
|
||||
class HashingOutputStream : public protobuf::io::ZeroCopyOutputStream {
|
||||
public:
|
||||
// This kBufSize makes sizeof(HashingOutputStream) == 256. It's not chosen
|
||||
// for any particular reason except it's a nice even number of cache lines.
|
||||
static constexpr size_t kBufSize = 228;
|
||||
static constexpr uint64 kDefaultSeed = 2570847921467975139ULL;
|
||||
bool Next(void** data, int* size) override {
|
||||
if (i_ == kBufSize) {
|
||||
// Mix the chunk in.
|
||||
Mix(buf_, kBufSize);
|
||||
*data = buf_;
|
||||
*size = kBufSize;
|
||||
} else {
|
||||
*data = buf_ + i_;
|
||||
*size = kBufSize - i_;
|
||||
}
|
||||
// We always set i_ to be past the end, since we've given the rest of buf_
|
||||
// out.
|
||||
i_ = kBufSize;
|
||||
return true;
|
||||
}
|
||||
|
||||
void BackUp(int count) override { i_ -= count; }
|
||||
|
||||
int64_t ByteCount() const override { return byte_count_; }
|
||||
|
||||
bool WriteAliasedRaw(const void* void_data, int size) override {
|
||||
// We can't do math on void*.
|
||||
const char* data = static_cast<const char*>(void_data);
|
||||
const auto remaining = kBufSize - i_;
|
||||
if (remaining > 0) {
|
||||
if (size < remaining) {
|
||||
memcpy(buf_ + i_, data, size);
|
||||
i_ += size;
|
||||
return true;
|
||||
}
|
||||
memcpy(buf_ + i_, data, remaining);
|
||||
i_ = kBufSize;
|
||||
data += remaining;
|
||||
size -= remaining;
|
||||
}
|
||||
if (i_ == kBufSize) {
|
||||
Mix(buf_, kBufSize);
|
||||
i_ = 0;
|
||||
}
|
||||
while (size >= kBufSize) {
|
||||
Mix(data, kBufSize);
|
||||
data += kBufSize;
|
||||
size -= kBufSize;
|
||||
}
|
||||
memcpy(buf_, data, size);
|
||||
i_ = size;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AllowsAliasing() const override { return true; }
|
||||
|
||||
uint64 hash() {
|
||||
if (i_ != 0) {
|
||||
Mix(buf_, i_);
|
||||
i_ = 0;
|
||||
}
|
||||
return h_;
|
||||
}
|
||||
|
||||
private:
|
||||
void Mix(const char* p, size_t n) {
|
||||
byte_count_ += n;
|
||||
h_ = Hash64(p, n, h_);
|
||||
}
|
||||
char buf_[kBufSize];
|
||||
int i_ = 0;
|
||||
int64_t byte_count_ = 0;
|
||||
uint64 h_ = kDefaultSeed;
|
||||
};
|
||||
|
||||
uint64 h_ = HashingOutputStream::kDefaultSeed;
|
||||
};
|
||||
|
||||
size_t OptimizerCSE::NodeHash(const Node* n) {
|
||||
Hasher hasher;
|
||||
hasher.MixString(n->type_string());
|
||||
hasher.MixInteger(n->output_types().size());
|
||||
for (DataType dt : n->output_types()) {
|
||||
hasher.MixInteger(dt);
|
||||
}
|
||||
|
||||
hasher.MixInteger(n->num_inputs());
|
||||
gtl::InlinedVector<const Node*, 4> control_edges;
|
||||
gtl::InlinedVector<std::pair<const Node*, int>, 4> in(N_in);
|
||||
gtl::InlinedVector<std::pair<const Node*, int>, 4> in(n->num_inputs());
|
||||
FillInputs(n, &control_edges, &in);
|
||||
for (const auto& edge : in) {
|
||||
strings::StrAppend(&str_to_hash, edge.first->id(), edge.second);
|
||||
hasher.MixInteger(edge.first->id());
|
||||
hasher.MixInteger(edge.second);
|
||||
}
|
||||
|
||||
size_t h = Hash64(str_to_hash);
|
||||
|
||||
#if !defined(__ANDROID__)
|
||||
// Hash the attrs. For example, this makes sure different constants
|
||||
// end up in different hash buckets.
|
||||
string tmp;
|
||||
size_t attr_hashes = 0;
|
||||
for (const auto& attr : n->attrs()) {
|
||||
tmp = attr.first;
|
||||
attr.second.AppendToString(&tmp);
|
||||
// Add hashes of attrs, so the order of attrs doesn't matter.
|
||||
h += Hash32(tmp.data(), tmp.size(), 0x87341245);
|
||||
Hasher h;
|
||||
h.MixString(attr.first);
|
||||
h.MixProto(attr.second);
|
||||
attr_hashes = Hash64CombineUnordered(attr_hashes, h.hash());
|
||||
}
|
||||
hasher.MixInteger(attr_hashes);
|
||||
#endif
|
||||
|
||||
if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1;
|
||||
return h;
|
||||
return hasher.hash();
|
||||
}
|
||||
|
||||
static bool HasRefInput(const Node* n) {
|
||||
|
Loading…
Reference in New Issue
Block a user