Add a strong keyed hash function based on highwayhash's siphash.

Add string_to_hash_bucket_strong to assign hash buckets using the strong keyed hash function.
Change: 123080459
This commit is contained in:
Yutaka Leon 2016-05-24 00:00:43 -08:00 committed by TensorFlower Gardener
parent 989166223c
commit 5040d0daaa
10 changed files with 194 additions and 3 deletions

View File

@ -211,6 +211,7 @@ cc_library(
"platform/mutex.h",
"platform/protobuf.h", # TODO(josh11b): make internal
"platform/regexp.h",
"platform/strong_hash.h",
"platform/thread_annotations.h",
"platform/types.h",
],

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/strong_hash.h"
namespace tensorflow {
@ -57,11 +58,14 @@ class LegacyStringToHashBuckeOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(LegacyStringToHashBuckeOp);
};
// StringToHashBucket is deprecated in favor of StringToHashBucketStable.
// StringToHashBucket is deprecated in favor of StringToHashBucketFast/Strong.
REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU),
LegacyStringToHashBuckeOp);
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
StringToHashBucketOp<Fingerprint64>);
REGISTER_KERNEL_BUILDER(Name("StringToHashBucketStrong").Device(DEVICE_CPU),
StringToKeyedHashBucketOp<StrongKeyedHash>);
} // namespace tensorflow

View File

@ -61,6 +61,49 @@ class StringToHashBucketOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
};
template <uint64 hash(const uint64 (&)[2], const string&)>
class StringToKeyedHashBucketOp : public OpKernel {
public:
explicit StringToKeyedHashBucketOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
std::vector<int64> key;
OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key));
OP_REQUIRES(ctx, key.size() == 2,
errors::InvalidArgument("Key must have 2 elements"));
std::memcpy(key_, key.data(), sizeof(key_));
}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
const auto& input_flat = input_tensor->flat<string>();
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("output", input_tensor->shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int64>();
typedef decltype(input_flat.size()) Index;
for (Index i = 0; i < input_flat.size(); ++i) {
const uint64 input_hash = hash(key_, input_flat(i));
const uint64 bucket_id = input_hash % num_buckets_;
// The number of buckets is always in the positive range of int64 so is
// the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
// safe.
output_flat(i) = static_cast<int64>(bucket_id);
}
}
private:
int64 num_buckets_;
uint64 key_[2];
TF_DISALLOW_COPY_AND_ASSIGN(StringToKeyedHashBucketOp);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_

View File

