diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index ede199a9169..67ac91cb9be 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -57,7 +57,8 @@ else: from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2 TextVectorizationV1 = TextVectorization -from tensorflow.python.keras.layers.preprocessing.categorical_crossing import CategoryCrossing +from tensorflow.python.keras.layers.preprocessing.category_crossing import CategoryCrossing +from tensorflow.python.keras.layers.preprocessing.hashing import Hashing # Advanced activations. from tensorflow.python.keras.layers.advanced_activations import LeakyReLU diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD index b580382f9d8..b7fdc17b81d 100644 --- a/tensorflow/python/keras/layers/preprocessing/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/BUILD @@ -25,7 +25,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":categorical_crossing", + ":category_crossing", ":discretization", ":hashing", ":image_preprocessing", @@ -52,9 +52,9 @@ py_library( ) py_library( - name = "categorical_crossing", + name = "category_crossing", srcs = [ - "categorical_crossing.py", + "category_crossing.py", ], srcs_version = "PY2AND3", deps = [ @@ -291,16 +291,16 @@ py_library( ) cuda_py_test( - name = "categorical_crossing_test", + name = "category_crossing_test", size = "medium", - srcs = ["categorical_crossing_test.py"], + srcs = ["category_crossing_test.py"], python_version = "PY3", shard_count = 4, tags = [ "no_windows", # b/149031156 ], deps = [ - ":categorical_crossing", + ":category_crossing", "//tensorflow/python:client_testlib", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -343,9 +343,9 @@ distribute_py_test( ) distribute_py_test( - name = "categorical_crossing_distribution_test", - srcs = ["categorical_crossing_distribution_test.py"], - main = "categorical_crossing_distribution_test.py", + name = "category_crossing_distribution_test", + srcs = ["category_crossing_distribution_test.py"], + main = "category_crossing_distribution_test.py", python_version = "PY3", tags = [ "multi_and_single_gpu", @@ -354,7 +354,7 @@ distribute_py_test( "no_oss", # b/155502591 ], deps = [ - ":categorical_crossing", + ":category_crossing", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/keras", diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD index 653a81581b3..6d29126bc7e 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD @@ -21,12 +21,22 @@ tf_py_test( ) tf_py_test( - name = "categorical_crossing_benchmark", - srcs = ["categorical_crossing_benchmark.py"], + name = "category_crossing_benchmark", + srcs = ["category_crossing_benchmark.py"], python_version = "PY3", deps = [ "//tensorflow:tensorflow_py", - "//tensorflow/python/keras/layers/preprocessing:categorical_crossing", + "//tensorflow/python/keras/layers/preprocessing:category_crossing", + ], +) + +tf_py_test( + name = "hashing_benchmark", + srcs = ["hashing_benchmark.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/layers/preprocessing:hashing", ], ) diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_crossing_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/category_crossing_benchmark.py similarity index 97% rename from tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_crossing_benchmark.py rename to tensorflow/python/keras/layers/preprocessing/benchmarks/category_crossing_benchmark.py index 80a7903f0b9..efc0ca3766f 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_crossing_benchmark.py +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/category_crossing_benchmark.py @@ -28,7 +28,7 @@ from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.layers.preprocessing import categorical_crossing +from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -74,7 +74,7 @@ class BenchmarkLayer(benchmark.Benchmark): def bm_layer_implementation(self, batch_size): input_1 = keras.Input(shape=(1,), dtype=dtypes.int64, name="word") input_2 = keras.Input(shape=(1,), dtype=dtypes.int64, name="int") - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() _ = layer([input_1, input_2]) num_repeats = 5 @@ -97,7 +97,7 @@ class BenchmarkLayer(benchmark.Benchmark): ends.append(time.time()) avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches - name = "categorical_crossing|batch_%s" % batch_size + name = "category_crossing|batch_%s" % batch_size baseline = self.run_dataset_implementation(batch_size) extras = { "dataset implementation baseline": baseline, diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/hashing_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/hashing_benchmark.py new file mode 100644 index 00000000000..68ab28c7f6c --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/hashing_benchmark.py @@ -0,0 +1,115 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Keras hashing preprocessing layer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import random +import string +import time + +from absl import flags +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.layers.preprocessing import hashing +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS + +v2_compat.enable_v2_behavior() + + +# word_gen creates random sequences of ASCII letters (both lowercase and upper). +# The number of unique strings is ~2,700. +def word_gen(): + for _ in itertools.count(1): + yield "".join(random.choice(string.ascii_letters) for i in range(2)) + + +class BenchmarkLayer(benchmark.Benchmark): + """Benchmark the layer forward pass.""" + + def run_dataset_implementation(self, batch_size): + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string, + tensor_shape.TensorShape([])) + ds = ds.shuffle(batch_size * 100) + ds = ds.batch(batch_size) + num_batches = 5 + ds = ds.take(num_batches) + ds = ds.prefetch(num_batches) + starts.append(time.time()) + # Benchmarked code begins here. + for i in ds: + _ = string_ops.string_to_hash_bucket(i, num_buckets=2) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches + return avg_time + + def bm_layer_implementation(self, batch_size): + input_1 = keras.Input(shape=(None,), dtype=dtypes.string, name="word") + layer = hashing.Hashing(num_bins=2) + _ = layer(input_1) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string, + tensor_shape.TensorShape([])) + ds = ds.shuffle(batch_size * 100) + ds = ds.batch(batch_size) + num_batches = 5 + ds = ds.take(num_batches) + ds = ds.prefetch(num_batches) + starts.append(time.time()) + # Benchmarked code begins here. + for i in ds: + _ = layer(i) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches + name = "hashing|batch_%s" % batch_size + baseline = self.run_dataset_implementation(batch_size) + extras = { + "dataset implementation baseline": baseline, + "delta seconds": (baseline - avg_time), + "delta percent": ((baseline - avg_time) / baseline) * 100 + } + self.report_benchmark( + iters=num_repeats, wall_time=avg_time, extras=extras, name=name) + + def benchmark_vocab_size_by_batch(self): + for batch in [32, 64, 256]: + self.bm_layer_implementation(batch_size=batch) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_crossing.py b/tensorflow/python/keras/layers/preprocessing/category_crossing.py similarity index 87% rename from tensorflow/python/keras/layers/preprocessing/categorical_crossing.py rename to tensorflow/python/keras/layers/preprocessing/category_crossing.py index 68848458bb2..79c27d9ec36 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_crossing.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing.py @@ -49,6 +49,17 @@ class CategoryCrossing(Layer): [b'b_X_e'], [b'c_X_f']], dtype=object)> + + >>> inp_1 = tf.constant([['a'], ['b'], ['c']]) + >>> inp_2 = tf.constant([['d'], ['e'], ['f']]) + >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing( + ... separator='-') + >>> layer([inp_1, inp_2]) + + Arguments: depth: depth of input crossing. By default None, all inputs are crossed into one output. It can also be an int or tuple/list of ints. Passing an @@ -59,6 +70,8 @@ class CategoryCrossing(Layer): equal to N1 or N2. Passing `None` means a single crossed output with all inputs. For example, with inputs `a`, `b` and `c`, `depth=2` means the output will be [a;b;c;cross(a, b);cross(bc);cross(ca)]. + separator: A string added between each input being joined. Defaults to + '_X_'. name: Name to give to the layer. **kwargs: Keyword arguments to construct a layer. @@ -98,13 +111,12 @@ class CategoryCrossing(Layer): `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]` """ - def __init__(self, - depth=None, - name=None, - **kwargs): - # TODO(tanzheny): Consider making seperator configurable. + def __init__(self, depth=None, name=None, separator=None, **kwargs): super(CategoryCrossing, self).__init__(name=name, **kwargs) self.depth = depth + if separator is None: + separator = '_X_' + self.separator = separator if isinstance(depth, (tuple, list)): self._depth_tuple = depth elif depth is not None: @@ -114,12 +126,16 @@ class CategoryCrossing(Layer): """Gets the crossed output from a partial list/tuple of inputs.""" # If ragged_out=True, convert output from sparse to ragged. if ragged_out: + # TODO(momernick): Support separator with ragged_cross. + if self.separator != '_X_': + raise ValueError('Non-default separator with ragged input is not ' + 'supported yet, given {}'.format(self.separator)) return ragged_array_ops.cross(partial_inputs) elif sparse_out: - return sparse_ops.sparse_cross(partial_inputs) + return sparse_ops.sparse_cross(partial_inputs, separator=self.separator) else: return sparse_ops.sparse_tensor_to_dense( - sparse_ops.sparse_cross(partial_inputs)) + sparse_ops.sparse_cross(partial_inputs, separator=self.separator)) def call(self, inputs): depth_tuple = self._depth_tuple if self.depth else (len(inputs),) @@ -178,6 +194,7 @@ class CategoryCrossing(Layer): def get_config(self): config = { 'depth': self.depth, + 'separator': self.separator, } base_config = super(CategoryCrossing, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_crossing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py similarity index 98% rename from tensorflow/python/keras/layers/preprocessing/categorical_crossing_distribution_test.py rename to tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py index 57dea6edf4a..1ccc7fe2296 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_crossing_distribution_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py @@ -28,7 +28,7 @@ from tensorflow.python.distribute import tpu_strategy from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.keras import keras_parameterized -from tensorflow.python.keras.layers.preprocessing import categorical_crossing +from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.platform import test @@ -72,7 +72,7 @@ class CategoryCrossingDistributionTest( input_data_2 = keras.Input(shape=(2,), dtype=dtypes.string, name='input_2') input_data = [input_data_1, input_data_2] - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) output_dataset = model.predict(inp_dataset) diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_crossing_test.py b/tensorflow/python/keras/layers/preprocessing/category_crossing_test.py similarity index 82% rename from tensorflow/python/keras/layers/preprocessing/categorical_crossing_test.py rename to tensorflow/python/keras/layers/preprocessing/category_crossing_test.py index 5bbcf5ce022..f076c9ea865 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_crossing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing_test.py @@ -29,7 +29,7 @@ from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras.engine import input_layer from tensorflow.python.keras.engine import training -from tensorflow.python.keras.layers.preprocessing import categorical_crossing +from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.ops import array_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -41,7 +41,7 @@ from tensorflow.python.platform import test class CategoryCrossingTest(keras_parameterized.TestCase): def test_crossing_sparse_inputs(self): - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() inputs_0 = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 0], [1, 1]], values=['a', 'b', 'c'], @@ -52,8 +52,32 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices) self.assertAllEqual([b'a_X_d', b'b_X_e', b'c_X_e'], output.values) + def test_crossing_sparse_inputs_custom_sep(self): + layer = category_crossing.CategoryCrossing(separator='_Y_') + inputs_0 = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=['a', 'b', 'c'], + dense_shape=[2, 2]) + inputs_1 = sparse_tensor.SparseTensor( + indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3]) + output = layer([inputs_0, inputs_1]) + self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices) + self.assertAllEqual([b'a_Y_d', b'b_Y_e', b'c_Y_e'], output.values) + + def test_crossing_sparse_inputs_empty_sep(self): + layer = category_crossing.CategoryCrossing(separator='') + inputs_0 = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=['a', 'b', 'c'], + dense_shape=[2, 2]) + inputs_1 = sparse_tensor.SparseTensor( + indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3]) + output = layer([inputs_0, inputs_1]) + self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices) + self.assertAllEqual([b'ad', b'be', b'ce'], output.values) + def test_crossing_sparse_inputs_depth_int(self): - layer = categorical_crossing.CategoryCrossing(depth=1) + layer = category_crossing.CategoryCrossing(depth=1) inputs_0 = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 0], [2, 0]], values=['a', 'b', 'c'], @@ -69,7 +93,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllEqual(expected_out, output) def test_crossing_sparse_inputs_depth_tuple(self): - layer = categorical_crossing.CategoryCrossing(depth=(2, 3)) + layer = category_crossing.CategoryCrossing(depth=(2, 3)) inputs_0 = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 0], [2, 0]], values=['a', 'b', 'c'], @@ -107,14 +131,14 @@ class CategoryCrossingTest(keras_parameterized.TestCase): inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) - non_hashed_layer = categorical_crossing.CategoryCrossing() + non_hashed_layer = category_crossing.CategoryCrossing() out_t = non_hashed_layer([inp_0_t, inp_1_t]) model = training.Model(inputs=[inp_0_t, inp_1_t], outputs=out_t) expected_output = [[b'omar_X_a', b'skywalker_X_a'], [b'marlo_X_b']] self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1])) def test_crossing_ragged_inputs_depth_int(self): - layer = categorical_crossing.CategoryCrossing(depth=1) + layer = category_crossing.CategoryCrossing(depth=1) inputs_0 = ragged_factory_ops.constant([['a'], ['b'], ['c']]) inputs_1 = ragged_factory_ops.constant([['d'], ['e'], ['f']]) output = layer([inputs_0, inputs_1]) @@ -122,7 +146,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertIsInstance(output, ragged_tensor.RaggedTensor) self.assertAllEqual(expected_output, output) - layer = categorical_crossing.CategoryCrossing(depth=2) + layer = category_crossing.CategoryCrossing(depth=2) inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) out_t = layer([inp_0_t, inp_1_t]) @@ -132,7 +156,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1])) def test_crossing_ragged_inputs_depth_tuple(self): - layer = categorical_crossing.CategoryCrossing(depth=[2, 3]) + layer = category_crossing.CategoryCrossing(depth=[2, 3]) inputs_0 = ragged_factory_ops.constant([['a'], ['b'], ['c']]) inputs_1 = ragged_factory_ops.constant([['d'], ['e'], ['f']]) inputs_2 = ragged_factory_ops.constant([['g'], ['h'], ['i']]) @@ -149,21 +173,21 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllEqual(expected_output, output) def test_crossing_with_dense_inputs(self): - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() inputs_0 = np.asarray([[1, 2]]) inputs_1 = np.asarray([[1, 3]]) output = layer([inputs_0, inputs_1]) self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output) def test_crossing_dense_inputs_depth_int(self): - layer = categorical_crossing.CategoryCrossing(depth=1) + layer = category_crossing.CategoryCrossing(depth=1) inputs_0 = constant_op.constant([['a'], ['b'], ['c']]) inputs_1 = constant_op.constant([['d'], ['e'], ['f']]) output = layer([inputs_0, inputs_1]) expected_output = [[b'a', b'd'], [b'b', b'e'], [b'c', b'f']] self.assertAllEqual(expected_output, output) - layer = categorical_crossing.CategoryCrossing(depth=2) + layer = category_crossing.CategoryCrossing(depth=2) inp_0_t = input_layer.Input(shape=(1,), dtype=dtypes.string) inp_1_t = input_layer.Input(shape=(1,), dtype=dtypes.string) out_t = layer([inp_0_t, inp_1_t]) @@ -174,7 +198,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1])) def test_crossing_dense_inputs_depth_tuple(self): - layer = categorical_crossing.CategoryCrossing(depth=[2, 3]) + layer = category_crossing.CategoryCrossing(depth=[2, 3]) inputs_0 = constant_op.constant([['a'], ['b'], ['c']]) inputs_1 = constant_op.constant([['d'], ['e'], ['f']]) inputs_2 = constant_op.constant([['g'], ['h'], ['i']]) @@ -200,21 +224,21 @@ class CategoryCrossingTest(keras_parameterized.TestCase): tensor_spec.TensorSpec(input_shape, dtypes.string) for input_shape in input_shapes ] - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() output_spec = layer.compute_output_signature(input_specs) self.assertEqual(output_spec.shape.dims[0], input_shapes[0].dims[0]) self.assertEqual(output_spec.dtype, dtypes.string) @tf_test_util.run_v2_only def test_config_with_custom_name(self): - layer = categorical_crossing.CategoryCrossing(depth=2, name='hashing') + layer = category_crossing.CategoryCrossing(depth=2, name='hashing') config = layer.get_config() - layer_1 = categorical_crossing.CategoryCrossing.from_config(config) + layer_1 = category_crossing.CategoryCrossing.from_config(config) self.assertEqual(layer_1.name, layer.name) - layer = categorical_crossing.CategoryCrossing(name='hashing') + layer = category_crossing.CategoryCrossing(name='hashing') config = layer.get_config() - layer_1 = categorical_crossing.CategoryCrossing.from_config(config) + layer_1 = category_crossing.CategoryCrossing.from_config(config) self.assertEqual(layer_1.name, layer.name) diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py index dfd4761f193..05b4445829a 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing.py @@ -22,20 +22,28 @@ import functools from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util.tf_export import keras_export + +# Default key from tf.sparse.cross_hashed +_DEFAULT_SALT_KEY = [0xDECAFCAFFE, 0xDECAFCAFFE] +@keras_export('keras.layers.experimental.preprocessing.Hashing') class Hashing(Layer): """Implements categorical feature hashing, also known as "hashing trick". - This layer transforms categorical inputs to hashed output. It converts a - sequence of int or string to a sequence of int. The stable hash function uses - tensorflow::ops::Fingerprint to produce universal output that is consistent - across platforms. + This layer transforms single or multiple categorical inputs to hashed output. + It converts a sequence of int or string to a sequence of int. The stable hash + function uses tensorflow::ops::Fingerprint to produce universal output that + is consistent across platforms. This layer uses [FarmHash64](https://github.com/google/farmhash) by default, which provides a consistent hashed output across different platforms and is @@ -48,50 +56,91 @@ class Hashing(Layer): the `salt` value serving as additional input to the hash function. Example (FarmHash64): - ```python - layer = Hashing(num_bins=3) - inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) - layer(inputs) - [[1], [0], [1], [1], [2]] - ``` + + >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3) + >>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + >>> layer(inp) + + Example (SipHash64): - ```python - layer = Hashing(num_bins=3, salt=[133, 137]) - inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) - layer(inputs) - [[1], [2], [1], [0], [2]] - ``` + + >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3, + ... salt=[133, 137]) + >>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + >>> layer(inp) + + + Example (Siphash64 with a single integer, same as `salt=[133, 133]` + + >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3, + ... salt=133) + >>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + >>> layer(inp) + + + Reference: [SipHash with salt](https://www.131002.net/siphash/siphash.pdf) Arguments: num_bins: Number of hash bins. - salt: A tuple/list of 2 unsigned integer numbers. If passed, the hash - function used will be SipHash64, with these values used as an additional - input (known as a "salt" in cryptography). + salt: A single unsigned integer or None. + If passed, the hash function used will be SipHash64, with these values + used as an additional input (known as a "salt" in cryptography). These should be non-zero. Defaults to `None` (in that - case, the FarmHash64 hash function is used). + case, the FarmHash64 hash function is used). It also supports + tuple/list of 2 unsigned integer numbers, see reference paper for details. name: Name to give to the layer. **kwargs: Keyword arguments to construct a layer. - Input shape: A string, int32 or int64 tensor of shape - `[batch_size, d1, ..., dm]` + Input shape: A single or list of string, int32 or int64 `Tensor`, + `SparseTensor` or `RaggedTensor` of shape `[batch_size, ...,]` - Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]` + Output shape: An int64 `Tensor`, `SparseTensor` or `RaggedTensor` of shape + `[batch_size, ...]`. If any input is `RaggedTensor` then output is + `RaggedTensor`, otherwise if any input is `SparseTensor` then output is + `SparseTensor`, otherwise the output is `Tensor`. """ def __init__(self, num_bins, salt=None, name=None, **kwargs): if num_bins is None or num_bins <= 0: raise ValueError('`num_bins` cannot be `None` or non-positive values.') - if salt is not None: - if not isinstance(salt, (tuple, list)) or len(salt) != 2: - raise ValueError('`salt` must be a tuple or list of 2 unsigned ' - 'integer numbers, got {}'.format(salt)) super(Hashing, self).__init__(name=name, **kwargs) self.num_bins = num_bins - self.salt = salt + self.strong_hash = True if salt is not None else False + if salt is not None: + if isinstance(salt, (tuple, list)) and len(salt) == 2: + self.salt = salt + elif isinstance(salt, int): + self.salt = [salt, salt] + else: + raise ValueError('`salt can only be a tuple of size 2 integers, or a ' + 'single integer, given {}'.format(salt)) + else: + self.salt = _DEFAULT_SALT_KEY def call(self, inputs): + if isinstance(inputs, (tuple, list)): + return self._process_input_list(inputs) + else: + return self._process_single_input(inputs) + + def _process_single_input(self, inputs): # Converts integer inputs to string. if inputs.dtype.is_integer: if isinstance(inputs, sparse_tensor.SparseTensor): @@ -116,10 +165,38 @@ class Hashing(Layer): else: return str_to_hash_bucket(inputs, self.num_bins, name='hash') + def _process_input_list(self, inputs): + # TODO(momernick): support ragged_cross_hashed with corrected fingerprint + # and siphash. + if any([isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs]): + raise ValueError('Hashing with ragged input is not supported yet.') + sparse_inputs = [ + inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor) + ] + dense_inputs = [ + inp for inp in inputs if not isinstance(inp, sparse_tensor.SparseTensor) + ] + all_dense = True if not sparse_inputs else False + indices = [sp_inp.indices for sp_inp in sparse_inputs] + values = [sp_inp.values for sp_inp in sparse_inputs] + shapes = [sp_inp.dense_shape for sp_inp in sparse_inputs] + indices_out, values_out, shapes_out = gen_sparse_ops.sparse_cross_hashed( + indices=indices, + values=values, + shapes=shapes, + dense_inputs=dense_inputs, + num_buckets=self.num_bins, + strong_hash=self.strong_hash, + salt=self.salt) + sparse_out = sparse_tensor.SparseTensor(indices_out, values_out, shapes_out) + if all_dense: + return sparse_ops.sparse_tensor_to_dense(sparse_out) + return sparse_out + def _get_string_to_hash_bucket_fn(self): """Returns the string_to_hash_bucket op to use based on `hasher_key`.""" # string_to_hash_bucket_fast uses FarmHash64 as hash function. - if self.salt is None: + if not self.strong_hash: return string_ops.string_to_hash_bucket_fast # string_to_hash_bucket_strong uses SipHash64 as hash function. else: @@ -127,16 +204,43 @@ class Hashing(Layer): string_ops.string_to_hash_bucket_strong, key=self.salt) def compute_output_shape(self, input_shape): - return input_shape + if not isinstance(input_shape, (tuple, list)): + return input_shape + input_shapes = input_shape + batch_size = None + for inp_shape in input_shapes: + inp_tensor_shape = tensor_shape.TensorShape(inp_shape).as_list() + if len(inp_tensor_shape) != 2: + raise ValueError('Inputs must be rank 2, get {}'.format(input_shapes)) + if batch_size is None: + batch_size = inp_tensor_shape[0] + # The second dimension is dynamic based on inputs. + output_shape = [batch_size, None] + return tensor_shape.TensorShape(output_shape) def compute_output_signature(self, input_spec): - output_shape = self.compute_output_shape(input_spec.shape.as_list()) - output_dtype = dtypes.int64 - if isinstance(input_spec, sparse_tensor.SparseTensorSpec): + if not isinstance(input_spec, (tuple, list)): + output_shape = self.compute_output_shape(input_spec.shape) + output_dtype = dtypes.int64 + if isinstance(input_spec, sparse_tensor.SparseTensorSpec): + return sparse_tensor.SparseTensorSpec( + shape=output_shape, dtype=output_dtype) + else: + return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype) + input_shapes = [x.shape for x in input_spec] + output_shape = self.compute_output_shape(input_shapes) + if any([ + isinstance(inp_spec, ragged_tensor.RaggedTensorSpec) + for inp_spec in input_spec + ]): + return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64) + elif any([ + isinstance(inp_spec, sparse_tensor.SparseTensorSpec) + for inp_spec in input_spec + ]): return sparse_tensor.SparseTensorSpec( - shape=output_shape, dtype=output_dtype) - else: - return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype) + shape=output_shape, dtype=dtypes.int64) + return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64) def get_config(self): config = {'num_bins': self.num_bins, 'salt': self.salt} diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_test.py index 147e4bc371b..4c3fd9c7501 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing_test.py @@ -51,6 +51,15 @@ class HashingTest(keras_parameterized.TestCase): # Assert equal for hashed output that should be true on all platforms. self.assertAllClose([[0], [0], [1], [0], [0]], output) + def test_hash_dense_multi_inputs_farmhash(self): + layer = hashing.Hashing(num_bins=2) + inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], + ['skywalker']]) + inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + output = layer([inp_1, inp_2]) + # Assert equal for hashed output that should be true on all platforms. + self.assertAllClose([[0], [0], [1], [1], [0]], output) + def test_hash_dense_int_input_farmhash(self): layer = hashing.Hashing(num_bins=3) inp = np.asarray([[0], [1], [2], [3], [4]]) @@ -72,6 +81,21 @@ class HashingTest(keras_parameterized.TestCase): # Note the result is different from (133, 137). self.assertAllClose([[1], [0], [1], [0], [1]], output_2) + def test_hash_dense_multi_inputs_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], + ['skywalker']]) + inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + output = layer([inp_1, inp_2]) + # Assert equal for hashed output that should be true on all platforms. + # Note the result is different from FarmHash. + self.assertAllClose([[0], [1], [0], [0], [1]], output) + + layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137]) + output_2 = layer_2([inp_1, inp_2]) + # Note the result is different from (133, 137). + self.assertAllClose([[1], [1], [1], [0], [1]], output_2) + def test_hash_dense_int_input_siphash(self): layer = hashing.Hashing(num_bins=3, salt=[133, 137]) inp = np.asarray([[0], [1], [2], [3], [4]]) @@ -90,6 +114,19 @@ class HashingTest(keras_parameterized.TestCase): self.assertAllClose(indices, output.indices) self.assertAllClose([0, 0, 1, 0, 0], output.values) + def test_hash_sparse_multi_inputs_farmhash(self): + layer = hashing.Hashing(num_bins=2) + indices = [[0, 0], [1, 0], [2, 0]] + inp_1 = sparse_tensor.SparseTensor( + indices=indices, + values=['omar', 'stringer', 'marlo'], + dense_shape=[3, 1]) + inp_2 = sparse_tensor.SparseTensor( + indices=indices, values=['A', 'B', 'C'], dense_shape=[3, 1]) + output = layer([inp_1, inp_2]) + self.assertAllClose(indices, output.indices) + self.assertAllClose([0, 0, 1], output.values) + def test_hash_sparse_int_input_farmhash(self): layer = hashing.Hashing(num_bins=3) indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] @@ -116,6 +153,25 @@ class HashingTest(keras_parameterized.TestCase): # The result should be same with test_hash_dense_input_siphash. self.assertAllClose([1, 0, 1, 0, 1], output.values) + def test_hash_sparse_multi_inputs_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + indices = [[0, 0], [1, 0], [2, 0]] + inp_1 = sparse_tensor.SparseTensor( + indices=indices, + values=['omar', 'stringer', 'marlo'], + dense_shape=[3, 1]) + inp_2 = sparse_tensor.SparseTensor( + indices=indices, values=['A', 'B', 'C'], dense_shape=[3, 1]) + output = layer([inp_1, inp_2]) + # The result should be same with test_hash_dense_input_siphash. + self.assertAllClose(indices, output.indices) + self.assertAllClose([0, 1, 0], output.values) + + layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137]) + output = layer_2([inp_1, inp_2]) + # The result should be same with test_hash_dense_input_siphash. + self.assertAllClose([1, 1, 1], output.values) + def test_hash_sparse_int_input_siphash(self): layer = hashing.Hashing(num_bins=3, salt=[133, 137]) indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] @@ -140,6 +196,17 @@ class HashingTest(keras_parameterized.TestCase): model = training.Model(inputs=inp_t, outputs=out_t) self.assertAllClose(out_data, model.predict(inp_data)) + def test_hash_ragged_string_multi_inputs_farmhash(self): + layer = hashing.Hashing(num_bins=2) + inp_data_1 = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + inp_data_2 = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + with self.assertRaisesRegexp(ValueError, 'not supported yet'): + _ = layer([inp_data_1, inp_data_2]) + def test_hash_ragged_int_input_farmhash(self): layer = hashing.Hashing(num_bins=3) inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]], @@ -178,6 +245,17 @@ class HashingTest(keras_parameterized.TestCase): model = training.Model(inputs=inp_t, outputs=out_t) self.assertAllClose(out_data, model.predict(inp_data)) + def test_hash_ragged_string_multi_inputs_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + inp_data_1 = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + inp_data_2 = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + with self.assertRaisesRegexp(ValueError, 'not supported yet'): + _ = layer([inp_data_1, inp_data_2]) + def test_hash_ragged_int_input_siphash(self): layer = hashing.Hashing(num_bins=3, salt=[133, 137]) inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]], @@ -197,11 +275,11 @@ class HashingTest(keras_parameterized.TestCase): _ = hashing.Hashing(num_bins=None) with self.assertRaisesRegexp(ValueError, 'cannot be `None`'): _ = hashing.Hashing(num_bins=-1) - with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'): _ = hashing.Hashing(num_bins=2, salt='string') - with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'): _ = hashing.Hashing(num_bins=2, salt=[1]) - with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'): _ = hashing.Hashing(num_bins=1, salt=constant_op.constant([133, 137])) def test_hash_compute_output_signature(self): diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index 9cafc0f08d8..2eb7cff75bb 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -45,6 +45,8 @@ from tensorflow.python.keras.layers import recurrent from tensorflow.python.keras.layers import recurrent_v2 from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 from tensorflow.python.keras.layers import wrappers +from tensorflow.python.keras.layers.preprocessing import category_crossing +from tensorflow.python.keras.layers.preprocessing import hashing from tensorflow.python.keras.layers.preprocessing import image_preprocessing from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1 @@ -60,7 +62,7 @@ ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional, embeddings, einsum_dense, local, merge, noise, normalization, pooling, image_preprocessing, preprocessing_normalization_v1, preprocessing_text_vectorization_v1, - recurrent, wrappers) + recurrent, wrappers, hashing, category_crossing) ALL_V2_MODULES = ( rnn_cell_wrapper_v2, normalization_v2, diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index c4c88ab86ef..cc4b1010021 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -27,6 +27,7 @@ import numbers import numpy as np +from tensorflow.python.compat import compat as tf_compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -569,7 +570,7 @@ def sparse_add_v2(a, b, threshold=0): @tf_export("sparse.cross") -def sparse_cross(inputs, name=None): +def sparse_cross(inputs, name=None, separator=None): """Generates sparse cross from a list of sparse and dense tensors. For example, if the inputs are @@ -590,14 +591,39 @@ def sparse_cross(inputs, name=None): [1, 0]: "b_X_e_X_g" [1, 1]: "c_X_e_X_g" + Customized separator "_Y_": + + >>> inp_0 = tf.constant([['a'], ['b']]) + >>> inp_1 = tf.constant([['c'], ['d']]) + >>> output = tf.sparse.cross([inp_0, inp_1], separator='_Y_') + >>> output.values + + + Args: inputs: An iterable of `Tensor` or `SparseTensor`. name: Optional name for the op. + separator: A string added between each string being joined. Defaults to + '_X_'. Returns: A `SparseTensor` of type `string`. """ - return _sparse_cross_internal(inputs=inputs, hashed_output=False, name=name) + if separator is None and not tf_compat.forward_compatible(2020, 6, 14): + return _sparse_cross_internal(inputs=inputs, hashed_output=False, name=name) + if separator is None: + separator = "_X_" + separator = ops.convert_to_tensor(separator, dtypes.string) + indices, values, shapes, dense_inputs = _sparse_cross_internval_v2(inputs) + indices_out, values_out, shape_out = gen_sparse_ops.sparse_cross_v2( + indices=indices, + values=values, + shapes=shapes, + dense_inputs=dense_inputs, + sep=separator, + name=name) + return sparse_tensor.SparseTensor(indices_out, values_out, shape_out) _sparse_cross = sparse_cross @@ -655,6 +681,32 @@ _sparse_cross_hashed = sparse_cross_hashed _DEFAULT_HASH_KEY = 0xDECAFCAFFE +def _sparse_cross_internval_v2(inputs): + """See gen_sparse_ops.sparse_cross_v2.""" + if not isinstance(inputs, (tuple, list)): + raise TypeError("Inputs must be a list") + if not all( + isinstance(i, sparse_tensor.SparseTensor) or isinstance(i, ops.Tensor) + for i in inputs): + raise TypeError("All inputs must be Tensor or SparseTensor.") + sparse_inputs = [ + i for i in inputs if isinstance(i, sparse_tensor.SparseTensor) + ] + dense_inputs = [ + i for i in inputs if not isinstance(i, sparse_tensor.SparseTensor) + ] + indices = [sp_input.indices for sp_input in sparse_inputs] + values = [sp_input.values for sp_input in sparse_inputs] + shapes = [sp_input.dense_shape for sp_input in sparse_inputs] + for i in range(len(values)): + if values[i].dtype != dtypes.string: + values[i] = math_ops.cast(values[i], dtypes.int64) + for i in range(len(dense_inputs)): + if dense_inputs[i].dtype != dtypes.string: + dense_inputs[i] = math_ops.cast(dense_inputs[i], dtypes.int64) + return indices, values, shapes, dense_inputs + + def _sparse_cross_internal(inputs, hashed_output=False, num_buckets=0, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt index 0407188ab6b..6cfcbf73e5d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.experimental.preprocessing.CategoryCrossing" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'depth\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt new file mode 100644 index 00000000000..e4a5619058d --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt @@ -0,0 +1,218 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.Hashing" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.pbtxt index 0964922ea26..c93b8a89fb8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "CenterCrop" mtype: "" } + member { + name: "Hashing" + mtype: "" + } member { name: "Normalization" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt index f8f8edb26a8..9550418c2a6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt @@ -22,7 +22,7 @@ tf_module { } member_method { name: "cross" - argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'inputs\', \'name\', \'separator\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "cross_hashed" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt index 0407188ab6b..6cfcbf73e5d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.experimental.preprocessing.CategoryCrossing" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'depth\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt new file mode 100644 index 00000000000..e4a5619058d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt @@ -0,0 +1,218 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.Hashing" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.pbtxt index 0964922ea26..c93b8a89fb8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "CenterCrop" mtype: "" } + member { + name: "Hashing" + mtype: "" + } member { name: "Normalization" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt index 67235bb2cf2..0028b7d8953 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt @@ -18,7 +18,7 @@ tf_module { } member_method { name: "cross" - argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'inputs\', \'name\', \'separator\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "cross_hashed"