Export tf.keras.layers.experimental.preprocessing.CategoryCrossing layer.

PiperOrigin-RevId: 311398537
Change-Id: I394c7dd5ae7fe168f3238dbd8a7ab064ff6ad2c1
This commit is contained in:
Zhenyu Tan 2020-05-13 13:57:07 -07:00 committed by TensorFlower Gardener
parent 7d704e3246
commit 8d1e8b350c
11 changed files with 707 additions and 177 deletions

View File

@ -57,6 +57,7 @@ else:
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
TextVectorizationV1 = TextVectorization
from tensorflow.python.keras.layers.preprocessing.categorical_crossing import CategoryCrossing
# Advanced activations.
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU

View File

@ -310,6 +310,25 @@ distribute_py_test(
],
)
distribute_py_test(
name = "categorical_crossing_distribution_test",
srcs = ["categorical_crossing_distribution_test.py"],
main = "categorical_crossing_distribution_test.py",
python_version = "PY3",
tags = [
"multi_and_single_gpu",
],
tpu_tags = [
"no_oss", # b/155502591
],
deps = [
":categorical_crossing",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
],
)
tf_py_test(
name = "discretization_test",
size = "small",

View File

@ -17,6 +17,16 @@ tf_py_test(
],
)
tf_py_test(
name = "categorical_crossing_benchmark",
srcs = ["categorical_crossing_benchmark.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python/keras/layers/preprocessing:categorical_crossing",
],
)
tf_py_test(
name = "index_lookup_adapt_benchmark",
srcs = ["index_lookup_adapt_benchmark.py"],

View File

@ -0,0 +1,116 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for Keras categorical_encoding preprocessing layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import time
from absl import flags
import numpy as np
from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.layers.preprocessing import categorical_crossing
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
FLAGS = flags.FLAGS
v2_compat.enable_v2_behavior()
# word_gen creates random sequences of ASCII letters (both lowercase and upper).
# The number of unique strings is ~2,700.
def int_gen():
for _ in itertools.count(1):
yield (np.random.randint(0, 5, (1,)), np.random.randint(0, 7, (1,)))
class BenchmarkLayer(benchmark.Benchmark):
"""Benchmark the layer forward pass."""
def run_dataset_implementation(self, batch_size):
num_repeats = 5
starts = []
ends = []
for _ in range(num_repeats):
ds = dataset_ops.Dataset.from_generator(
int_gen, (dtypes.int64, dtypes.int64),
(tensor_shape.TensorShape([1]), tensor_shape.TensorShape([1])))
ds = ds.shuffle(batch_size * 100)
ds = ds.batch(batch_size)
num_batches = 5
ds = ds.take(num_batches)
ds = ds.prefetch(num_batches)
starts.append(time.time())
# Benchmarked code begins here.
for i in ds:
_ = sparse_ops.sparse_cross([i[0], i[1]])
# Benchmarked code ends here.
ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches
return avg_time
def bm_layer_implementation(self, batch_size):
input_1 = keras.Input(shape=(1,), dtype=dtypes.int64, name="word")
input_2 = keras.Input(shape=(1,), dtype=dtypes.int64, name="int")
layer = categorical_crossing.CategoryCrossing()
_ = layer([input_1, input_2])
num_repeats = 5
starts = []
ends = []
for _ in range(num_repeats):
ds = dataset_ops.Dataset.from_generator(
int_gen, (dtypes.int64, dtypes.int64),
(tensor_shape.TensorShape([1]), tensor_shape.TensorShape([1])))
ds = ds.shuffle(batch_size * 100)
ds = ds.batch(batch_size)
num_batches = 5
ds = ds.take(num_batches)
ds = ds.prefetch(num_batches)
starts.append(time.time())
# Benchmarked code begins here.
for i in ds:
_ = layer([i[0], i[1]])
# Benchmarked code ends here.
ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches
name = "categorical_crossing|batch_%s" % batch_size
baseline = self.run_dataset_implementation(batch_size)
extras = {
"dataset implementation baseline": baseline,
"delta seconds": (baseline - avg_time),
"delta percent": ((baseline - avg_time) / baseline) * 100
}
self.report_benchmark(
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
def benchmark_vocab_size_by_batch(self):
for batch in [32, 64, 256]:
self.bm_layer_implementation(batch_size=batch)
if __name__ == "__main__":
test.main()

View File

@ -20,49 +20,35 @@ from __future__ import print_function
import itertools
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.layers.experimental.preprocessing.CategoryCrossing')
class CategoryCrossing(Layer):
"""Category crossing layer.
This layer transforms multiple categorical inputs to categorical outputs
by Cartesian product, and hash the output if necessary. Without hashing
(`num_bins=None`) the output dtype is string, with hashing the output dtype
is int64.
For each input, the hash function uses a specific fingerprint method, i.e.,
[FarmHash64](https://github.com/google/farmhash) to compute the hashed output,
that provides a consistent hashed output across different platforms.
For multiple inputs, the final output is calculated by first computing the
fingerprint of `hash_key`, and concatenate it with the fingerprints of
each input. The user can also obfuscate the output with customized `hash_key`.
If [SipHash64[(https://github.com/google/highwayhash) is desired instead, the
user can set `num_bins=None` to get string outputs, and use Hashing layer to
get hashed output with SipHash64.
This layer concatenates multiple categorical inputs into a single categorical
output (similar to Cartesian product). The output dtype is string.
Usage:
Use with string output.
>>> inp_1 = tf.constant([['a'], ['b'], ['c']])
>>> inp_2 = tf.constant([['d'], ['e'], ['f']])
>>> layer = categorical_crossing.CategoryCrossing()
>>> output = layer([inp_1, inp_2])
Use with hashed output.
>>> layer = categorical_crossing.CategoryCrossing(num_bins=2)
>>> output = layer([inp_1, inp_2])
Use with customized hashed output.
>>> layer = categorical_crossing.CategoryCrossing(num_bins=2, hash_key=133)
>>> output = layer([inp_1, inp_2])
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing()
>>> layer([inp_1, inp_2])
<tf.Tensor: shape=(3, 1), dtype=string, numpy=
array([[b'a_X_d'],
[b'b_X_e'],
[b'c_X_f']], dtype=object)>
Arguments:
depth: depth of input crossing. By default None, all inputs are crossed into
@ -74,10 +60,6 @@ class CategoryCrossing(Layer):
equal to N1 or N2. Passing `None` means a single crossed output with all
inputs. For example, with inputs `a`, `b` and `c`, `depth=2` means the
output will be [a;b;c;cross(a, b);cross(bc);cross(ca)].
num_bins: Number of hash bins. By default None, no hashing is performed.
hash_key: Integer hash_key that will be used by the concatenate
fingerprints. If not given, will use a default key from
`tf.sparse.cross_hashed`. This is only valid when `num_bins` is not None.
name: Name to give to the layer.
**kwargs: Keyword arguments to construct a layer.
@ -87,114 +69,69 @@ class CategoryCrossing(Layer):
Output shape: a single string or int tensor or sparse tensor of shape
`[batch_size, d1, ..., dm]`
Below 'hash' stands for tf.fingerprint, and cat stands for 'FingerprintCat'.
Returns:
If any input is `RaggedTensor`, the output is `RaggedTensor`.
Else, if any input is `SparseTensor`, the output is `SparseTensor`.
Otherwise, the output is `Tensor`.
Example: (`depth`=None)
If the layer receives three inputs:
`a=[[1], [4]]`, `b=[[2], [5]]`, `c=[[3], [6]]`
the output will be a string tensor if not hashed:
the output will be a string tensor:
`[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
the output will be an int64 tensor if hashed:
`[[cat(hash(3), cat(hash(2), cat(hash(1), hash(hash_key))))],
[[cat(hash(6), cat(hash(5), cat(hash(4), hash(hash_key))))]`
Example: (`depth` is an integer)
With the same input above, and if `depth`=2,
the output will be a list of 6 string tensors if not hashed:
the output will be a list of 6 string tensors:
`[[b'1'], [b'4']]`
`[[b'2'], [b'5']]`
`[[b'3'], [b'6']]`
`[[b'1_X_2'], [b'4_X_5']]`,
`[[b'2_X_3'], [b'5_X_6']]`,
`[[b'3_X_1'], [b'6_X_4']]`
the output will be a list of 6 int64 tensors if hashed:
`[[hash(b'1')], [hash(b'4')]]`
`[[hash(b'2')], [hash(b'5')]]`
`[[hash(b'3')], [hash(b'6')]]`
`[[cat(hash(2), cat(hash(1), hash(hash_key)))],
[cat(hash(5), cat(hash(4), hash(hash_key)))]`,
`[[cat(hash(3), cat(hash(1), hash(hash_key)))],
[cat(hash(6), cat(hash(4), hash(hash_key)))]`,
`[[cat(hash(3), cat(hash(2), hash(hash_key)))],
[cat(hash(6), cat(hash(5), hash(hash_key)))]`,
Example: (`depth` is a tuple/list of integers)
With the same input above, and if `depth`=(2, 3)
the output will be a list of 4 string tensors if not hashed:
the output will be a list of 4 string tensors:
`[[b'1_X_2'], [b'4_X_5']]`,
`[[b'2_X_3'], [b'5_X_6']]`,
`[[b'3_X_1'], [b'6_X_4']]`,
`[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
the output will be a list of 4 int64 tensors if hashed:
`[
[cat(hash(2), cat(hash(1), hash(hash_key)))],
[cat(hash(5), cat(hash(4), hash(hash_key)))]
]`,
`[
[cat(hash(3), cat(hash(1), hash(hash_key)))],
[cat(hash(6), cat(hash(4), hash(hash_key)))]
]`,
`[
[cat(hash(3), cat(hash(2), hash(hash_key)))],
[cat(hash(6), cat(hash(5), hash(hash_key)))]
]`,
`[
[cat(hash(3), cat(hash(2), cat(hash(1), hash(hash_key))))],
[cat(hash(6), cat(hash(5), cat(hash(4), hash(hash_key))))]
]`
"""
def __init__(self,
depth=None,
num_bins=None,
hash_key=None,
name=None,
**kwargs):
# TODO(tanzheny): Consider making seperator configurable.
if num_bins is None and hash_key is not None:
raise ValueError('`hash_key` is only valid when `num_bins` is not None')
super(CategoryCrossing, self).__init__(name=name, **kwargs)
self.depth = depth
self.num_bins = num_bins
self.hash_key = hash_key
if isinstance(depth, (tuple, list)):
self._depth_tuple = depth
elif depth is not None:
self._depth_tuple = tuple([i for i in range(1, depth + 1)])
strategy = ds_context.get_strategy()
if strategy.__class__.__name__.startswith('TPUStrategy'):
raise ValueError('TPU strategy is not support for this layer yet.')
def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
"""Gets the crossed output from a partial list/tuple of inputs."""
if self.num_bins is not None:
partial_output = sparse_ops.sparse_cross_hashed(
partial_inputs, num_buckets=self.num_bins, hash_key=self.hash_key)
else:
partial_output = sparse_ops.sparse_cross(partial_inputs)
# If ragged_out=True, convert output from sparse to ragged.
if ragged_out:
return ragged_tensor.RaggedTensor.from_sparse(partial_output)
return ragged_array_ops.cross(partial_inputs)
elif sparse_out:
return partial_output
return sparse_ops.sparse_cross(partial_inputs)
else:
return sparse_ops.sparse_tensor_to_dense(partial_output)
return sparse_ops.sparse_tensor_to_dense(
sparse_ops.sparse_cross(partial_inputs))
def call(self, inputs):
depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
ragged_out = sparse_out = False
if all([ragged_tensor.is_ragged(inp) for inp in inputs]):
# (b/144500510) ragged.map_flat_values(sparse_cross_hashed, inputs) will
# cause kernel failure. Investigate and find a more efficient
# implementation
inputs = [inp.to_sparse() for inp in inputs]
if any([ragged_tensor.is_ragged(inp) for inp in inputs]):
ragged_out = True
else:
if any([ragged_tensor.is_ragged(inp) for inp in inputs]):
raise ValueError(
'Inputs must be either all `RaggedTensor`, or none of them should '
'be `RaggedTensor`, got {}'.format(inputs))
if any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]):
sparse_out = True
elif any([isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs]):
sparse_out = True
outputs = []
for depth in depth_tuple:
@ -229,15 +166,22 @@ class CategoryCrossing(Layer):
def compute_output_signature(self, input_spec):
input_shapes = [x.shape for x in input_spec]
output_shape = self.compute_output_shape(input_shapes)
output_dtype = dtypes.int64 if self.num_bins else dtypes.string
return sparse_tensor.SparseTensorSpec(
shape=output_shape, dtype=output_dtype)
if any([
isinstance(inp_spec, ragged_tensor.RaggedTensorSpec)
for inp_spec in input_spec
]):
return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string)
elif any([
isinstance(inp_spec, sparse_tensor.SparseTensorSpec)
for inp_spec in input_spec
]):
return sparse_tensor.SparseTensorSpec(
shape=output_shape, dtype=dtypes.string)
return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string)
def get_config(self):
config = {
'depth': self.depth,
'num_bins': self.num_bins,
'hash_key': self.hash_key
}
base_config = super(CategoryCrossing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -0,0 +1,64 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for keras.layers.preprocessing.normalization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import categorical_crossing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
@combinations.generate(
combinations.combine(
# Investigate why crossing is not supported with TPU.
distribution=strategy_combinations.strategies_minus_tpu,
mode=['eager', 'graph']))
class CategoryCrossingDistributionTest(
keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_distribution(self, distribution):
input_array_1 = np.array([['a', 'b'], ['c', 'd']])
input_array_2 = np.array([['e', 'f'], ['g', 'h']])
# pyformat: disable
expected_output = [[b'a_X_e', b'a_X_f', b'b_X_e', b'b_X_f'],
[b'c_X_g', b'c_X_h', b'd_X_g', b'd_X_h']]
config.set_soft_device_placement(True)
with distribution.scope():
input_data_1 = keras.Input(shape=(2,), dtype=dtypes.string)
input_data_2 = keras.Input(shape=(2,), dtype=dtypes.string)
input_data = [input_data_1, input_data_2]
layer = categorical_crossing.CategoryCrossing()
int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict([input_array_1, input_array_2])
self.assertAllEqual(expected_output, output_dataset)
if __name__ == '__main__':
test.main()

View File

@ -40,7 +40,7 @@ from tensorflow.python.platform import test
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class CategoryCrossingTest(keras_parameterized.TestCase):
def test_crossing_basic(self):
def test_crossing_sparse_inputs(self):
layer = categorical_crossing.CategoryCrossing()
inputs_0 = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
@ -52,36 +52,6 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices)
self.assertAllEqual([b'a_X_d', b'b_X_e', b'c_X_e'], output.values)
def test_crossing_sparse_inputs(self):
layer = categorical_crossing.CategoryCrossing(num_bins=1)
inputs_0 = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=['a', 'b', 'c'],
dense_shape=[2, 2])
inputs_1 = sparse_tensor.SparseTensor(
indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3])
output = layer([inputs_0, inputs_1])
self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices)
self.assertAllClose([0, 0, 0], output.values)
def test_crossing_sparse_inputs_with_hash_key(self):
layer = categorical_crossing.CategoryCrossing(num_bins=2, hash_key=133)
inputs_0 = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=['a', 'b', 'c'],
dense_shape=[2, 2])
inputs_1 = sparse_tensor.SparseTensor(
indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3])
output = layer([inputs_0, inputs_1])
self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices)
self.assertAllClose([1, 0, 1], output.values)
layer_2 = categorical_crossing.CategoryCrossing(num_bins=2, hash_key=137)
output = layer_2([inputs_0, inputs_1])
self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices)
# Note the output is different with above.
self.assertAllClose([0, 1, 0], output.values)
def test_crossing_sparse_inputs_depth_int(self):
layer = categorical_crossing.CategoryCrossing(depth=1)
inputs_0 = sparse_tensor.SparseTensor(
@ -127,35 +97,15 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
[expected_outputs_0, expected_outputs_1, expected_outputs_2], axis=0)
self.assertAllEqual(expected_out, output)
def test_crossing_hashed_two_bins(self):
layer = categorical_crossing.CategoryCrossing(num_bins=2)
inputs_0 = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=['a', 'b', 'c'],
dense_shape=[2, 2])
inputs_1 = sparse_tensor.SparseTensor(
indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3])
output = layer([inputs_0, inputs_1])
self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices)
self.assertEqual(output.values.numpy().max(), 1)
self.assertEqual(output.values.numpy().min(), 0)
def test_crossing_hashed_ragged_inputs(self):
layer = categorical_crossing.CategoryCrossing(num_bins=2)
def test_crossing_ragged_inputs(self):
inputs_0 = ragged_factory_ops.constant(
[['omar', 'skywalker'], ['marlo']],
dtype=dtypes.string)
inputs_1 = ragged_factory_ops.constant(
[['a'], ['b']],
dtype=dtypes.string)
out_data = layer([inputs_0, inputs_1])
expected_output = [[0, 0], [0]]
self.assertAllClose(expected_output, out_data)
inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
out_t = layer([inp_0_t, inp_1_t])
model = training.Model(inputs=[inp_0_t, inp_1_t], outputs=out_t)
self.assertAllClose(expected_output, model.predict([inputs_0, inputs_1]))
non_hashed_layer = categorical_crossing.CategoryCrossing()
out_t = non_hashed_layer([inp_0_t, inp_1_t])
@ -198,16 +148,6 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
self.assertIsInstance(output, ragged_tensor.RaggedTensor)
self.assertAllEqual(expected_output, output)
def test_invalid_mixed_sparse_and_ragged_input(self):
with self.assertRaises(ValueError):
layer = categorical_crossing.CategoryCrossing(num_bins=2)
inputs_0 = ragged_factory_ops.constant(
[['omar'], ['marlo']],
dtype=dtypes.string)
inputs_1 = sparse_tensor.SparseTensor(
indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3])
layer([inputs_0, inputs_1])
def test_crossing_with_dense_inputs(self):
layer = categorical_crossing.CategoryCrossing()
inputs_0 = np.asarray([[1, 2]])
@ -251,13 +191,6 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
self.assertAllEqual(expected_output,
model.predict([inputs_0, inputs_1, inputs_2]))
def test_crossing_hashed_with_dense_inputs(self):
layer = categorical_crossing.CategoryCrossing(num_bins=2)
inputs_0 = np.asarray([[1, 2]])
inputs_1 = np.asarray([[1, 3]])
output = layer([inputs_0, inputs_1])
self.assertAllClose([[1, 1, 0, 0]], output)
def test_crossing_compute_output_signature(self):
input_shapes = [
tensor_shape.TensorShape([2, 2]),
@ -272,18 +205,9 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
self.assertEqual(output_spec.shape.dims[0], input_shapes[0].dims[0])
self.assertEqual(output_spec.dtype, dtypes.string)
layer = categorical_crossing.CategoryCrossing(num_bins=2)
output_spec = layer.compute_output_signature(input_specs)
self.assertEqual(output_spec.shape.dims[0], input_shapes[0].dims[0])
self.assertEqual(output_spec.dtype, dtypes.int64)
def test_crossing_with_invalid_hash_key(self):
with self.assertRaises(ValueError):
_ = categorical_crossing.CategoryCrossing(hash_key=133)
@tf_test_util.run_v2_only
def test_config_with_custom_name(self):
layer = categorical_crossing.CategoryCrossing(num_bins=2, name='hashing')
layer = categorical_crossing.CategoryCrossing(depth=2, name='hashing')
config = layer.get_config()
layer_1 = categorical_crossing.CategoryCrossing.from_config(config)
self.assertEqual(layer_1.name, layer.name)

View File

@ -0,0 +1,222 @@
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryCrossing"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.categorical_crossing.CategoryCrossing\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'depth\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "partial_crossing"
argspec: "args=[\'self\', \'partial_inputs\', \'ragged_out\', \'sparse_out\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.keras.layers.experimental.preprocessing"
tf_module {
member {
name: "CategoryCrossing"
mtype: "<type \'type\'>"
}
member {
name: "CenterCrop"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,222 @@
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryCrossing"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.categorical_crossing.CategoryCrossing\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'depth\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "partial_crossing"
argspec: "args=[\'self\', \'partial_inputs\', \'ragged_out\', \'sparse_out\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.keras.layers.experimental.preprocessing"
tf_module {
member {
name: "CategoryCrossing"
mtype: "<type \'type\'>"
}
member {
name: "CenterCrop"
mtype: "<type \'type\'>"