Exporting CategoryEncoding layer.

PiperOrigin-RevId: 312727421
Change-Id: I62552e9b85398a27c5f584b2ea265d915c9661bb
This commit is contained in:
Zhenyu Tan 2020-05-21 13:13:26 -07:00 committed by TensorFlower Gardener
parent 8fdb54ea98
commit 17895acf34
17 changed files with 677 additions and 178 deletions

View File

@ -44,6 +44,9 @@ from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Res
# Preprocessing layers. # Preprocessing layers.
if tf2.enabled(): if tf2.enabled():
from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding
from tensorflow.python.keras.layers.preprocessing.category_encoding_v1 import CategoryEncoding as CategoryEncodingV1
CategoryEncodingV2 = CategoryEncoding
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1 from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
NormalizationV2 = Normalization NormalizationV2 = Normalization
@ -51,6 +54,9 @@ if tf2.enabled():
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1 from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
TextVectorizationV2 = TextVectorization TextVectorizationV2 = TextVectorization
else: else:
from tensorflow.python.keras.layers.preprocessing.category_encoding_v1 import CategoryEncoding
from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding as CategoryEncodingV2
CategoryEncodingV1 = CategoryEncoding
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2 from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2
NormalizationV1 = Normalization NormalizationV1 = Normalization

View File

