From 5040d0daaa69e95f8ef3e7c6807f39443941cf8b Mon Sep 17 00:00:00 2001
From: Yutaka Leon <yutaka.leon@gmail.com>
Date: Tue, 24 May 2016 00:00:43 -0800
Subject: [PATCH] Add a strong keyed hash function based on highwayhash's
 siphash.

Add string_to_hash_bucket_strong to assign hash buckets using the strong keyed hash function.
Change: 123080459
---
 tensorflow/core/BUILD                         |  1 +
 .../core/kernels/string_to_hash_bucket_op.cc  |  6 ++-
 .../core/kernels/string_to_hash_bucket_op.h   | 43 ++++++++++++++++++
 tensorflow/core/ops/string_ops.cc             | 36 ++++++++++++++-
 .../core/platform/default/build_config/BUILD  |  1 +
 .../core/platform/default/strong_hash.h       | 30 +++++++++++++
 tensorflow/core/platform/strong_hash.h        | 45 +++++++++++++++++++
 .../string_to_hash_bucket_op_test.py          | 25 +++++++++++
 tensorflow/python/ops/string_ops.py           |  3 ++
 tensorflow/workspace.bzl                      |  7 +++
 10 files changed, 194 insertions(+), 3 deletions(-)
 create mode 100644 tensorflow/core/platform/default/strong_hash.h
 create mode 100644 tensorflow/core/platform/strong_hash.h

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1be21fa9545..aa226886a80 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -211,6 +211,7 @@ cc_library(
         "platform/mutex.h",
         "platform/protobuf.h",  # TODO(josh11b): make internal
         "platform/regexp.h",
+        "platform/strong_hash.h",
         "platform/thread_annotations.h",
         "platform/types.h",
     ],
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc
index 3a2429d4cd0..e00cd25f455 100644
--- a/tensorflow/core/kernels/string_to_hash_bucket_op.cc
+++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/strong_hash.h"
 
 namespace tensorflow {
 
@@ -57,11 +58,14 @@ class LegacyStringToHashBuckeOp : public OpKernel {
   TF_DISALLOW_COPY_AND_ASSIGN(LegacyStringToHashBuckeOp);
 };
 
-// StringToHashBucket is deprecated in favor of StringToHashBucketStable.
+// StringToHashBucket is deprecated in favor of StringToHashBucketFast/Strong.
 REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU),
                         LegacyStringToHashBuckeOp);
 
 REGISTER_KERNEL_BUILDER(Name("StringToHashBucketFast").Device(DEVICE_CPU),
                         StringToHashBucketOp<Fingerprint64>);
 
+REGISTER_KERNEL_BUILDER(Name("StringToHashBucketStrong").Device(DEVICE_CPU),
+                        StringToKeyedHashBucketOp<StrongKeyedHash>);
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.h b/tensorflow/core/kernels/string_to_hash_bucket_op.h
index 9c6c0a89e42..0c3acbebbfb 100644
--- a/tensorflow/core/kernels/string_to_hash_bucket_op.h
+++ b/tensorflow/core/kernels/string_to_hash_bucket_op.h
@@ -61,6 +61,49 @@ class StringToHashBucketOp : public OpKernel {
   TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
 };
 
+template <uint64 hash(const uint64 (&)[2], const string&)>
+class StringToKeyedHashBucketOp : public OpKernel {
+ public:
+  explicit StringToKeyedHashBucketOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
+
+    std::vector<int64> key;
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key));
+    OP_REQUIRES(ctx, key.size() == 2,
+                errors::InvalidArgument("Key must have 2 elements"));
+    std::memcpy(key_, key.data(), sizeof(key_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor* input_tensor;
+    OP_REQUIRES_OK(context, context->input("input", &input_tensor));
+    const auto& input_flat = input_tensor->flat<string>();
+
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output("output", input_tensor->shape(),
+                                            &output_tensor));
+    auto output_flat = output_tensor->flat<int64>();
+
+    typedef decltype(input_flat.size()) Index;
+    for (Index i = 0; i < input_flat.size(); ++i) {
+      const uint64 input_hash = hash(key_, input_flat(i));
+      const uint64 bucket_id = input_hash % num_buckets_;
+      // The number of buckets is always in the positive range of int64 so is
+      // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
+      // safe.
+      output_flat(i) = static_cast<int64>(bucket_id);
+    }
+  }
+
+ private:
+  int64 num_buckets_;
+  uint64 key_[2];
+
+  TF_DISALLOW_COPY_AND_ASSIGN(StringToKeyedHashBucketOp);
+};
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 1a274f1e68a..526fc35eb25 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -26,17 +26,49 @@ Converts each string in the input Tensor to its hash mod by a number of buckets.
 
 The hash function is deterministic on the content of the string within the
 process and will never change. However, it is not suitable for cryptography.
