Avoid serializing large protos in OptimizerCSE::NodeHash.
PiperOrigin-RevId: 289688850 Change-Id: I12076a9b6168f9909f9d045a445fa255e7faac55
This commit is contained in:
parent
83df634c7e
commit
8aed2672fc
@ -38,6 +38,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/graph/optimizer_cse.h"
|
#include "tensorflow/core/graph/optimizer_cse.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -49,6 +50,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -89,38 +91,142 @@ static void FillInputs(const Node* n,
|
|||||||
|
|
||||||
static size_t kIllegalNodeHash = 0;
|
static size_t kIllegalNodeHash = 0;
|
||||||
|
|
||||||
size_t OptimizerCSE::NodeHash(const Node* n) {
|
class Hasher {
|
||||||
const DataTypeVector& out = n->output_types();
|
public:
|
||||||
string str_to_hash = strings::StrCat(n->type_string(), out.size());
|
uint64 hash() { return h_ == kIllegalNodeHash ? kIllegalNodeHash + 1 : h_; }
|
||||||
for (DataType dt : out) {
|
|
||||||
strings::StrAppend(&str_to_hash, dt);
|
void MixString(const string& s) { h_ = Hash64(s.data(), s.size(), h_); }
|
||||||
|
|
||||||
|
void MixInteger(size_t z) { h_ = Hash64Combine(h_, z); }
|
||||||
|
|
||||||
|
void MixProto(const protobuf::MessageLite& msg) {
|
||||||
|
msg.ByteSizeLong(); // Ensure sizes are cached accurately.
|
||||||
|
HashingOutputStream hasher;
|
||||||
|
{
|
||||||
|
// CodedOutputStream doesn't call BackUp until it's destroyed, so we need
|
||||||
|
// it to be destroyed before we call hasher.hash().
|
||||||
|
protobuf::io::CodedOutputStream stream(&hasher);
|
||||||
|
stream.EnableAliasing(true);
|
||||||
|
stream.SetSerializationDeterministic(true);
|
||||||
|
msg.SerializeWithCachedSizes(&stream);
|
||||||
|
}
|
||||||
|
h_ = Hash64Combine(h_, hasher.hash());
|
||||||
}
|
}
|
||||||
|
|
||||||
const int N_in = n->num_inputs();
|
private:
|
||||||
strings::StrAppend(&str_to_hash, N_in);
|
// HashingOutputStream produces the same exact hash as if you serialized the
|
||||||
|
// proto and hashed it sequentially in kBufSize chunks, except it doesn't
|
||||||
|
// manifest the entire proto into memory at any point.
|
||||||
|
class HashingOutputStream : public protobuf::io::ZeroCopyOutputStream {
|
||||||
|
public:
|
||||||
|
// This kBufSize makes sizeof(HashingOutputStream) == 256. It's not chosen
|
||||||
|
// for any particular reason except it's a nice even number of cache lines.
|
||||||
|
static constexpr size_t kBufSize = 228;
|
||||||
|
static constexpr uint64 kDefaultSeed = 2570847921467975139ULL;
|
||||||
|
bool Next(void** data, int* size) override {
|
||||||
|
if (i_ == kBufSize) {
|
||||||
|
// Mix the chunk in.
|
||||||
|
Mix(buf_, kBufSize);
|
||||||
|
*data = buf_;
|
||||||
|
*size = kBufSize;
|
||||||
|
} else {
|
||||||
|
*data = buf_ + i_;
|
||||||
|
*size = kBufSize - i_;
|
||||||
|
}
|
||||||
|
// We always set i_ to be past the end, since we've given the rest of buf_
|
||||||
|
// out.
|
||||||
|
i_ = kBufSize;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BackUp(int count) override { i_ -= count; }
|
||||||
|
|
||||||
|
int64_t ByteCount() const override { return byte_count_; }
|
||||||
|
|
||||||
|
bool WriteAliasedRaw(const void* void_data, int size) override {
|
||||||
|
// We can't do math on void*.
|
||||||
|
const char* data = static_cast<const char*>(void_data);
|
||||||
|
const auto remaining = kBufSize - i_;
|
||||||
|
if (remaining > 0) {
|
||||||
|
if (size < remaining) {
|
||||||
|
memcpy(buf_ + i_, data, size);
|
||||||
|
i_ += size;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
memcpy(buf_ + i_, data, remaining);
|
||||||
|
i_ = kBufSize;
|
||||||
|
data += remaining;
|
||||||
|
size -= remaining;
|
||||||
|
}
|
||||||
|
if (i_ == kBufSize) {
|
||||||
|
Mix(buf_, kBufSize);
|
||||||
|
i_ = 0;
|
||||||
|
}
|
||||||
|
while (size >= kBufSize) {
|
||||||
|
Mix(data, kBufSize);
|
||||||
|
data += kBufSize;
|
||||||
|
size -= kBufSize;
|
||||||
|
}
|
||||||
|
memcpy(buf_, data, size);
|
||||||
|
i_ = size;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AllowsAliasing() const override { return true; }
|
||||||
|
|
||||||
|
uint64 hash() {
|
||||||
|
if (i_ != 0) {
|
||||||
|
Mix(buf_, i_);
|
||||||
|
i_ = 0;
|
||||||
|
}
|
||||||
|
return h_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Mix(const char* p, size_t n) {
|
||||||
|
byte_count_ += n;
|
||||||
|
h_ = Hash64(p, n, h_);
|
||||||
|
}
|
||||||
|
char buf_[kBufSize];
|
||||||
|
int i_ = 0;
|
||||||
|
int64_t byte_count_ = 0;
|
||||||
|
uint64 h_ = kDefaultSeed;
|
||||||
|
};
|
||||||
|
|
||||||
|
uint64 h_ = HashingOutputStream::kDefaultSeed;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t OptimizerCSE::NodeHash(const Node* n) {
|
||||||
|
Hasher hasher;
|
||||||
|
hasher.MixString(n->type_string());
|
||||||
|
hasher.MixInteger(n->output_types().size());
|
||||||
|
for (DataType dt : n->output_types()) {
|
||||||
|
hasher.MixInteger(dt);
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher.MixInteger(n->num_inputs());
|
||||||
gtl::InlinedVector<const Node*, 4> control_edges;
|
gtl::InlinedVector<const Node*, 4> control_edges;
|
||||||
gtl::InlinedVector<std::pair<const Node*, int>, 4> in(N_in);
|
gtl::InlinedVector<std::pair<const Node*, int>, 4> in(n->num_inputs());
|
||||||
FillInputs(n, &control_edges, &in);
|
FillInputs(n, &control_edges, &in);
|
||||||
for (const auto& edge : in) {
|
for (const auto& edge : in) {
|
||||||
strings::StrAppend(&str_to_hash, edge.first->id(), edge.second);
|
hasher.MixInteger(edge.first->id());
|
||||||
|
hasher.MixInteger(edge.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t h = Hash64(str_to_hash);
|
|
||||||
|
|
||||||
#if !defined(__ANDROID__)
|
#if !defined(__ANDROID__)
|
||||||
// Hash the attrs. For example, this makes sure different constants
|
// Hash the attrs. For example, this makes sure different constants
|
||||||
// end up in different hash buckets.
|
// end up in different hash buckets.
|
||||||
string tmp;
|
size_t attr_hashes = 0;
|
||||||
for (const auto& attr : n->attrs()) {
|
for (const auto& attr : n->attrs()) {
|
||||||
tmp = attr.first;
|
Hasher h;
|
||||||
attr.second.AppendToString(&tmp);
|
h.MixString(attr.first);
|
||||||
// Add hashes of attrs, so the order of attrs doesn't matter.
|
h.MixProto(attr.second);
|
||||||
h += Hash32(tmp.data(), tmp.size(), 0x87341245);
|
attr_hashes = Hash64CombineUnordered(attr_hashes, h.hash());
|
||||||
}
|
}
|
||||||
|
hasher.MixInteger(attr_hashes);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1;
|
return hasher.hash();
|
||||||
return h;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool HasRefInput(const Node* n) {
|
static bool HasRefInput(const Node* n) {
|
||||||
|
Loading…
Reference in New Issue
Block a user