@ -26,17 +26,49 @@ Converts each string in the input Tensor to its hash mod by a number of buckets.
The hash function is deterministic on the content of the string within the
process and will never change. However, it is not suitable for cryptography.
This function may be used when CPU time is scarce and inputs are trusted or
unimportant. There is a risk of adversaries constructing inputs that all hash
to the same bucket. To prevent this problem, use a strong hash function with
`tf.string_to_hash_bucket_strong`.
input: The strings to assing a hash bucket.
input: The strings to assign a hash bucket.
num_buckets: The number of buckets.
output: A Tensor of the same shape as the input `string_tensor`.
)doc");
REGISTER_OP("StringToHashBucketStrong")
.Input("input: string")
.Output("output: int64")
.Attr("num_buckets: int >= 1")
.Attr("key: list(int)")
.Doc(R"doc(
Converts each string in the input Tensor to its hash mod by a number of buckets.
The hash function is deterministic on the content of the string within the
process. The hash function is a keyed hash function, where attribute `key`
defines the key of the hash function. `key` is an array of 2 elements.
A strong hash is important when inputs may be malicious, e.g. URLs with
additional components. Adversaries could try to make their inputs hash to the
same bucket for a denial-of-service attack or to skew the results. A strong
hash prevents this by making it dificult, if not infeasible, to compute inputs
that hash to the same bucket. This comes at a cost of roughly 4x higher compute
time than tf.string_to_hash_bucket_fast.
input: The strings to assign a hash bucket.
num_buckets: The number of buckets.
key: The key for the keyed hash function passed as a list of two uint64
elements.
output: A Tensor of the same shape as the input `string_tensor`.
)doc");
REGISTER_OP("StringToHashBucket")
.Input("string_tensor: string")
.Output("output: int64")
.Attr("num_buckets: int >= 1")
.Deprecated(10, "Use tf.string_to_hash_bucket_fast()")
.Deprecated(10,
"Use `tf.string_to_hash_bucket_fast()` or "
"`tf.string_to_hash_bucket_strong()`")
.Doc(R"doc(
Converts each string in the input Tensor to its hash mod by a number of buckets.

View File

@ -50,6 +50,7 @@ cc_library(
"@farmhash_archive//:farmhash",
"@jpeg_archive//:jpeg",
"@png_archive//:png",
"@highwayhash//:sip_hash",
"@re2//:re2",
"//tensorflow/core:protos_cc",
],

View File

@ -0,0 +1,30 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
#include "highwayhash/sip_hash.h"
#include "highwayhash/state_helpers.h"
namespace tensorflow {
inline uint64 StrongKeyedHash(const uint64 (&key)[2], const string& s) {
return highwayhash::StringHasher<highwayhash::SipHashState>()(key, s);
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
#define TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// This is a strong keyed hash function interface for strings.
// The hash function is deterministic on the content of the string within the
// process. The key of the hash is an array of 2 uint64 elements.
// A strong hash make it dificult, if not infeasible, to compute inputs that
// hash to the same bucket.
//
// Usage:
// uint64 key[2] = {123, 456};
// string input = "input string";
// uint64 hash_value = StrongKeyedHash(key, input);
//
uint64 StrongKeyedHash(const uint64 (&)[2], const string&);
} // namespace tensorflow
#if defined(PLATFORM_GOOGLE)
#include "tensorflow/core/platform/google/strong_hash.h"
#else
#include "tensorflow/core/platform/default/strong_hash.h"
#endif
#endif // TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
@ -66,6 +67,30 @@ class StringToHashBucketOpTest(tf.test.TestCase):
# Hash64('c') -> 14899841994519054197 -> mod 10 -> 7
self.assertAllEqual([8, 0, 7], result)
def testStringToOneHashBucketStrongOneHashBucket(self):
with self.test_session():
input_string = tf.constant(['a', 'b', 'c'])
output = tf.string_to_hash_bucket_strong(input_string, 1, key=[123, 345])
self.assertAllEqual([0, 0, 0], output.eval())
def testStringToHashBucketsStrong(self):
with self.test_session():
input_string = tf.constant(['a', 'b', 'c'])
output = tf.string_to_hash_bucket_strong(input_string,
10,
key=[98765, 132])
# key = [98765, 132]
# StrongKeyedHash(key, 'a') -> 7157389809176466784 -> mod 10 -> 4
# StrongKeyedHash(key, 'b') -> 15805638358933211562 -> mod 10 -> 2
# StrongKeyedHash(key, 'c') -> 18100027895074076528 -> mod 10 -> 8
self.assertAllEqual([4, 2, 8], output.eval())
def testStringToHashBucketsStrongInvalidKey(self):
with self.test_session():
input_string = tf.constant(['a', 'b', 'c'])
with self.assertRaisesOpError('Key must have 2 elements'):
tf.string_to_hash_bucket_strong(input_string, 10, key=[98765]).eval()
if __name__ == '__main__':
tf.test.main()

View File

@ -19,6 +19,7 @@ String hashing ops take a string input tensor and map each element to an
integer.
@@string_to_hash_bucket_fast
@@string_to_hash_bucket_strong
@@string_to_hash_bucket
## Joining
@ -49,10 +50,12 @@ from tensorflow.python.ops.gen_string_ops import *
ops.NoGradient("StringToHashBucket")
ops.NoGradient("StringToHashBucketFast")
ops.NoGradient("StringToHashBucketStrong")
ops.NoGradient("ReduceJoin")
ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape)
ops.RegisterShape("StringToHashBucketStrong")(common_shapes.unchanged_shape)
@ops.RegisterShape("ReduceJoin")

View File

@ -52,6 +52,13 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
actual = "@farmhash//:farmhash",
)
native.git_repository(
name = "highwayhash",
remote = "https://github.com/google/highwayhash.git",
commit = "be5edafc2e1a455768e260ccd68ae7317b6690ee",
init_submodules = True,
)
native.new_http_archive(
name = "jpeg_archive",
url = "http://www.ijg.org/files/jpegsrc.v9a.tar.gz",