+This function may be used when CPU time is scarce and inputs are trusted or
+unimportant. There is a risk of adversaries constructing inputs that all hash
+to the same bucket. To prevent this problem, use a strong hash function with
+`tf.string_to_hash_bucket_strong`.
 
-input: The strings to assing a hash bucket.
+input: The strings to assign a hash bucket.
 num_buckets: The number of buckets.
 output: A Tensor of the same shape as the input `string_tensor`.
 )doc");
 
+REGISTER_OP("StringToHashBucketStrong")
+    .Input("input: string")
+    .Output("output: int64")
+    .Attr("num_buckets: int >= 1")
+    .Attr("key: list(int)")
+    .Doc(R"doc(
+Converts each string in the input Tensor to its hash mod by a number of buckets.
+
+The hash function is deterministic on the content of the string within the
+process. The hash function is a keyed hash function, where attribute `key`
+defines the key of the hash function. `key` is an array of 2 elements.
+
+A strong hash is important when inputs may be malicious, e.g. URLs with
+additional components. Adversaries could try to make their inputs hash to the
+same bucket for a denial-of-service attack or to skew the results. A strong
+hash prevents this by making it dificult, if not infeasible, to compute inputs
+that hash to the same bucket. This comes at a cost of roughly 4x higher compute
+time than tf.string_to_hash_bucket_fast.
+
+input: The strings to assign a hash bucket.
+num_buckets: The number of buckets.
+key: The key for the keyed hash function passed as a list of two uint64
+  elements.
+output: A Tensor of the same shape as the input `string_tensor`.
+)doc");
+
 REGISTER_OP("StringToHashBucket")
     .Input("string_tensor: string")
     .Output("output: int64")
     .Attr("num_buckets: int >= 1")
