Add a strong keyed hash function based on highwayhash's siphash.
Add string_to_hash_bucket_strong to assign hash buckets using the strong keyed hash function. Change: 123080459
This commit is contained in:
parent
989166223c
commit
5040d0daaa
@ -211,6 +211,7 @@ cc_library(
|
||||
"platform/mutex.h",
|
||||
"platform/protobuf.h", # TODO(josh11b): make internal
|
||||
"platform/regexp.h",
|
||||
"platform/strong_hash.h",
|
||||
"platform/thread_annotations.h",
|
||||
"platform/types.h",
|
||||
],
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/strong_hash.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -57,11 +58,14 @@ class LegacyStringToHashBuckeOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(LegacyStringToHashBuckeOp);
|
||||
};
|
||||
|
||||
// StringToHashBucket is deprecated in favor of StringToHashBucketStable.
|
||||
// StringToHashBucket is deprecated in favor of StringToHashBucketFast/Strong.
|
||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU),
|
||||
LegacyStringToHashBuckeOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
|
||||
StringToHashBucketOp<Fingerprint64>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketStrong").Device(DEVICE_CPU),
|
||||
StringToKeyedHashBucketOp<StrongKeyedHash>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -61,6 +61,49 @@ class StringToHashBucketOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
|
||||
};
|
||||
|
||||
template <uint64 hash(const uint64 (&)[2], const string&)>
|
||||
class StringToKeyedHashBucketOp : public OpKernel {
|
||||
public:
|
||||
explicit StringToKeyedHashBucketOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
|
||||
|
||||
std::vector<int64> key;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key));
|
||||
OP_REQUIRES(ctx, key.size() == 2,
|
||||
errors::InvalidArgument("Key must have 2 elements"));
|
||||
std::memcpy(key_, key.data(), sizeof(key_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* input_tensor;
|
||||
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
|
||||
const auto& input_flat = input_tensor->flat<string>();
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("output", input_tensor->shape(),
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<int64>();
|
||||
|
||||
typedef decltype(input_flat.size()) Index;
|
||||
for (Index i = 0; i < input_flat.size(); ++i) {
|
||||
const uint64 input_hash = hash(key_, input_flat(i));
|
||||
const uint64 bucket_id = input_hash % num_buckets_;
|
||||
// The number of buckets is always in the positive range of int64 so is
|
||||
// the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
|
||||
// safe.
|
||||
output_flat(i) = static_cast<int64>(bucket_id);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int64 num_buckets_;
|
||||
uint64 key_[2];
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StringToKeyedHashBucketOp);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_
|
||||
|
@ -26,17 +26,49 @@ Converts each string in the input Tensor to its hash mod by a number of buckets.
|
||||
|
||||
The hash function is deterministic on the content of the string within the
|
||||
process and will never change. However, it is not suitable for cryptography.
|
||||
This function may be used when CPU time is scarce and inputs are trusted or
|
||||
unimportant. There is a risk of adversaries constructing inputs that all hash
|
||||
to the same bucket. To prevent this problem, use a strong hash function with
|
||||
`tf.string_to_hash_bucket_strong`.
|
||||
|
||||
input: The strings to assing a hash bucket.
|
||||
input: The strings to assign a hash bucket.
|
||||
num_buckets: The number of buckets.
|
||||
output: A Tensor of the same shape as the input `string_tensor`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("StringToHashBucketStrong")
|
||||
.Input("input: string")
|
||||
.Output("output: int64")
|
||||
.Attr("num_buckets: int >= 1")
|
||||
.Attr("key: list(int)")
|
||||
.Doc(R"doc(
|
||||
Converts each string in the input Tensor to its hash mod by a number of buckets.
|
||||
|
||||
The hash function is deterministic on the content of the string within the
|
||||
process. The hash function is a keyed hash function, where attribute `key`
|
||||
defines the key of the hash function. `key` is an array of 2 elements.
|
||||
|
||||
A strong hash is important when inputs may be malicious, e.g. URLs with
|
||||
additional components. Adversaries could try to make their inputs hash to the
|
||||
same bucket for a denial-of-service attack or to skew the results. A strong
|
||||
hash prevents this by making it dificult, if not infeasible, to compute inputs
|
||||
that hash to the same bucket. This comes at a cost of roughly 4x higher compute
|
||||
time than tf.string_to_hash_bucket_fast.
|
||||
|
||||
input: The strings to assign a hash bucket.
|
||||
num_buckets: The number of buckets.
|
||||
key: The key for the keyed hash function passed as a list of two uint64
|
||||
elements.
|
||||
output: A Tensor of the same shape as the input `string_tensor`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("StringToHashBucket")
|
||||
.Input("string_tensor: string")
|
||||
.Output("output: int64")
|
||||
.Attr("num_buckets: int >= 1")
|
||||
.Deprecated(10, "Use tf.string_to_hash_bucket_fast()")
|
||||
.Deprecated(10,
|
||||
"Use `tf.string_to_hash_bucket_fast()` or "
|
||||
"`tf.string_to_hash_bucket_strong()`")
|
||||
.Doc(R"doc(
|
||||
Converts each string in the input Tensor to its hash mod by a number of buckets.
|
||||
|
||||
|
@ -50,6 +50,7 @@ cc_library(
|
||||
"@farmhash_archive//:farmhash",
|
||||
"@jpeg_archive//:jpeg",
|
||||
"@png_archive//:png",
|
||||
"@highwayhash//:sip_hash",
|
||||
"@re2//:re2",
|
||||
"//tensorflow/core:protos_cc",
|
||||
],
|
||||
|
30
tensorflow/core/platform/default/strong_hash.h
Normal file
30
tensorflow/core/platform/default/strong_hash.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
|
||||
|
||||
#include "highwayhash/sip_hash.h"
|
||||
#include "highwayhash/state_helpers.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
inline uint64 StrongKeyedHash(const uint64 (&key)[2], const string& s) {
|
||||
return highwayhash::StringHasher<highwayhash::SipHashState>()(key, s);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
|
45
tensorflow/core/platform/strong_hash.h
Normal file
45
tensorflow/core/platform/strong_hash.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
|
||||
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This is a strong keyed hash function interface for strings.
|
||||
// The hash function is deterministic on the content of the string within the
|
||||
// process. The key of the hash is an array of 2 uint64 elements.
|
||||
// A strong hash make it dificult, if not infeasible, to compute inputs that
|
||||
// hash to the same bucket.
|
||||
//
|
||||
// Usage:
|
||||
// uint64 key[2] = {123, 456};
|
||||
// string input = "input string";
|
||||
// uint64 hash_value = StrongKeyedHash(key, input);
|
||||
//
|
||||
uint64 StrongKeyedHash(const uint64 (&)[2], const string&);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#include "tensorflow/core/platform/google/strong_hash.h"
|
||||
#else
|
||||
#include "tensorflow/core/platform/default/strong_hash.h"
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
@ -66,6 +67,30 @@ class StringToHashBucketOpTest(tf.test.TestCase):
|
||||
# Hash64('c') -> 14899841994519054197 -> mod 10 -> 7
|
||||
self.assertAllEqual([8, 0, 7], result)
|
||||
|
||||
def testStringToOneHashBucketStrongOneHashBucket(self):
|
||||
with self.test_session():
|
||||
input_string = tf.constant(['a', 'b', 'c'])
|
||||
output = tf.string_to_hash_bucket_strong(input_string, 1, key=[123, 345])
|
||||
self.assertAllEqual([0, 0, 0], output.eval())
|
||||
|
||||
def testStringToHashBucketsStrong(self):
|
||||
with self.test_session():
|
||||
input_string = tf.constant(['a', 'b', 'c'])
|
||||
output = tf.string_to_hash_bucket_strong(input_string,
|
||||
10,
|
||||
key=[98765, 132])
|
||||
# key = [98765, 132]
|
||||
# StrongKeyedHash(key, 'a') -> 7157389809176466784 -> mod 10 -> 4
|
||||
# StrongKeyedHash(key, 'b') -> 15805638358933211562 -> mod 10 -> 2
|
||||
# StrongKeyedHash(key, 'c') -> 18100027895074076528 -> mod 10 -> 8
|
||||
self.assertAllEqual([4, 2, 8], output.eval())
|
||||
|
||||
def testStringToHashBucketsStrongInvalidKey(self):
|
||||
with self.test_session():
|
||||
input_string = tf.constant(['a', 'b', 'c'])
|
||||
with self.assertRaisesOpError('Key must have 2 elements'):
|
||||
tf.string_to_hash_bucket_strong(input_string, 10, key=[98765]).eval()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
@ -19,6 +19,7 @@ String hashing ops take a string input tensor and map each element to an
|
||||
integer.
|
||||
|
||||
@@string_to_hash_bucket_fast
|
||||
@@string_to_hash_bucket_strong
|
||||
@@string_to_hash_bucket
|
||||
|
||||
## Joining
|
||||
@ -49,10 +50,12 @@ from tensorflow.python.ops.gen_string_ops import *
|
||||
|
||||
ops.NoGradient("StringToHashBucket")
|
||||
ops.NoGradient("StringToHashBucketFast")
|
||||
ops.NoGradient("StringToHashBucketStrong")
|
||||
ops.NoGradient("ReduceJoin")
|
||||
|
||||
ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
|
||||
ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape)
|
||||
ops.RegisterShape("StringToHashBucketStrong")(common_shapes.unchanged_shape)
|
||||
|
||||
|
||||
@ops.RegisterShape("ReduceJoin")
|
||||
|
@ -52,6 +52,13 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
actual = "@farmhash//:farmhash",
|
||||
)
|
||||
|
||||
native.git_repository(
|
||||
name = "highwayhash",
|
||||
remote = "https://github.com/google/highwayhash.git",
|
||||
commit = "be5edafc2e1a455768e260ccd68ae7317b6690ee",
|
||||
init_submodules = True,
|
||||
)
|
||||
|
||||
native.new_http_archive(
|
||||
name = "jpeg_archive",
|
||||
url = "http://www.ijg.org/files/jpegsrc.v9a.tar.gz",
|
||||
|
Loading…
Reference in New Issue
Block a user