Add a strong keyed hash function based on highwayhash's siphash.
Add string_to_hash_bucket_strong to assign hash buckets using the strong keyed hash function. Change: 123080459
This commit is contained in:
parent
989166223c
commit
5040d0daaa
@ -211,6 +211,7 @@ cc_library(
|
|||||||
"platform/mutex.h",
|
"platform/mutex.h",
|
||||||
"platform/protobuf.h", # TODO(josh11b): make internal
|
"platform/protobuf.h", # TODO(josh11b): make internal
|
||||||
"platform/regexp.h",
|
"platform/regexp.h",
|
||||||
|
"platform/strong_hash.h",
|
||||||
"platform/thread_annotations.h",
|
"platform/thread_annotations.h",
|
||||||
"platform/types.h",
|
"platform/types.h",
|
||||||
],
|
],
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/platform/fingerprint.h"
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
|
#include "tensorflow/core/platform/strong_hash.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -57,11 +58,14 @@ class LegacyStringToHashBuckeOp : public OpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(LegacyStringToHashBuckeOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(LegacyStringToHashBuckeOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
// StringToHashBucket is deprecated in favor of StringToHashBucketStable.
|
// StringToHashBucket is deprecated in favor of StringToHashBucketFast/Strong.
|
||||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU),
|
||||||
LegacyStringToHashBuckeOp);
|
LegacyStringToHashBuckeOp);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
|
||||||
StringToHashBucketOp<Fingerprint64>);
|
StringToHashBucketOp<Fingerprint64>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketStrong").Device(DEVICE_CPU),
|
||||||
|
StringToKeyedHashBucketOp<StrongKeyedHash>);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -61,6 +61,49 @@ class StringToHashBucketOp : public OpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <uint64 hash(const uint64 (&)[2], const string&)>
|
||||||
|
class StringToKeyedHashBucketOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit StringToKeyedHashBucketOp(OpKernelConstruction* ctx)
|
||||||
|
: OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
|
||||||
|
|
||||||
|
std::vector<int64> key;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key));
|
||||||
|
OP_REQUIRES(ctx, key.size() == 2,
|
||||||
|
errors::InvalidArgument("Key must have 2 elements"));
|
||||||
|
std::memcpy(key_, key.data(), sizeof(key_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor* input_tensor;
|
||||||
|
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
|
||||||
|
const auto& input_flat = input_tensor->flat<string>();
|
||||||
|
|
||||||
|
Tensor* output_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output("output", input_tensor->shape(),
|
||||||
|
&output_tensor));
|
||||||
|
auto output_flat = output_tensor->flat<int64>();
|
||||||
|
|
||||||
|
typedef decltype(input_flat.size()) Index;
|
||||||
|
for (Index i = 0; i < input_flat.size(); ++i) {
|
||||||
|
const uint64 input_hash = hash(key_, input_flat(i));
|
||||||
|
const uint64 bucket_id = input_hash % num_buckets_;
|
||||||
|
// The number of buckets is always in the positive range of int64 so is
|
||||||
|
// the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
|
||||||
|
// safe.
|
||||||
|
output_flat(i) = static_cast<int64>(bucket_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 num_buckets_;
|
||||||
|
uint64 key_[2];
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StringToKeyedHashBucketOp);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_
|
#endif // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_
|
||||||
|
@ -26,17 +26,49 @@ Converts each string in the input Tensor to its hash mod by a number of buckets.
|
|||||||
|
|
||||||
The hash function is deterministic on the content of the string within the
|
The hash function is deterministic on the content of the string within the
|
||||||
process and will never change. However, it is not suitable for cryptography.
|
process and will never change. However, it is not suitable for cryptography.
|
||||||
|
This function may be used when CPU time is scarce and inputs are trusted or
|
||||||
|
unimportant. There is a risk of adversaries constructing inputs that all hash
|
||||||
|
to the same bucket. To prevent this problem, use a strong hash function with
|
||||||
|
`tf.string_to_hash_bucket_strong`.
|
||||||
|
|
||||||
input: The strings to assing a hash bucket.
|
input: The strings to assign a hash bucket.
|
||||||
num_buckets: The number of buckets.
|
num_buckets: The number of buckets.
|
||||||
output: A Tensor of the same shape as the input `string_tensor`.
|
output: A Tensor of the same shape as the input `string_tensor`.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("StringToHashBucketStrong")
|
||||||
|
.Input("input: string")
|
||||||
|
.Output("output: int64")
|
||||||
|
.Attr("num_buckets: int >= 1")
|
||||||
|
.Attr("key: list(int)")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Converts each string in the input Tensor to its hash mod by a number of buckets.
|
||||||
|
|
||||||
|
The hash function is deterministic on the content of the string within the
|
||||||
|
process. The hash function is a keyed hash function, where attribute `key`
|
||||||
|
defines the key of the hash function. `key` is an array of 2 elements.
|
||||||
|
|
||||||
|
A strong hash is important when inputs may be malicious, e.g. URLs with
|
||||||
|
additional components. Adversaries could try to make their inputs hash to the
|
||||||
|
same bucket for a denial-of-service attack or to skew the results. A strong
|
||||||
|
hash prevents this by making it dificult, if not infeasible, to compute inputs
|
||||||
|
that hash to the same bucket. This comes at a cost of roughly 4x higher compute
|
||||||
|
time than tf.string_to_hash_bucket_fast.
|
||||||
|
|
||||||
|
input: The strings to assign a hash bucket.
|
||||||
|
num_buckets: The number of buckets.
|
||||||
|
key: The key for the keyed hash function passed as a list of two uint64
|
||||||
|
elements.
|
||||||
|
output: A Tensor of the same shape as the input `string_tensor`.
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("StringToHashBucket")
|
REGISTER_OP("StringToHashBucket")
|
||||||
.Input("string_tensor: string")
|
.Input("string_tensor: string")
|
||||||
.Output("output: int64")
|
.Output("output: int64")
|
||||||
.Attr("num_buckets: int >= 1")
|
.Attr("num_buckets: int >= 1")
|
||||||
.Deprecated(10, "Use tf.string_to_hash_bucket_fast()")
|
.Deprecated(10,
|
||||||
|
"Use `tf.string_to_hash_bucket_fast()` or "
|
||||||
|
"`tf.string_to_hash_bucket_strong()`")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Converts each string in the input Tensor to its hash mod by a number of buckets.
|
Converts each string in the input Tensor to its hash mod by a number of buckets.
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ cc_library(
|
|||||||
"@farmhash_archive//:farmhash",
|
"@farmhash_archive//:farmhash",
|
||||||
"@jpeg_archive//:jpeg",
|
"@jpeg_archive//:jpeg",
|
||||||
"@png_archive//:png",
|
"@png_archive//:png",
|
||||||
|
"@highwayhash//:sip_hash",
|
||||||
"@re2//:re2",
|
"@re2//:re2",
|
||||||
"//tensorflow/core:protos_cc",
|
"//tensorflow/core:protos_cc",
|
||||||
],
|
],
|
||||||
|
30
tensorflow/core/platform/default/strong_hash.h
Normal file
30
tensorflow/core/platform/default/strong_hash.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
|
||||||
|
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
|
||||||
|
|
||||||
|
#include "highwayhash/sip_hash.h"
|
||||||
|
#include "highwayhash/state_helpers.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
inline uint64 StrongKeyedHash(const uint64 (&key)[2], const string& s) {
|
||||||
|
return highwayhash::StringHasher<highwayhash::SipHashState>()(key, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
|
45
tensorflow/core/platform/strong_hash.h
Normal file
45
tensorflow/core/platform/strong_hash.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
|
||||||
|
#define TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/platform.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// This is a strong keyed hash function interface for strings.
|
||||||
|
// The hash function is deterministic on the content of the string within the
|
||||||
|
// process. The key of the hash is an array of 2 uint64 elements.
|
||||||
|
// A strong hash make it dificult, if not infeasible, to compute inputs that
|
||||||
|
// hash to the same bucket.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
// uint64 key[2] = {123, 456};
|
||||||
|
// string input = "input string";
|
||||||
|
// uint64 hash_value = StrongKeyedHash(key, input);
|
||||||
|
//
|
||||||
|
uint64 StrongKeyedHash(const uint64 (&)[2], const string&);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
#include "tensorflow/core/platform/google/strong_hash.h"
|
||||||
|
#else
|
||||||
|
#include "tensorflow/core/platform/default/strong_hash.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
@ -66,6 +67,30 @@ class StringToHashBucketOpTest(tf.test.TestCase):
|
|||||||
# Hash64('c') -> 14899841994519054197 -> mod 10 -> 7
|
# Hash64('c') -> 14899841994519054197 -> mod 10 -> 7
|
||||||
self.assertAllEqual([8, 0, 7], result)
|
self.assertAllEqual([8, 0, 7], result)
|
||||||
|
|
||||||
|
def testStringToOneHashBucketStrongOneHashBucket(self):
|
||||||
|
with self.test_session():
|
||||||
|
input_string = tf.constant(['a', 'b', 'c'])
|
||||||
|
output = tf.string_to_hash_bucket_strong(input_string, 1, key=[123, 345])
|
||||||
|
self.assertAllEqual([0, 0, 0], output.eval())
|
||||||
|
|
||||||
|
def testStringToHashBucketsStrong(self):
|
||||||
|
with self.test_session():
|
||||||
|
input_string = tf.constant(['a', 'b', 'c'])
|
||||||
|
output = tf.string_to_hash_bucket_strong(input_string,
|
||||||
|
10,
|
||||||
|
key=[98765, 132])
|
||||||
|
# key = [98765, 132]
|
||||||
|
# StrongKeyedHash(key, 'a') -> 7157389809176466784 -> mod 10 -> 4
|
||||||
|
# StrongKeyedHash(key, 'b') -> 15805638358933211562 -> mod 10 -> 2
|
||||||
|
# StrongKeyedHash(key, 'c') -> 18100027895074076528 -> mod 10 -> 8
|
||||||
|
self.assertAllEqual([4, 2, 8], output.eval())
|
||||||
|
|
||||||
|
def testStringToHashBucketsStrongInvalidKey(self):
|
||||||
|
with self.test_session():
|
||||||
|
input_string = tf.constant(['a', 'b', 'c'])
|
||||||
|
with self.assertRaisesOpError('Key must have 2 elements'):
|
||||||
|
tf.string_to_hash_bucket_strong(input_string, 10, key=[98765]).eval()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -19,6 +19,7 @@ String hashing ops take a string input tensor and map each element to an
|
|||||||
integer.
|
integer.
|
||||||
|
|
||||||
@@string_to_hash_bucket_fast
|
@@string_to_hash_bucket_fast
|
||||||
|
@@string_to_hash_bucket_strong
|
||||||
@@string_to_hash_bucket
|
@@string_to_hash_bucket
|
||||||
|
|
||||||
## Joining
|
## Joining
|
||||||
@ -49,10 +50,12 @@ from tensorflow.python.ops.gen_string_ops import *
|
|||||||
|
|
||||||
ops.NoGradient("StringToHashBucket")
|
ops.NoGradient("StringToHashBucket")
|
||||||
ops.NoGradient("StringToHashBucketFast")
|
ops.NoGradient("StringToHashBucketFast")
|
||||||
|
ops.NoGradient("StringToHashBucketStrong")
|
||||||
ops.NoGradient("ReduceJoin")
|
ops.NoGradient("ReduceJoin")
|
||||||
|
|
||||||
ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
|
ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
|
||||||
ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape)
|
ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape)
|
||||||
|
ops.RegisterShape("StringToHashBucketStrong")(common_shapes.unchanged_shape)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("ReduceJoin")
|
@ops.RegisterShape("ReduceJoin")
|
||||||
|
@ -52,6 +52,13 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
|||||||
actual = "@farmhash//:farmhash",
|
actual = "@farmhash//:farmhash",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
native.git_repository(
|
||||||
|
name = "highwayhash",
|
||||||
|
remote = "https://github.com/google/highwayhash.git",
|
||||||
|
commit = "be5edafc2e1a455768e260ccd68ae7317b6690ee",
|
||||||
|
init_submodules = True,
|
||||||
|
)
|
||||||
|
|
||||||
native.new_http_archive(
|
native.new_http_archive(
|
||||||
name = "jpeg_archive",
|
name = "jpeg_archive",
|
||||||
url = "http://www.ijg.org/files/jpegsrc.v9a.tar.gz",
|
url = "http://www.ijg.org/files/jpegsrc.v9a.tar.gz",
|
||||||
|
Loading…
Reference in New Issue
Block a user