-    .Deprecated(10, "Use tf.string_to_hash_bucket_fast()")
+    .Deprecated(10,
+                "Use `tf.string_to_hash_bucket_fast()` or "
+                "`tf.string_to_hash_bucket_strong()`")
     .Doc(R"doc(
 Converts each string in the input Tensor to its hash mod by a number of buckets.
 
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index 109fd18e6b5..66e2c75934f 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -50,6 +50,7 @@ cc_library(
         "@farmhash_archive//:farmhash",
         "@jpeg_archive//:jpeg",
         "@png_archive//:png",
+        "@highwayhash//:sip_hash",
         "@re2//:re2",
         "//tensorflow/core:protos_cc",
     ],
diff --git a/tensorflow/core/platform/default/strong_hash.h b/tensorflow/core/platform/default/strong_hash.h
new file mode 100644
index 00000000000..53d1dae98dd
--- /dev/null
+++ b/tensorflow/core/platform/default/strong_hash.h
@@ -0,0 +1,30 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
+
+#include "highwayhash/sip_hash.h"
+#include "highwayhash/state_helpers.h"
+
+namespace tensorflow {
+
+inline uint64 StrongKeyedHash(const uint64 (&key)[2], const string& s) {
+  return highwayhash::StringHasher<highwayhash::SipHashState>()(key, s);
+}
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_DEFAULT_STRONG_HASH_H_
diff --git a/tensorflow/core/platform/strong_hash.h b/tensorflow/core/platform/strong_hash.h
new file mode 100644
index 00000000000..7bd3eed6106
--- /dev/null
+++ b/tensorflow/core/platform/strong_hash.h
@@ -0,0 +1,45 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
+#define TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
+
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// This is a strong keyed hash function interface for strings.
+// The hash function is deterministic on the content of the string within the
+// process. The key of the hash is an array of 2 uint64 elements.
+// A strong hash make it dificult, if not infeasible, to compute inputs that
+// hash to the same bucket.
+//
+// Usage:
+//   uint64 key[2] = {123, 456};
+//   string input = "input string";
+//   uint64 hash_value = StrongKeyedHash(key, input);
+//
+uint64 StrongKeyedHash(const uint64 (&)[2], const string&);
+
+}  // namespace tensorflow
+
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/strong_hash.h"
+#else
+#include "tensorflow/core/platform/default/strong_hash.h"
+#endif
+
+#endif  // TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_
diff --git a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
index 379edbfbb04..8a018573d1f 100644
--- a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import numpy as np
 import tensorflow as tf
 
 
@@ -66,6 +67,30 @@ class StringToHashBucketOpTest(tf.test.TestCase):
       # Hash64('c') -> 14899841994519054197 -> mod 10 -> 7
       self.assertAllEqual([8, 0, 7], result)
 
+  def testStringToOneHashBucketStrongOneHashBucket(self):
+    with self.test_session():
+      input_string = tf.constant(['a', 'b', 'c'])
+      output = tf.string_to_hash_bucket_strong(input_string, 1, key=[123, 345])
+      self.assertAllEqual([0, 0, 0], output.eval())
+
+  def testStringToHashBucketsStrong(self):
+    with self.test_session():
+      input_string = tf.constant(['a', 'b', 'c'])
+      output = tf.string_to_hash_bucket_strong(input_string,
+                                               10,
+                                               key=[98765, 132])
+      # key = [98765, 132]
+      # StrongKeyedHash(key, 'a') -> 7157389809176466784 -> mod 10 -> 4
+      # StrongKeyedHash(key, 'b') -> 15805638358933211562 -> mod 10 -> 2
+      # StrongKeyedHash(key, 'c') -> 18100027895074076528 -> mod 10 -> 8
+      self.assertAllEqual([4, 2, 8], output.eval())
+
+  def testStringToHashBucketsStrongInvalidKey(self):
+    with self.test_session():
+      input_string = tf.constant(['a', 'b', 'c'])
+      with self.assertRaisesOpError('Key must have 2 elements'):
+        tf.string_to_hash_bucket_strong(input_string, 10, key=[98765]).eval()
+
 
 if __name__ == '__main__':
   tf.test.main()
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 1cd38af3b8f..e057ba64079 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -19,6 +19,7 @@ String hashing ops take a string input tensor and map each element to an
 integer.
 
 @@string_to_hash_bucket_fast
+@@string_to_hash_bucket_strong
 @@string_to_hash_bucket
 
 ## Joining
@@ -49,10 +50,12 @@ from tensorflow.python.ops.gen_string_ops import *
 
 ops.NoGradient("StringToHashBucket")
 ops.NoGradient("StringToHashBucketFast")
+ops.NoGradient("StringToHashBucketStrong")
 ops.NoGradient("ReduceJoin")
 
 ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
 ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape)
+ops.RegisterShape("StringToHashBucketStrong")(common_shapes.unchanged_shape)
 
 
 @ops.RegisterShape("ReduceJoin")
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index d99cb5b5e3d..7c68fb763fa 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -52,6 +52,13 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
     actual = "@farmhash//:farmhash",
   )
 
+  native.git_repository(
+    name = "highwayhash",
+    remote = "https://github.com/google/highwayhash.git",
+    commit = "be5edafc2e1a455768e260ccd68ae7317b6690ee",
+    init_submodules = True,
+  )
+
   native.new_http_archive(
     name = "jpeg_archive",
     url = "http://www.ijg.org/files/jpegsrc.v9a.tar.gz",