@ -196,7 +196,7 @@ py_library(
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":categorical_encoding", ":category_encoding",
":string_lookup", ":string_lookup",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
@ -216,10 +216,10 @@ py_library(
) )
py_library( py_library(
name = "categorical_encoding", name = "category_encoding",
srcs = [ srcs = [
"categorical_encoding.py", "category_encoding.py",
"categorical_encoding_v1.py", "category_encoding_v1.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
@ -308,12 +308,12 @@ cuda_py_test(
) )
tf_py_test( tf_py_test(
name = "categorical_encoding_test", name = "category_encoding_test",
size = "medium", size = "medium",
srcs = ["categorical_encoding_test.py"], srcs = ["category_encoding_test.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":categorical_encoding", ":category_encoding",
":preprocessing_test_utils", ":preprocessing_test_utils",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/keras", "//tensorflow/python/keras",
@ -324,9 +324,9 @@ tf_py_test(
) )
distribute_py_test( distribute_py_test(
name = "categorical_encoding_distribution_test", name = "category_encoding_distribution_test",
srcs = ["categorical_encoding_distribution_test.py"], srcs = ["category_encoding_distribution_test.py"],
main = "categorical_encoding_distribution_test.py", main = "category_encoding_distribution_test.py",
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"multi_and_single_gpu", "multi_and_single_gpu",
@ -335,7 +335,7 @@ distribute_py_test(
"no_oss", # b/155502591 "no_oss", # b/155502591
], ],
deps = [ deps = [
":categorical_encoding", ":category_encoding",
"//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras", "//tensorflow/python/keras",

View File

@ -11,12 +11,12 @@ package(
exports_files(["LICENSE"]) exports_files(["LICENSE"])
tf_py_test( tf_py_test(
name = "categorical_encoding_benchmark", name = "category_encoding_benchmark",
srcs = ["categorical_encoding_benchmark.py"], srcs = ["category_encoding_benchmark.py"],
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
"//tensorflow:tensorflow_py", "//tensorflow:tensorflow_py",
"//tensorflow/python/keras/layers/preprocessing:categorical_encoding", "//tensorflow/python/keras/layers/preprocessing:category_encoding",
], ],
) )

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Benchmark for Keras categorical_encoding preprocessing layer.""" """Benchmark for Keras category_encoding preprocessing layer."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -26,7 +26,7 @@ from tensorflow.python import keras
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.keras.layers.preprocessing import categorical_encoding from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.platform import benchmark from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -42,7 +42,7 @@ class BenchmarkLayer(benchmark.Benchmark):
def run_dataset_implementation(self, output_mode, batch_size, sequence_length, def run_dataset_implementation(self, output_mode, batch_size, sequence_length,
max_tokens): max_tokens):
input_t = keras.Input(shape=(sequence_length,), dtype=dtypes.int32) input_t = keras.Input(shape=(sequence_length,), dtype=dtypes.int32)
layer = categorical_encoding.CategoricalEncoding( layer = category_encoding.CategoryEncoding(
max_tokens=max_tokens, output_mode=output_mode) max_tokens=max_tokens, output_mode=output_mode)
_ = layer(input_t) _ = layer(input_t)
@ -68,7 +68,7 @@ class BenchmarkLayer(benchmark.Benchmark):
ends.append(time.time()) ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches
name = "categorical_encoding|batch_%s|seq_length_%s|%s_max_tokens" % ( name = "category_encoding|batch_%s|seq_length_%s|%s_max_tokens" % (
batch_size, sequence_length, max_tokens) batch_size, sequence_length, max_tokens)
self.report_benchmark(iters=num_repeats, wall_time=avg_time, name=name) self.report_benchmark(iters=num_repeats, wall_time=avg_time, name=name)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras text CategoricalEncoding preprocessing layer.""" """Keras text CategoryEncoding preprocessing layer."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -32,11 +32,13 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_preprocessing_layer from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bincount_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import keras_export
TFIDF = "tf-idf" TFIDF = "tf-idf"
INT = "int" INT = "int"
@ -49,14 +51,26 @@ _NUM_ELEMENTS_NAME = "num_elements"
_IDF_NAME = "idf" _IDF_NAME = "idf"
class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): @keras_export("keras.layers.experimental.preprocessing.CategoryEncoding", v1=[])
"""Categorical encoding layer. class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
"""Category encoding layer.
This layer provides options for condensing data into a categorical encoding. This layer provides options for condensing data into a categorical encoding.
It accepts integer values as inputs and outputs a dense representation It accepts integer values as inputs and outputs a dense representation
(one sample = 1-index tensor of float values representing data about the (one sample = 1-index tensor of float values representing data about the
sample's tokens) of those inputs. sample's tokens) of those inputs.
Examples:
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
... max_tokens=4)
>>> layer([[0, 1], [0, 0], [1, 2], [3, 1]])
<tf.Tensor: shape=(4, 4), dtype=int64, numpy=
array([[1, 1, 0, 0],
[2, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 0, 1]])>
Attributes: Attributes:
max_tokens: The maximum size of the vocabulary for this layer. If None, max_tokens: The maximum size of the vocabulary for this layer. If None,
there is no cap on the size of the vocabulary. there is no cap on the size of the vocabulary.
@ -72,7 +86,6 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
sparse: Boolean. If true, returns a `SparseTensor` instead of a dense sparse: Boolean. If true, returns a `SparseTensor` instead of a dense
`Tensor`. Defaults to `False`. `Tensor`. Defaults to `False`.
""" """
# TODO(momernick): Add an examples section to the docstring.
def __init__(self, def __init__(self,
max_tokens=None, max_tokens=None,
@ -83,7 +96,7 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
layer_utils.validate_string_arg( layer_utils.validate_string_arg(
output_mode, output_mode,
allowable_strings=(COUNT, BINARY, TFIDF), allowable_strings=(COUNT, BINARY, TFIDF),
layer_name="CategoricalEncoding", layer_name="CategoryEncoding",
arg_name="output_mode") arg_name="output_mode")
# If max_tokens is set, the value must be greater than 1 - otherwise we # If max_tokens is set, the value must be greater than 1 - otherwise we
@ -92,10 +105,10 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
raise ValueError("max_tokens must be > 1.") raise ValueError("max_tokens must be > 1.")
# We need to call super() before we call _add_state_variable(). # We need to call super() before we call _add_state_variable().
combiner = _CategoricalEncodingCombiner( combiner = _CategoryEncodingCombiner(
compute_max_element=max_tokens is None, compute_max_element=max_tokens is None,
compute_idf=output_mode == TFIDF) compute_idf=output_mode == TFIDF)
super(CategoricalEncoding, self).__init__(combiner=combiner, **kwargs) super(CategoryEncoding, self).__init__(combiner=combiner, **kwargs)
self._max_tokens = max_tokens self._max_tokens = max_tokens
self._output_mode = output_mode self._output_mode = output_mode
@ -158,13 +171,12 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
RuntimeError: if the layer cannot be adapted at this time. RuntimeError: if the layer cannot be adapted at this time.
""" """
if not reset_state: if not reset_state:
raise ValueError("CategoricalEncoding does not support streaming adapts.") raise ValueError("CategoryEncoding does not support streaming adapts.")
if self._called and self._max_tokens is None: if self._called and self._max_tokens is None:
raise RuntimeError( raise RuntimeError("CategoryEncoding can't be adapted after being called "
"CategoricalEncoding can't be adapted after being called " "if max_tokens is None.")
"if max_tokens is None.") super(CategoryEncoding, self).adapt(data, reset_state)
super(CategoricalEncoding, self).adapt(data, reset_state)
def _set_state_variables(self, updates): def _set_state_variables(self, updates):
if not self.built: if not self.built:
@ -180,7 +192,7 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
"output_mode": self._output_mode, "output_mode": self._output_mode,
"sparse": self._sparse, "sparse": self._sparse,
} }
base_config = super(CategoricalEncoding, self).get_config() base_config = super(CategoryEncoding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def _convert_to_ndarray(self, x): def _convert_to_ndarray(self, x):
@ -237,65 +249,40 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
else: else:
out_depth = self._max_tokens out_depth = self._max_tokens
if self._sparse:
if self._output_mode != COUNT:
raise ValueError("Only supports `sparse=True` when `output_mode` "
' is \"count\", got {}'.format(self._output_mode))
inputs = self._convert_to_sparse_inputs(inputs)
# Consider having sparse.one_hot
# Append values to indices, and reduce sum to get the counts.
tokens = array_ops.expand_dims(
math_ops.cast(inputs.values, dtypes.int64), axis=1)
count_tokens = array_ops.concat([inputs.indices, tokens], axis=1)
count_values = array_ops.ones_like(inputs.values, dtype=dtypes.int64)
unreduced_count_shape = array_ops.concat(
[inputs.dense_shape, [out_depth]], axis=0)
counts = sparse_tensor.SparseTensor(
indices=count_tokens,
values=count_values,
dense_shape=unreduced_count_shape)
count_data = sparse_ops.sparse_reduce_sum_v2(
counts, axis=1, output_is_sparse=True)
return count_data
# If the input is a sparse tensor, we densify it with the default value of
# -1. Because -1 is ignored by one_hot, this effectively drops the non-set
# positions from the output encoding.
if isinstance(inputs, sparse_tensor.SparseTensor):
inputs = sparse_ops.sparse_tensor_to_dense(inputs, default_value=-1)
if self._output_mode == BINARY:
bool_one_hot_data = array_ops.one_hot(
inputs, depth=out_depth, on_value=True, off_value=False)
reduced_bool_data = math_ops.reduce_any(bool_one_hot_data, axis=1)
binary_data = math_ops.cast(reduced_bool_data, dtypes.int64)
binary_data.set_shape(tensor_shape.TensorShape((None, out_depth)))
return binary_data
one_hot_data = array_ops.one_hot(inputs, depth=out_depth)
counts = math_ops.reduce_sum(one_hot_data, axis=1)
if self._output_mode == COUNT:
count_data = math_ops.cast(counts, dtypes.int64)
count_data.set_shape(tensor_shape.TensorShape((None, out_depth)))
return count_data
tf_idf_data = math_ops.multiply(counts, self.tf_idf_weights)
tf_idf_data.set_shape(tensor_shape.TensorShape((None, out_depth)))
if self._output_mode == TFIDF: if self._output_mode == TFIDF:
# If the input is a sparse tensor, we densify it with the default value of
# -1. Because -1 is ignored by one_hot, this effectively drops the non-set
# positions from the output encoding.
if isinstance(inputs, sparse_tensor.SparseTensor):
inputs = sparse_ops.sparse_tensor_to_dense(inputs, default_value=-1)
one_hot_data = array_ops.one_hot(inputs, depth=out_depth)
counts = math_ops.reduce_sum(one_hot_data, axis=1)
tf_idf_data = math_ops.multiply(counts, self.tf_idf_weights)
tf_idf_data.set_shape(tensor_shape.TensorShape((None, out_depth)))
return tf_idf_data return tf_idf_data
# We can only get here if we didn't recognize the passed mode. binary_output = (self._output_mode == BINARY)
raise ValueError("Unknown output mode %s" % self._output_mode) if self._sparse:
return bincount_ops.sparse_bincount(
inputs, minlength=out_depth, axis=-1, binary_output=binary_output)
else:
result = bincount_ops.bincount(
inputs,
minlength=out_depth,
dtype=dtypes.int64,
axis=-1,
binary_output=binary_output)
result.set_shape(tensor_shape.TensorShape((None, out_depth)))
return result
class _CategoricalEncodingAccumulator( class _CategoryEncodingAccumulator(
collections.namedtuple("Accumulator", ["data", "per_doc_count_dict"])): collections.namedtuple("Accumulator", ["data", "per_doc_count_dict"])):
pass pass
class _CategoricalEncodingCombiner(base_preprocessing_layer.Combiner): class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner):
"""Combiner for the CategoricalEncoding preprocessing layer. """Combiner for the CategoryEncoding preprocessing layer.
This class encapsulates the logic for computing the number of elements in the This class encapsulates the logic for computing the number of elements in the
input dataset and the document frequency for each element. input dataset and the document frequency for each element.
@ -411,7 +398,7 @@ class _CategoricalEncodingCombiner(base_preprocessing_layer.Combiner):
def restore(self, output): def restore(self, output):
"""Creates an accumulator based on 'output'.""" """Creates an accumulator based on 'output'."""
raise NotImplementedError( raise NotImplementedError(
"CategoricalEncoding does not restore or support streaming updates.") "CategoryEncoding does not restore or support streaming updates.")
def serialize(self, accumulator): def serialize(self, accumulator):
"""Serializes an accumulator for a remote call.""" """Serializes an accumulator for a remote call."""
@ -452,4 +439,4 @@ class _CategoricalEncodingCombiner(base_preprocessing_layer.Combiner):
else: else:
per_doc_count_dict = None per_doc_count_dict = None
data = [0, 0] data = [0, 0]
return _CategoricalEncodingAccumulator(data, per_doc_count_dict) return _CategoryEncodingAccumulator(data, per_doc_count_dict)

View File

@ -21,39 +21,58 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers.preprocessing import categorical_encoding from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test from tensorflow.python.platform import test
def batch_wrapper(dataset, batch_size, distribution, repeat=None):
if repeat:
dataset = dataset.repeat(repeat)
# TPUs currently require fully defined input shapes, drop_remainder ensures
# the input will have fully defined shapes.
if isinstance(distribution,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
return dataset.batch(batch_size, drop_remainder=True)
else:
return dataset.batch(batch_size)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, # (b/156783625): Outside compilation failed for eager mode only.
distribution=strategy_combinations.strategies_minus_tpu,
mode=["eager", "graph"])) mode=["eager", "graph"]))
class CategoricalEncodingDistributionTest( class CategoryEncodingDistributionTest(
keras_parameterized.TestCase, keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest): preprocessing_test_utils.PreprocessingLayerTest):
def test_distribution(self, distribution): def test_distribution(self, distribution):
input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
inp_dataset = dataset_ops.DatasetV2.from_tensor_slices(input_array)
inp_dataset = batch_wrapper(inp_dataset, 2, distribution)
# pyformat: disable # pyformat: disable
expected_output = [[0, 1, 1, 1, 0, 0], expected_output = [[0, 1, 1, 1, 0, 0],
[1, 1, 0, 1, 0, 0]] [1, 1, 0, 1, 0, 0]]
# pyformat: enable # pyformat: enable
max_tokens = 6 max_tokens = 6
config.set_soft_device_placement(True)
with distribution.scope(): with distribution.scope():
input_data = keras.Input(shape=(4,), dtype=dtypes.int32) input_data = keras.Input(shape=(4,), dtype=dtypes.int32)
layer = categorical_encoding.CategoricalEncoding( layer = category_encoding.CategoryEncoding(
max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) max_tokens=max_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data) int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data) model = keras.Model(inputs=input_data, outputs=int_data)
output_dataset = model.predict(input_array) output_dataset = model.predict(inp_dataset)
self.assertAllEqual(expected_output, output_dataset) self.assertAllEqual(expected_output, output_dataset)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for Keras text categorical_encoding preprocessing layer.""" """Tests for Keras text category_encoding preprocessing layer."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -32,8 +32,8 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers.preprocessing import categorical_encoding from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.keras.layers.preprocessing import categorical_encoding_v1 from tensorflow.python.keras.layers.preprocessing import category_encoding_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
@ -44,15 +44,15 @@ from tensorflow.python.platform import test
def get_layer_class(): def get_layer_class():
if context.executing_eagerly(): if context.executing_eagerly():
return categorical_encoding.CategoricalEncoding return category_encoding.CategoryEncoding
else: else:
return categorical_encoding_v1.CategoricalEncoding return category_encoding_v1.CategoryEncoding
@keras_parameterized.run_all_keras_modes(always_skip_v1=True) @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class CategoricalEncodingInputTest( class CategoryEncodingInputTest(keras_parameterized.TestCase,
keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest
preprocessing_test_utils.PreprocessingLayerTest): ):
def test_dense_input_sparse_output(self): def test_dense_input_sparse_output(self):
input_array = constant_op.constant([[1, 2, 3], [3, 3, 0]]) input_array = constant_op.constant([[1, 2, 3], [3, 3, 0]])
@ -67,9 +67,7 @@ class CategoricalEncodingInputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
output_mode=categorical_encoding.COUNT,
sparse=True)
int_data = layer(input_data) int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data) model = keras.Model(inputs=input_data, outputs=int_data)
@ -80,7 +78,7 @@ class CategoricalEncodingInputTest(
# Assert sparse output is same as dense output. # Assert sparse output is same as dense output.
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, max_tokens=max_tokens,
output_mode=categorical_encoding.COUNT, output_mode=category_encoding.COUNT,
sparse=False) sparse=False)
int_data = layer(input_data) int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data) model = keras.Model(inputs=input_data, outputs=int_data)
@ -103,7 +101,7 @@ class CategoricalEncodingInputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True) input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) max_tokens=max_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data) int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -128,9 +126,7 @@ class CategoricalEncodingInputTest(
max_tokens = 6 max_tokens = 6
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
output_mode=categorical_encoding.COUNT,
sparse=True)
int_data = layer(input_data) int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data) model = keras.Model(inputs=input_data, outputs=int_data)
@ -141,7 +137,7 @@ class CategoricalEncodingInputTest(
# Assert sparse output is same as dense output. # Assert sparse output is same as dense output.
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, max_tokens=max_tokens,
output_mode=categorical_encoding.COUNT, output_mode=category_encoding.COUNT,
sparse=False) sparse=False)
int_data = layer(input_data) int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data) model = keras.Model(inputs=input_data, outputs=int_data)
@ -163,7 +159,7 @@ class CategoricalEncodingInputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) max_tokens=max_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data) int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -184,9 +180,7 @@ class CategoricalEncodingInputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
output_mode=categorical_encoding.COUNT,
sparse=True)
int_data = layer(input_data) int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data) model = keras.Model(inputs=input_data, outputs=int_data)
@ -197,7 +191,7 @@ class CategoricalEncodingInputTest(
# Assert sparse output is same as dense output. # Assert sparse output is same as dense output.
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, max_tokens=max_tokens,
output_mode=categorical_encoding.COUNT, output_mode=category_encoding.COUNT,
sparse=False) sparse=False)
int_data = layer(input_data) int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data) model = keras.Model(inputs=input_data, outputs=int_data)
@ -214,9 +208,7 @@ class CategoricalEncodingInputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
encoding_layer = get_layer_class()( encoding_layer = get_layer_class()(
max_tokens=max_tokens, max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True)
output_mode=categorical_encoding.COUNT,
sparse=True)
int_data = encoding_layer(input_data) int_data = encoding_layer(input_data)
output_data = math_ops.cast(int_data, dtypes.float32) output_data = math_ops.cast(int_data, dtypes.float32)
weights = variables.Variable([[.1], [.2], [.3], [.4]], dtype=dtypes.float32) weights = variables.Variable([[.1], [.2], [.3], [.4]], dtype=dtypes.float32)
@ -228,9 +220,9 @@ class CategoricalEncodingInputTest(
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class CategoricalEncodingAdaptTest( class CategoryEncodingAdaptTest(keras_parameterized.TestCase,
keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest
preprocessing_test_utils.PreprocessingLayerTest): ):
def test_sparse_adapt(self): def test_sparse_adapt(self):
vocab_data = sparse_ops.from_dense( vocab_data = sparse_ops.from_dense(
@ -248,7 +240,7 @@ class CategoricalEncodingAdaptTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True) input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
layer.adapt(vocab_dataset) layer.adapt(vocab_dataset)
int_data = layer(input_data) int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -273,7 +265,7 @@ class CategoricalEncodingAdaptTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
layer.adapt(vocab_dataset) layer.adapt(vocab_dataset)
int_data = layer(input_data) int_data = layer(input_data)
@ -296,7 +288,7 @@ class CategoricalEncodingAdaptTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) max_tokens=max_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data) int_data = layer(input_data)
layer.adapt(vocab_data) layer.adapt(vocab_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -306,7 +298,7 @@ class CategoricalEncodingAdaptTest(
self.assertAllEqual(expected_output, output_dataset) self.assertAllEqual(expected_output, output_dataset)
def test_hard_maximum_set_state_variables_after_build(self): def test_hard_maximum_set_state_variables_after_build(self):
state_variables = {categorical_encoding._NUM_ELEMENTS_NAME: 5} state_variables = {category_encoding._NUM_ELEMENTS_NAME: 5}
input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
# pyformat: disable # pyformat: disable
@ -318,7 +310,7 @@ class CategoricalEncodingAdaptTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) max_tokens=max_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data) int_data = layer(input_data)
layer._set_state_variables(state_variables) layer._set_state_variables(state_variables)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -339,7 +331,7 @@ class CategoricalEncodingAdaptTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
layer.build(input_data.shape) layer.build(input_data.shape)
layer.set_num_elements(max_tokens) layer.set_num_elements(max_tokens)
int_data = layer(input_data) int_data = layer(input_data)
@ -351,8 +343,7 @@ class CategoricalEncodingAdaptTest(
def test_set_weights_fails_on_wrong_size_weights(self): def test_set_weights_fails_on_wrong_size_weights(self):
tfidf_data = [.05, .5, .25, .2, .125] tfidf_data = [.05, .5, .25, .2, .125]
layer = get_layer_class()( layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.TFIDF)
max_tokens=6, output_mode=categorical_encoding.TFIDF)
with self.assertRaisesRegex(ValueError, ".*Layer weight shape.*"): with self.assertRaisesRegex(ValueError, ".*Layer weight shape.*"):
layer.set_weights([np.array(tfidf_data)]) layer.set_weights([np.array(tfidf_data)])
@ -360,7 +351,7 @@ class CategoricalEncodingAdaptTest(
def test_set_num_elements_after_call_fails(self): def test_set_num_elements_after_call_fails(self):
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
layer.set_num_elements(5) layer.set_num_elements(5)
@ -370,17 +361,17 @@ class CategoricalEncodingAdaptTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "can't be adapted"): with self.assertRaisesRegex(RuntimeError, "can't be adapted"):
layer.adapt(vocab_data) layer.adapt(vocab_data)
def test_set_state_variables_after_call_fails(self): def test_set_state_variables_after_call_fails(self):
state_variables = {categorical_encoding._NUM_ELEMENTS_NAME: 5} state_variables = {category_encoding._NUM_ELEMENTS_NAME: 5}
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
_ = layer(input_data) _ = layer(input_data)
with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"):
layer._set_state_variables(state_variables) layer._set_state_variables(state_variables)
@ -388,9 +379,9 @@ class CategoricalEncodingAdaptTest(
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class CategoricalEncodingOutputTest( class CategoryEncodingOutputTest(keras_parameterized.TestCase,
keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest
preprocessing_test_utils.PreprocessingLayerTest): ):
def test_binary_output_hard_maximum(self): def test_binary_output_hard_maximum(self):
input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
@ -404,7 +395,7 @@ class CategoricalEncodingOutputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) max_tokens=max_tokens, output_mode=category_encoding.BINARY)
int_data = layer(input_data) int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -424,7 +415,7 @@ class CategoricalEncodingOutputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.BINARY) max_tokens=None, output_mode=category_encoding.BINARY)
layer.set_weights([np.array(max_tokens)]) layer.set_weights([np.array(max_tokens)])
int_data = layer(input_data) int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -444,8 +435,7 @@ class CategoricalEncodingOutputTest(
expected_output_shape = [None, max_tokens] expected_output_shape = [None, max_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.COUNT)
max_tokens=6, output_mode=categorical_encoding.COUNT)
int_data = layer(input_data) int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -465,7 +455,7 @@ class CategoricalEncodingOutputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.COUNT) max_tokens=None, output_mode=category_encoding.COUNT)
layer.set_weights([np.array(max_tokens)]) layer.set_weights([np.array(max_tokens)])
int_data = layer(input_data) int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -488,8 +478,7 @@ class CategoricalEncodingOutputTest(
expected_output_shape = [None, max_tokens] expected_output_shape = [None, max_tokens]
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.TFIDF)
max_tokens=6, output_mode=categorical_encoding.TFIDF)
layer.set_tfidf_data(tfidf_data) layer.set_tfidf_data(tfidf_data)
int_data = layer(input_data) int_data = layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
@ -513,7 +502,7 @@ class CategoricalEncodingOutputTest(
input_data = keras.Input(shape=(None,), dtype=dtypes.int32) input_data = keras.Input(shape=(None,), dtype=dtypes.int32)
layer = get_layer_class()( layer = get_layer_class()(
max_tokens=None, output_mode=categorical_encoding.TFIDF) max_tokens=None, output_mode=category_encoding.TFIDF)
layer.set_num_elements(max_tokens) layer.set_num_elements(max_tokens)
layer.set_tfidf_data(tfidf_data) layer.set_tfidf_data(tfidf_data)
int_data = layer(input_data) int_data = layer(input_data)
@ -524,7 +513,7 @@ class CategoricalEncodingOutputTest(
self.assertAllClose(expected_output, output_dataset) self.assertAllClose(expected_output, output_dataset)
class CategoricalEncodingModelBuildingTest( class CategoryEncodingModelBuildingTest(
keras_parameterized.TestCase, keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest): preprocessing_test_utils.PreprocessingLayerTest):
@ -532,27 +521,27 @@ class CategoricalEncodingModelBuildingTest(
{ {
"testcase_name": "count_hard_max", "testcase_name": "count_hard_max",
"max_tokens": 5, "max_tokens": 5,
"output_mode": categorical_encoding.COUNT "output_mode": category_encoding.COUNT
}, { }, {
"testcase_name": "count_soft_max", "testcase_name": "count_soft_max",
"max_tokens": None, "max_tokens": None,
"output_mode": categorical_encoding.COUNT "output_mode": category_encoding.COUNT
}, { }, {
"testcase_name": "binary_hard_max", "testcase_name": "binary_hard_max",
"max_tokens": 5, "max_tokens": 5,
"output_mode": categorical_encoding.BINARY "output_mode": category_encoding.BINARY
}, { }, {
"testcase_name": "binary_soft_max", "testcase_name": "binary_soft_max",
"max_tokens": None, "max_tokens": None,
"output_mode": categorical_encoding.BINARY "output_mode": category_encoding.BINARY
}, { }, {
"testcase_name": "tfidf_hard_max", "testcase_name": "tfidf_hard_max",
"max_tokens": 5, "max_tokens": 5,
"output_mode": categorical_encoding.TFIDF "output_mode": category_encoding.TFIDF
}, { }, {
"testcase_name": "tfidf_soft_max", "testcase_name": "tfidf_soft_max",
"max_tokens": None, "max_tokens": None,
"output_mode": categorical_encoding.TFIDF "output_mode": category_encoding.TFIDF
}) })
def test_end_to_end_bagged_modeling(self, output_mode, max_tokens): def test_end_to_end_bagged_modeling(self, output_mode, max_tokens):
tfidf_data = np.array([.03, .5, .25, .2, .125]) tfidf_data = np.array([.03, .5, .25, .2, .125])
@ -564,7 +553,7 @@ class CategoricalEncodingModelBuildingTest(
weights = [] weights = []
if max_tokens is None: if max_tokens is None:
weights.append(np.array(5)) weights.append(np.array(5))
if output_mode == categorical_encoding.TFIDF: if output_mode == category_encoding.TFIDF:
weights.append(tfidf_data) weights.append(tfidf_data)
layer.set_weights(weights) layer.set_weights(weights)
@ -577,7 +566,7 @@ class CategoricalEncodingModelBuildingTest(
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class CategoricalEncodingCombinerTest( class CategoryEncodingCombinerTest(
keras_parameterized.TestCase, keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest): preprocessing_test_utils.PreprocessingLayerTest):
@ -617,8 +606,7 @@ class CategoricalEncodingCombinerTest(
def test_combiner_api_compatibility_int_mode(self): def test_combiner_api_compatibility_int_mode(self):
data = np.array([[1, 2, 3, 4], [1, 2, 3, 0]]) data = np.array([[1, 2, 3, 4], [1, 2, 3, 0]])
combiner = categorical_encoding._CategoricalEncodingCombiner( combiner = category_encoding._CategoryEncodingCombiner(compute_idf=False)
compute_idf=False)
expected_accumulator_output = { expected_accumulator_output = {
"max_element": np.array(4), "max_element": np.array(4),
"num_documents": np.array(2), "num_documents": np.array(2),
@ -636,8 +624,7 @@ class CategoricalEncodingCombinerTest(
def test_combiner_api_compatibility_tfidf_mode(self): def test_combiner_api_compatibility_tfidf_mode(self):
data = np.array([[1, 2, 3, 4], [1, 2, 3, 0]]) data = np.array([[1, 2, 3, 4], [1, 2, 3, 0]])
combiner = categorical_encoding._CategoricalEncodingCombiner( combiner = category_encoding._CategoryEncodingCombiner(compute_idf=True)
compute_idf=True)
expected_accumulator_output = { expected_accumulator_output = {
"max_element": np.array(4), "max_element": np.array(4),
"document_counts": np.array([1, 2, 2, 2, 1]), "document_counts": np.array([1, 2, 2, 2, 1]),
@ -693,7 +680,7 @@ class CategoricalEncodingCombinerTest(
expected_accumulator_output, expected_accumulator_output,
expected_extract_output, expected_extract_output,
compute_idf=True): compute_idf=True):
combiner = categorical_encoding._CategoricalEncodingCombiner( combiner = category_encoding._CategoryEncodingCombiner(
compute_idf=compute_idf) compute_idf=compute_idf)
expected_accumulator = combiner._create_accumulator() expected_accumulator = combiner._create_accumulator()
expected_accumulator = self.update_accumulator(expected_accumulator, expected_accumulator = self.update_accumulator(expected_accumulator,
@ -702,6 +689,5 @@ class CategoricalEncodingCombinerTest(
self.validate_accumulator_extract(combiner, data, expected_extract_output) self.validate_accumulator_extract(combiner, data, expected_extract_output)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -12,20 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tensorflow V1 version of the text categorical_encoding preprocessing layer.""" """Tensorflow V1 version of the text category_encoding preprocessing layer."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.keras.engine import base_preprocessing_layer_v1 from tensorflow.python.keras.engine import base_preprocessing_layer_v1
from tensorflow.python.keras.layers.preprocessing import categorical_encoding from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.util.tf_export import keras_export
class CategoricalEncoding(categorical_encoding.CategoricalEncoding, @keras_export(v1=["keras.layers.experimental.preprocessing.CategoryEncoding"])
base_preprocessing_layer_v1.CombinerPreprocessingLayer class CategoryEncoding(category_encoding.CategoryEncoding,
): base_preprocessing_layer_v1.CombinerPreprocessingLayer):
"""CategoricalEncoding layer. """CategoryEncoding layer.
This layer provides options for condensing input data into denser This layer provides options for condensing input data into denser
representations. It accepts either integer values or strings as inputs, representations. It accepts either integer values or strings as inputs,

View File

@ -26,7 +26,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
from tensorflow.python.keras.layers.preprocessing import categorical_encoding from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.keras.layers.preprocessing import string_lookup from tensorflow.python.keras.layers.preprocessing import string_lookup
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -42,10 +42,10 @@ LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation"
SPLIT_ON_WHITESPACE = "whitespace" SPLIT_ON_WHITESPACE = "whitespace"
TFIDF = categorical_encoding.TFIDF TFIDF = category_encoding.TFIDF
INT = categorical_encoding.INT INT = category_encoding.INT
BINARY = categorical_encoding.BINARY BINARY = category_encoding.BINARY
COUNT = categorical_encoding.COUNT COUNT = category_encoding.COUNT
# This is an explicit regex of all the tokens that will be stripped if # This is an explicit regex of all the tokens that will be stripped if
# LOWER_AND_STRIP_PUNCTUATION is set. If an application requires other # LOWER_AND_STRIP_PUNCTUATION is set. If an application requires other
@ -307,7 +307,7 @@ class TextVectorization(CombinerPreprocessingLayer):
# These are V1/V2 shim points. There are V1 implementations in the V1 class. # These are V1/V2 shim points. There are V1 implementations in the V1 class.
def _get_vectorization_class(self): def _get_vectorization_class(self):
return categorical_encoding.CategoricalEncoding return category_encoding.CategoryEncoding
def _get_index_lookup_class(self): def _get_index_lookup_class(self):
return string_lookup.StringLookup return string_lookup.StringLookup

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.keras.engine import base_preprocessing_layer_v1 from tensorflow.python.keras.engine import base_preprocessing_layer_v1
from tensorflow.python.keras.layers.preprocessing import categorical_encoding_v1 from tensorflow.python.keras.layers.preprocessing import category_encoding_v1
from tensorflow.python.keras.layers.preprocessing import string_lookup_v1 from tensorflow.python.keras.layers.preprocessing import string_lookup_v1
from tensorflow.python.keras.layers.preprocessing import text_vectorization from tensorflow.python.keras.layers.preprocessing import text_vectorization
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -77,7 +77,7 @@ class TextVectorization(text_vectorization.TextVectorization,
""" """
def _get_vectorization_class(self): def _get_vectorization_class(self):
return categorical_encoding_v1.CategoricalEncoding return category_encoding_v1.CategoryEncoding
def _get_index_lookup_class(self): def _get_index_lookup_class(self):
return string_lookup_v1.StringLookup return string_lookup_v1.StringLookup

View File

@ -46,6 +46,8 @@ from tensorflow.python.keras.layers import recurrent_v2
from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
from tensorflow.python.keras.layers import wrappers from tensorflow.python.keras.layers import wrappers
from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.keras.layers.preprocessing import category_crossing
from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.keras.layers.preprocessing import category_encoding_v1
from tensorflow.python.keras.layers.preprocessing import hashing from tensorflow.python.keras.layers.preprocessing import hashing
from tensorflow.python.keras.layers.preprocessing import image_preprocessing from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization
@ -61,15 +63,11 @@ ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional,
convolutional_recurrent, core, cudnn_recurrent, dense_attention, convolutional_recurrent, core, cudnn_recurrent, dense_attention,
embeddings, einsum_dense, local, merge, noise, normalization, embeddings, einsum_dense, local, merge, noise, normalization,
pooling, image_preprocessing, preprocessing_normalization_v1, pooling, image_preprocessing, preprocessing_normalization_v1,
preprocessing_text_vectorization_v1, preprocessing_text_vectorization_v1, recurrent, wrappers,
recurrent, wrappers, hashing, category_crossing) hashing, category_crossing, category_encoding_v1)
ALL_V2_MODULES = ( ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2,
rnn_cell_wrapper_v2, preprocessing_normalization, preprocessing_text_vectorization,
normalization_v2, category_encoding)
recurrent_v2,
preprocessing_normalization,
preprocessing_text_vectorization
)
# ALL_OBJECTS is meant to be a global mutable. Hence we need to make it # ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
# thread-local to avoid concurrent mutations. # thread-local to avoid concurrent mutations.
LOCAL = threading.local() LOCAL = threading.local()

View File

@ -0,0 +1,14 @@
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding.__metaclass__"
tf_class {
is_instance: "<type \'type\'>"
member_method {
name: "__init__"
}
member_method {
name: "mro"
}
member_method {
name: "register"
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,234 @@
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding_v1.CategoryEncoding\'>"
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding.CategoryEncoding\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer_v1.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'max_tokens\', \'output_mode\', \'sparse\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'count\', \'False\'], "
}
member_method {
name: "adapt"
argspec: "args=[\'self\', \'data\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_num_elements"
argspec: "args=[\'self\', \'num_elements\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_tfidf_data"
argspec: "args=[\'self\', \'tfidf_data\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "CategoryCrossing" name: "CategoryCrossing"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "CategoryEncoding"
mtype: "<type \'type\'>"
}
member { member {
name: "CenterCrop" name: "CenterCrop"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -0,0 +1,14 @@
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding.__metaclass__"
tf_class {
is_instance: "<type \'type\'>"
member_method {
name: "__init__"
}
member_method {
name: "mro"
}
member_method {
name: "register"
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,232 @@
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding.CategoryEncoding\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'max_tokens\', \'output_mode\', \'sparse\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'count\', \'False\'], "
}
member_method {
name: "adapt"
argspec: "args=[\'self\', \'data\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_num_elements"
argspec: "args=[\'self\', \'num_elements\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_tfidf_data"
argspec: "args=[\'self\', \'tfidf_data\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -4,6 +4,10 @@ tf_module {
name: "CategoryCrossing" name: "CategoryCrossing"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "CategoryEncoding"
mtype: "<type \'type\'>"
}
member { member {
name: "CenterCrop" name: "CenterCrop"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"