Export Hashing layer. Add separator
for CategoryCrossing and tf.sparse.cross. Add benchmarks for hashing.
PiperOrigin-RevId: 312579726 Change-Id: I0dc5bac26413ec114c57bd59e6810d6c641f600d
This commit is contained in:
parent
ae14cc6b1b
commit
26b2581519
@ -57,7 +57,8 @@ else:
|
||||
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization
|
||||
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
|
||||
TextVectorizationV1 = TextVectorization
|
||||
from tensorflow.python.keras.layers.preprocessing.categorical_crossing import CategoryCrossing
|
||||
from tensorflow.python.keras.layers.preprocessing.category_crossing import CategoryCrossing
|
||||
from tensorflow.python.keras.layers.preprocessing.hashing import Hashing
|
||||
|
||||
# Advanced activations.
|
||||
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
|
||||
|
@ -25,7 +25,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":categorical_crossing",
|
||||
":category_crossing",
|
||||
":discretization",
|
||||
":hashing",
|
||||
":image_preprocessing",
|
||||
@ -52,9 +52,9 @@ py_library(
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "categorical_crossing",
|
||||
name = "category_crossing",
|
||||
srcs = [
|
||||
"categorical_crossing.py",
|
||||
"category_crossing.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
@ -291,16 +291,16 @@ py_library(
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "categorical_crossing_test",
|
||||
name = "category_crossing_test",
|
||||
size = "medium",
|
||||
srcs = ["categorical_crossing_test.py"],
|
||||
srcs = ["category_crossing_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
"no_windows", # b/149031156
|
||||
],
|
||||
deps = [
|
||||
":categorical_crossing",
|
||||
":category_crossing",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
@ -343,9 +343,9 @@ distribute_py_test(
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "categorical_crossing_distribution_test",
|
||||
srcs = ["categorical_crossing_distribution_test.py"],
|
||||
main = "categorical_crossing_distribution_test.py",
|
||||
name = "category_crossing_distribution_test",
|
||||
srcs = ["category_crossing_distribution_test.py"],
|
||||
main = "category_crossing_distribution_test.py",
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
@ -354,7 +354,7 @@ distribute_py_test(
|
||||
"no_oss", # b/155502591
|
||||
],
|
||||
deps = [
|
||||
":categorical_crossing",
|
||||
":category_crossing",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/keras",
|
||||
|
@ -21,12 +21,22 @@ tf_py_test(
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "categorical_crossing_benchmark",
|
||||
srcs = ["categorical_crossing_benchmark.py"],
|
||||
name = "category_crossing_benchmark",
|
||||
srcs = ["category_crossing_benchmark.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/keras/layers/preprocessing:categorical_crossing",
|
||||
"//tensorflow/python/keras/layers/preprocessing:category_crossing",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "hashing_benchmark",
|
||||
srcs = ["hashing_benchmark.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/keras/layers/preprocessing:hashing",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras.layers.preprocessing import categorical_crossing
|
||||
from tensorflow.python.keras.layers.preprocessing import category_crossing
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import test
|
||||
@ -74,7 +74,7 @@ class BenchmarkLayer(benchmark.Benchmark):
|
||||
def bm_layer_implementation(self, batch_size):
|
||||
input_1 = keras.Input(shape=(1,), dtype=dtypes.int64, name="word")
|
||||
input_2 = keras.Input(shape=(1,), dtype=dtypes.int64, name="int")
|
||||
layer = categorical_crossing.CategoryCrossing()
|
||||
layer = category_crossing.CategoryCrossing()
|
||||
_ = layer([input_1, input_2])
|
||||
|
||||
num_repeats = 5
|
||||
@ -97,7 +97,7 @@ class BenchmarkLayer(benchmark.Benchmark):
|
||||
ends.append(time.time())
|
||||
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches
|
||||
name = "categorical_crossing|batch_%s" % batch_size
|
||||
name = "category_crossing|batch_%s" % batch_size
|
||||
baseline = self.run_dataset_implementation(batch_size)
|
||||
extras = {
|
||||
"dataset implementation baseline": baseline,
|
@ -0,0 +1,115 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Benchmark for Keras hashing preprocessing layer."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras.layers.preprocessing import hashing
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
||||
|
||||
# word_gen creates random sequences of ASCII letters (both lowercase and upper).
|
||||
# The number of unique strings is ~2,700.
|
||||
def word_gen():
|
||||
for _ in itertools.count(1):
|
||||
yield "".join(random.choice(string.ascii_letters) for i in range(2))
|
||||
|
||||
|
||||
class BenchmarkLayer(benchmark.Benchmark):
|
||||
"""Benchmark the layer forward pass."""
|
||||
|
||||
def run_dataset_implementation(self, batch_size):
|
||||
num_repeats = 5
|
||||
starts = []
|
||||
ends = []
|
||||
for _ in range(num_repeats):
|
||||
ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string,
|
||||
tensor_shape.TensorShape([]))
|
||||
ds = ds.shuffle(batch_size * 100)
|
||||
ds = ds.batch(batch_size)
|
||||
num_batches = 5
|
||||
ds = ds.take(num_batches)
|
||||
ds = ds.prefetch(num_batches)
|
||||
starts.append(time.time())
|
||||
# Benchmarked code begins here.
|
||||
for i in ds:
|
||||
_ = string_ops.string_to_hash_bucket(i, num_buckets=2)
|
||||
# Benchmarked code ends here.
|
||||
ends.append(time.time())
|
||||
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches
|
||||
return avg_time
|
||||
|
||||
def bm_layer_implementation(self, batch_size):
|
||||
input_1 = keras.Input(shape=(None,), dtype=dtypes.string, name="word")
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
_ = layer(input_1)
|
||||
|
||||
num_repeats = 5
|
||||
starts = []
|
||||
ends = []
|
||||
for _ in range(num_repeats):
|
||||
ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string,
|
||||
tensor_shape.TensorShape([]))
|
||||
ds = ds.shuffle(batch_size * 100)
|
||||
ds = ds.batch(batch_size)
|
||||
num_batches = 5
|
||||
ds = ds.take(num_batches)
|
||||
ds = ds.prefetch(num_batches)
|
||||
starts.append(time.time())
|
||||
# Benchmarked code begins here.
|
||||
for i in ds:
|
||||
_ = layer(i)
|
||||
# Benchmarked code ends here.
|
||||
ends.append(time.time())
|
||||
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches
|
||||
name = "hashing|batch_%s" % batch_size
|
||||
baseline = self.run_dataset_implementation(batch_size)
|
||||
extras = {
|
||||
"dataset implementation baseline": baseline,
|
||||
"delta seconds": (baseline - avg_time),
|
||||
"delta percent": ((baseline - avg_time) / baseline) * 100
|
||||
}
|
||||
self.report_benchmark(
|
||||
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
|
||||
|
||||
def benchmark_vocab_size_by_batch(self):
|
||||
for batch in [32, 64, 256]:
|
||||
self.bm_layer_implementation(batch_size=batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -49,6 +49,17 @@ class CategoryCrossing(Layer):
|
||||
[b'b_X_e'],
|
||||
[b'c_X_f']], dtype=object)>
|
||||
|
||||
|
||||
>>> inp_1 = tf.constant([['a'], ['b'], ['c']])
|
||||
>>> inp_2 = tf.constant([['d'], ['e'], ['f']])
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing(
|
||||
... separator='-')
|
||||
>>> layer([inp_1, inp_2])
|
||||
<tf.Tensor: shape=(3, 1), dtype=string, numpy=
|
||||
array([[b'a-d'],
|
||||
[b'b-e'],
|
||||
[b'c-f']], dtype=object)>
|
||||
|
||||
Arguments:
|
||||
depth: depth of input crossing. By default None, all inputs are crossed into
|
||||
one output. It can also be an int or tuple/list of ints. Passing an
|
||||
@ -59,6 +70,8 @@ class CategoryCrossing(Layer):
|
||||
equal to N1 or N2. Passing `None` means a single crossed output with all
|
||||
inputs. For example, with inputs `a`, `b` and `c`, `depth=2` means the
|
||||
output will be [a;b;c;cross(a, b);cross(bc);cross(ca)].
|
||||
separator: A string added between each input being joined. Defaults to
|
||||
'_X_'.
|
||||
name: Name to give to the layer.
|
||||
**kwargs: Keyword arguments to construct a layer.
|
||||
|
||||
@ -98,13 +111,12 @@ class CategoryCrossing(Layer):
|
||||
`[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
depth=None,
|
||||
name=None,
|
||||
**kwargs):
|
||||
# TODO(tanzheny): Consider making seperator configurable.
|
||||
def __init__(self, depth=None, name=None, separator=None, **kwargs):
|
||||
super(CategoryCrossing, self).__init__(name=name, **kwargs)
|
||||
self.depth = depth
|
||||
if separator is None:
|
||||
separator = '_X_'
|
||||
self.separator = separator
|
||||
if isinstance(depth, (tuple, list)):
|
||||
self._depth_tuple = depth
|
||||
elif depth is not None:
|
||||
@ -114,12 +126,16 @@ class CategoryCrossing(Layer):
|
||||
"""Gets the crossed output from a partial list/tuple of inputs."""
|
||||
# If ragged_out=True, convert output from sparse to ragged.
|
||||
if ragged_out:
|
||||
# TODO(momernick): Support separator with ragged_cross.
|
||||
if self.separator != '_X_':
|
||||
raise ValueError('Non-default separator with ragged input is not '
|
||||
'supported yet, given {}'.format(self.separator))
|
||||
return ragged_array_ops.cross(partial_inputs)
|
||||
elif sparse_out:
|
||||
return sparse_ops.sparse_cross(partial_inputs)
|
||||
return sparse_ops.sparse_cross(partial_inputs, separator=self.separator)
|
||||
else:
|
||||
return sparse_ops.sparse_tensor_to_dense(
|
||||
sparse_ops.sparse_cross(partial_inputs))
|
||||
sparse_ops.sparse_cross(partial_inputs, separator=self.separator))
|
||||
|
||||
def call(self, inputs):
|
||||
depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
|
||||
@ -178,6 +194,7 @@ class CategoryCrossing(Layer):
|
||||
def get_config(self):
|
||||
config = {
|
||||
'depth': self.depth,
|
||||
'separator': self.separator,
|
||||
}
|
||||
base_config = super(CategoryCrossing, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
@ -28,7 +28,7 @@ from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.layers.preprocessing import categorical_crossing
|
||||
from tensorflow.python.keras.layers.preprocessing import category_crossing
|
||||
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -72,7 +72,7 @@ class CategoryCrossingDistributionTest(
|
||||
input_data_2 = keras.Input(shape=(2,), dtype=dtypes.string,
|
||||
name='input_2')
|
||||
input_data = [input_data_1, input_data_2]
|
||||
layer = categorical_crossing.CategoryCrossing()
|
||||
layer = category_crossing.CategoryCrossing()
|
||||
int_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
output_dataset = model.predict(inp_dataset)
|
@ -29,7 +29,7 @@ from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers.preprocessing import categorical_crossing
|
||||
from tensorflow.python.keras.layers.preprocessing import category_crossing
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
@ -41,7 +41,7 @@ from tensorflow.python.platform import test
|
||||
class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_crossing_sparse_inputs(self):
|
||||
layer = categorical_crossing.CategoryCrossing()
|
||||
layer = category_crossing.CategoryCrossing()
|
||||
inputs_0 = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
values=['a', 'b', 'c'],
|
||||
@ -52,8 +52,32 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices)
|
||||
self.assertAllEqual([b'a_X_d', b'b_X_e', b'c_X_e'], output.values)
|
||||
|
||||
def test_crossing_sparse_inputs_custom_sep(self):
|
||||
layer = category_crossing.CategoryCrossing(separator='_Y_')
|
||||
inputs_0 = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
values=['a', 'b', 'c'],
|
||||
dense_shape=[2, 2])
|
||||
inputs_1 = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3])
|
||||
output = layer([inputs_0, inputs_1])
|
||||
self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices)
|
||||
self.assertAllEqual([b'a_Y_d', b'b_Y_e', b'c_Y_e'], output.values)
|
||||
|
||||
def test_crossing_sparse_inputs_empty_sep(self):
|
||||
layer = category_crossing.CategoryCrossing(separator='')
|
||||
inputs_0 = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
values=['a', 'b', 'c'],
|
||||
dense_shape=[2, 2])
|
||||
inputs_1 = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3])
|
||||
output = layer([inputs_0, inputs_1])
|
||||
self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices)
|
||||
self.assertAllEqual([b'ad', b'be', b'ce'], output.values)
|
||||
|
||||
def test_crossing_sparse_inputs_depth_int(self):
|
||||
layer = categorical_crossing.CategoryCrossing(depth=1)
|
||||
layer = category_crossing.CategoryCrossing(depth=1)
|
||||
inputs_0 = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 0], [2, 0]],
|
||||
values=['a', 'b', 'c'],
|
||||
@ -69,7 +93,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
self.assertAllEqual(expected_out, output)
|
||||
|
||||
def test_crossing_sparse_inputs_depth_tuple(self):
|
||||
layer = categorical_crossing.CategoryCrossing(depth=(2, 3))
|
||||
layer = category_crossing.CategoryCrossing(depth=(2, 3))
|
||||
inputs_0 = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 0], [2, 0]],
|
||||
values=['a', 'b', 'c'],
|
||||
@ -107,14 +131,14 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
||||
inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
||||
|
||||
non_hashed_layer = categorical_crossing.CategoryCrossing()
|
||||
non_hashed_layer = category_crossing.CategoryCrossing()
|
||||
out_t = non_hashed_layer([inp_0_t, inp_1_t])
|
||||
model = training.Model(inputs=[inp_0_t, inp_1_t], outputs=out_t)
|
||||
expected_output = [[b'omar_X_a', b'skywalker_X_a'], [b'marlo_X_b']]
|
||||
self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1]))
|
||||
|
||||
def test_crossing_ragged_inputs_depth_int(self):
|
||||
layer = categorical_crossing.CategoryCrossing(depth=1)
|
||||
layer = category_crossing.CategoryCrossing(depth=1)
|
||||
inputs_0 = ragged_factory_ops.constant([['a'], ['b'], ['c']])
|
||||
inputs_1 = ragged_factory_ops.constant([['d'], ['e'], ['f']])
|
||||
output = layer([inputs_0, inputs_1])
|
||||
@ -122,7 +146,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
self.assertIsInstance(output, ragged_tensor.RaggedTensor)
|
||||
self.assertAllEqual(expected_output, output)
|
||||
|
||||
layer = categorical_crossing.CategoryCrossing(depth=2)
|
||||
layer = category_crossing.CategoryCrossing(depth=2)
|
||||
inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
||||
inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string)
|
||||
out_t = layer([inp_0_t, inp_1_t])
|
||||
@ -132,7 +156,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1]))
|
||||
|
||||
def test_crossing_ragged_inputs_depth_tuple(self):
|
||||
layer = categorical_crossing.CategoryCrossing(depth=[2, 3])
|
||||
layer = category_crossing.CategoryCrossing(depth=[2, 3])
|
||||
inputs_0 = ragged_factory_ops.constant([['a'], ['b'], ['c']])
|
||||
inputs_1 = ragged_factory_ops.constant([['d'], ['e'], ['f']])
|
||||
inputs_2 = ragged_factory_ops.constant([['g'], ['h'], ['i']])
|
||||
@ -149,21 +173,21 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
self.assertAllEqual(expected_output, output)
|
||||
|
||||
def test_crossing_with_dense_inputs(self):
|
||||
layer = categorical_crossing.CategoryCrossing()
|
||||
layer = category_crossing.CategoryCrossing()
|
||||
inputs_0 = np.asarray([[1, 2]])
|
||||
inputs_1 = np.asarray([[1, 3]])
|
||||
output = layer([inputs_0, inputs_1])
|
||||
self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output)
|
||||
|
||||
def test_crossing_dense_inputs_depth_int(self):
|
||||
layer = categorical_crossing.CategoryCrossing(depth=1)
|
||||
layer = category_crossing.CategoryCrossing(depth=1)
|
||||
inputs_0 = constant_op.constant([['a'], ['b'], ['c']])
|
||||
inputs_1 = constant_op.constant([['d'], ['e'], ['f']])
|
||||
output = layer([inputs_0, inputs_1])
|
||||
expected_output = [[b'a', b'd'], [b'b', b'e'], [b'c', b'f']]
|
||||
self.assertAllEqual(expected_output, output)
|
||||
|
||||
layer = categorical_crossing.CategoryCrossing(depth=2)
|
||||
layer = category_crossing.CategoryCrossing(depth=2)
|
||||
inp_0_t = input_layer.Input(shape=(1,), dtype=dtypes.string)
|
||||
inp_1_t = input_layer.Input(shape=(1,), dtype=dtypes.string)
|
||||
out_t = layer([inp_0_t, inp_1_t])
|
||||
@ -174,7 +198,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1]))
|
||||
|
||||
def test_crossing_dense_inputs_depth_tuple(self):
|
||||
layer = categorical_crossing.CategoryCrossing(depth=[2, 3])
|
||||
layer = category_crossing.CategoryCrossing(depth=[2, 3])
|
||||
inputs_0 = constant_op.constant([['a'], ['b'], ['c']])
|
||||
inputs_1 = constant_op.constant([['d'], ['e'], ['f']])
|
||||
inputs_2 = constant_op.constant([['g'], ['h'], ['i']])
|
||||
@ -200,21 +224,21 @@ class CategoryCrossingTest(keras_parameterized.TestCase):
|
||||
tensor_spec.TensorSpec(input_shape, dtypes.string)
|
||||
for input_shape in input_shapes
|
||||
]
|
||||
layer = categorical_crossing.CategoryCrossing()
|
||||
layer = category_crossing.CategoryCrossing()
|
||||
output_spec = layer.compute_output_signature(input_specs)
|
||||
self.assertEqual(output_spec.shape.dims[0], input_shapes[0].dims[0])
|
||||
self.assertEqual(output_spec.dtype, dtypes.string)
|
||||
|
||||
@tf_test_util.run_v2_only
|
||||
def test_config_with_custom_name(self):
|
||||
layer = categorical_crossing.CategoryCrossing(depth=2, name='hashing')
|
||||
layer = category_crossing.CategoryCrossing(depth=2, name='hashing')
|
||||
config = layer.get_config()
|
||||
layer_1 = categorical_crossing.CategoryCrossing.from_config(config)
|
||||
layer_1 = category_crossing.CategoryCrossing.from_config(config)
|
||||
self.assertEqual(layer_1.name, layer.name)
|
||||
|
||||
layer = categorical_crossing.CategoryCrossing(name='hashing')
|
||||
layer = category_crossing.CategoryCrossing(name='hashing')
|
||||
config = layer.get_config()
|
||||
layer_1 = categorical_crossing.CategoryCrossing.from_config(config)
|
||||
layer_1 = category_crossing.CategoryCrossing.from_config(config)
|
||||
self.assertEqual(layer_1.name, layer.name)
|
||||
|
||||
|
@ -22,20 +22,28 @@ import functools
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.ops import gen_sparse_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# Default key from tf.sparse.cross_hashed
|
||||
_DEFAULT_SALT_KEY = [0xDECAFCAFFE, 0xDECAFCAFFE]
|
||||
|
||||
|
||||
@keras_export('keras.layers.experimental.preprocessing.Hashing')
|
||||
class Hashing(Layer):
|
||||
"""Implements categorical feature hashing, also known as "hashing trick".
|
||||
|
||||
This layer transforms categorical inputs to hashed output. It converts a
|
||||
sequence of int or string to a sequence of int. The stable hash function uses
|
||||
tensorflow::ops::Fingerprint to produce universal output that is consistent
|
||||
across platforms.
|
||||
This layer transforms single or multiple categorical inputs to hashed output.
|
||||
It converts a sequence of int or string to a sequence of int. The stable hash
|
||||
function uses tensorflow::ops::Fingerprint to produce universal output that
|
||||
is consistent across platforms.
|
||||
|
||||
This layer uses [FarmHash64](https://github.com/google/farmhash) by default,
|
||||
which provides a consistent hashed output across different platforms and is
|
||||
@ -48,50 +56,91 @@ class Hashing(Layer):
|
||||
the `salt` value serving as additional input to the hash function.
|
||||
|
||||
Example (FarmHash64):
|
||||
```python
|
||||
layer = Hashing(num_bins=3)
|
||||
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
layer(inputs)
|
||||
[[1], [0], [1], [1], [2]]
|
||||
```
|
||||
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3)
|
||||
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
>>> layer(inp)
|
||||
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
|
||||
array([[1],
|
||||
[0],
|
||||
[1],
|
||||
[1],
|
||||
[2]])>
|
||||
|
||||
|
||||
Example (SipHash64):
|
||||
```python
|
||||
layer = Hashing(num_bins=3, salt=[133, 137])
|
||||
inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
layer(inputs)
|
||||
[[1], [2], [1], [0], [2]]
|
||||
```
|
||||
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
|
||||
... salt=[133, 137])
|
||||
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
>>> layer(inp)
|
||||
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
|
||||
array([[1],
|
||||
[2],
|
||||
[1],
|
||||
[0],
|
||||
[2]])>
|
||||
|
||||
Example (Siphash64 with a single integer, same as `salt=[133, 133]`
|
||||
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3,
|
||||
... salt=133)
|
||||
>>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
>>> layer(inp)
|
||||
<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
|
||||
array([[0],
|
||||
[0],
|
||||
[2],
|
||||
[1],
|
||||
[0]])>
|
||||
|
||||
Reference: [SipHash with salt](https://www.131002.net/siphash/siphash.pdf)
|
||||
|
||||
Arguments:
|
||||
num_bins: Number of hash bins.
|
||||
salt: A tuple/list of 2 unsigned integer numbers. If passed, the hash
|
||||
function used will be SipHash64, with these values used as an additional
|
||||
input (known as a "salt" in cryptography).
|
||||
salt: A single unsigned integer or None.
|
||||
If passed, the hash function used will be SipHash64, with these values
|
||||
used as an additional input (known as a "salt" in cryptography).
|
||||
These should be non-zero. Defaults to `None` (in that
|
||||
case, the FarmHash64 hash function is used).
|
||||
case, the FarmHash64 hash function is used). It also supports
|
||||
tuple/list of 2 unsigned integer numbers, see reference paper for details.
|
||||
name: Name to give to the layer.
|
||||
**kwargs: Keyword arguments to construct a layer.
|
||||
|
||||
Input shape: A string, int32 or int64 tensor of shape
|
||||
`[batch_size, d1, ..., dm]`
|
||||
Input shape: A single or list of string, int32 or int64 `Tensor`,
|
||||
`SparseTensor` or `RaggedTensor` of shape `[batch_size, ...,]`
|
||||
|
||||
Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]`
|
||||
Output shape: An int64 `Tensor`, `SparseTensor` or `RaggedTensor` of shape
|
||||
`[batch_size, ...]`. If any input is `RaggedTensor` then output is
|
||||
`RaggedTensor`, otherwise if any input is `SparseTensor` then output is
|
||||
`SparseTensor`, otherwise the output is `Tensor`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_bins, salt=None, name=None, **kwargs):
|
||||
if num_bins is None or num_bins <= 0:
|
||||
raise ValueError('`num_bins` cannot be `None` or non-positive values.')
|
||||
if salt is not None:
|
||||
if not isinstance(salt, (tuple, list)) or len(salt) != 2:
|
||||
raise ValueError('`salt` must be a tuple or list of 2 unsigned '
|
||||
'integer numbers, got {}'.format(salt))
|
||||
super(Hashing, self).__init__(name=name, **kwargs)
|
||||
self.num_bins = num_bins
|
||||
self.salt = salt
|
||||
self.strong_hash = True if salt is not None else False
|
||||
if salt is not None:
|
||||
if isinstance(salt, (tuple, list)) and len(salt) == 2:
|
||||
self.salt = salt
|
||||
elif isinstance(salt, int):
|
||||
self.salt = [salt, salt]
|
||||
else:
|
||||
raise ValueError('`salt can only be a tuple of size 2 integers, or a '
|
||||
'single integer, given {}'.format(salt))
|
||||
else:
|
||||
self.salt = _DEFAULT_SALT_KEY
|
||||
|
||||
def call(self, inputs):
|
||||
if isinstance(inputs, (tuple, list)):
|
||||
return self._process_input_list(inputs)
|
||||
else:
|
||||
return self._process_single_input(inputs)
|
||||
|
||||
def _process_single_input(self, inputs):
|
||||
# Converts integer inputs to string.
|
||||
if inputs.dtype.is_integer:
|
||||
if isinstance(inputs, sparse_tensor.SparseTensor):
|
||||
@ -116,10 +165,38 @@ class Hashing(Layer):
|
||||
else:
|
||||
return str_to_hash_bucket(inputs, self.num_bins, name='hash')
|
||||
|
||||
def _process_input_list(self, inputs):
|
||||
# TODO(momernick): support ragged_cross_hashed with corrected fingerprint
|
||||
# and siphash.
|
||||
if any([isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs]):
|
||||
raise ValueError('Hashing with ragged input is not supported yet.')
|
||||
sparse_inputs = [
|
||||
inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor)
|
||||
]
|
||||
dense_inputs = [
|
||||
inp for inp in inputs if not isinstance(inp, sparse_tensor.SparseTensor)
|
||||
]
|
||||
all_dense = True if not sparse_inputs else False
|
||||
indices = [sp_inp.indices for sp_inp in sparse_inputs]
|
||||
values = [sp_inp.values for sp_inp in sparse_inputs]
|
||||
shapes = [sp_inp.dense_shape for sp_inp in sparse_inputs]
|
||||
indices_out, values_out, shapes_out = gen_sparse_ops.sparse_cross_hashed(
|
||||
indices=indices,
|
||||
values=values,
|
||||
shapes=shapes,
|
||||
dense_inputs=dense_inputs,
|
||||
num_buckets=self.num_bins,
|
||||
strong_hash=self.strong_hash,
|
||||
salt=self.salt)
|
||||
sparse_out = sparse_tensor.SparseTensor(indices_out, values_out, shapes_out)
|
||||
if all_dense:
|
||||
return sparse_ops.sparse_tensor_to_dense(sparse_out)
|
||||
return sparse_out
|
||||
|
||||
def _get_string_to_hash_bucket_fn(self):
|
||||
"""Returns the string_to_hash_bucket op to use based on `hasher_key`."""
|
||||
# string_to_hash_bucket_fast uses FarmHash64 as hash function.
|
||||
if self.salt is None:
|
||||
if not self.strong_hash:
|
||||
return string_ops.string_to_hash_bucket_fast
|
||||
# string_to_hash_bucket_strong uses SipHash64 as hash function.
|
||||
else:
|
||||
@ -127,16 +204,43 @@ class Hashing(Layer):
|
||||
string_ops.string_to_hash_bucket_strong, key=self.salt)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
if not isinstance(input_shape, (tuple, list)):
|
||||
return input_shape
|
||||
input_shapes = input_shape
|
||||
batch_size = None
|
||||
for inp_shape in input_shapes:
|
||||
inp_tensor_shape = tensor_shape.TensorShape(inp_shape).as_list()
|
||||
if len(inp_tensor_shape) != 2:
|
||||
raise ValueError('Inputs must be rank 2, get {}'.format(input_shapes))
|
||||
if batch_size is None:
|
||||
batch_size = inp_tensor_shape[0]
|
||||
# The second dimension is dynamic based on inputs.
|
||||
output_shape = [batch_size, None]
|
||||
return tensor_shape.TensorShape(output_shape)
|
||||
|
||||
def compute_output_signature(self, input_spec):
|
||||
output_shape = self.compute_output_shape(input_spec.shape.as_list())
|
||||
output_dtype = dtypes.int64
|
||||
if isinstance(input_spec, sparse_tensor.SparseTensorSpec):
|
||||
if not isinstance(input_spec, (tuple, list)):
|
||||
output_shape = self.compute_output_shape(input_spec.shape)
|
||||
output_dtype = dtypes.int64
|
||||
if isinstance(input_spec, sparse_tensor.SparseTensorSpec):
|
||||
return sparse_tensor.SparseTensorSpec(
|
||||
shape=output_shape, dtype=output_dtype)
|
||||
else:
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
||||
input_shapes = [x.shape for x in input_spec]
|
||||
output_shape = self.compute_output_shape(input_shapes)
|
||||
if any([
|
||||
isinstance(inp_spec, ragged_tensor.RaggedTensorSpec)
|
||||
for inp_spec in input_spec
|
||||
]):
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
|
||||
elif any([
|
||||
isinstance(inp_spec, sparse_tensor.SparseTensorSpec)
|
||||
for inp_spec in input_spec
|
||||
]):
|
||||
return sparse_tensor.SparseTensorSpec(
|
||||
shape=output_shape, dtype=output_dtype)
|
||||
else:
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
|
||||
shape=output_shape, dtype=dtypes.int64)
|
||||
return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
|
||||
|
||||
def get_config(self):
|
||||
config = {'num_bins': self.num_bins, 'salt': self.salt}
|
||||
|
@ -51,6 +51,15 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[0], [0], [1], [0], [0]], output)
|
||||
|
||||
def test_hash_dense_multi_inputs_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
['skywalker']])
|
||||
inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
output = layer([inp_1, inp_2])
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
self.assertAllClose([[0], [0], [1], [1], [0]], output)
|
||||
|
||||
def test_hash_dense_int_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=3)
|
||||
inp = np.asarray([[0], [1], [2], [3], [4]])
|
||||
@ -72,6 +81,21 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
# Note the result is different from (133, 137).
|
||||
self.assertAllClose([[1], [0], [1], [0], [1]], output_2)
|
||||
|
||||
def test_hash_dense_multi_inputs_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'],
|
||||
['skywalker']])
|
||||
inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']])
|
||||
output = layer([inp_1, inp_2])
|
||||
# Assert equal for hashed output that should be true on all platforms.
|
||||
# Note the result is different from FarmHash.
|
||||
self.assertAllClose([[0], [1], [0], [0], [1]], output)
|
||||
|
||||
layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
|
||||
output_2 = layer_2([inp_1, inp_2])
|
||||
# Note the result is different from (133, 137).
|
||||
self.assertAllClose([[1], [1], [1], [0], [1]], output_2)
|
||||
|
||||
def test_hash_dense_int_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
|
||||
inp = np.asarray([[0], [1], [2], [3], [4]])
|
||||
@ -90,6 +114,19 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
self.assertAllClose(indices, output.indices)
|
||||
self.assertAllClose([0, 0, 1, 0, 0], output.values)
|
||||
|
||||
def test_hash_sparse_multi_inputs_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
indices = [[0, 0], [1, 0], [2, 0]]
|
||||
inp_1 = sparse_tensor.SparseTensor(
|
||||
indices=indices,
|
||||
values=['omar', 'stringer', 'marlo'],
|
||||
dense_shape=[3, 1])
|
||||
inp_2 = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=['A', 'B', 'C'], dense_shape=[3, 1])
|
||||
output = layer([inp_1, inp_2])
|
||||
self.assertAllClose(indices, output.indices)
|
||||
self.assertAllClose([0, 0, 1], output.values)
|
||||
|
||||
def test_hash_sparse_int_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=3)
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
@ -116,6 +153,25 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
# The result should be same with test_hash_dense_input_siphash.
|
||||
self.assertAllClose([1, 0, 1, 0, 1], output.values)
|
||||
|
||||
def test_hash_sparse_multi_inputs_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
indices = [[0, 0], [1, 0], [2, 0]]
|
||||
inp_1 = sparse_tensor.SparseTensor(
|
||||
indices=indices,
|
||||
values=['omar', 'stringer', 'marlo'],
|
||||
dense_shape=[3, 1])
|
||||
inp_2 = sparse_tensor.SparseTensor(
|
||||
indices=indices, values=['A', 'B', 'C'], dense_shape=[3, 1])
|
||||
output = layer([inp_1, inp_2])
|
||||
# The result should be same with test_hash_dense_input_siphash.
|
||||
self.assertAllClose(indices, output.indices)
|
||||
self.assertAllClose([0, 1, 0], output.values)
|
||||
|
||||
layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137])
|
||||
output = layer_2([inp_1, inp_2])
|
||||
# The result should be same with test_hash_dense_input_siphash.
|
||||
self.assertAllClose([1, 1, 1], output.values)
|
||||
|
||||
def test_hash_sparse_int_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]
|
||||
@ -140,6 +196,17 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_hash_ragged_string_multi_inputs_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=2)
|
||||
inp_data_1 = ragged_factory_ops.constant(
|
||||
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||
dtype=dtypes.string)
|
||||
inp_data_2 = ragged_factory_ops.constant(
|
||||
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||
dtype=dtypes.string)
|
||||
with self.assertRaisesRegexp(ValueError, 'not supported yet'):
|
||||
_ = layer([inp_data_1, inp_data_2])
|
||||
|
||||
def test_hash_ragged_int_input_farmhash(self):
|
||||
layer = hashing.Hashing(num_bins=3)
|
||||
inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]],
|
||||
@ -178,6 +245,17 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
model = training.Model(inputs=inp_t, outputs=out_t)
|
||||
self.assertAllClose(out_data, model.predict(inp_data))
|
||||
|
||||
def test_hash_ragged_string_multi_inputs_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=2, salt=[133, 137])
|
||||
inp_data_1 = ragged_factory_ops.constant(
|
||||
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||
dtype=dtypes.string)
|
||||
inp_data_2 = ragged_factory_ops.constant(
|
||||
[['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
|
||||
dtype=dtypes.string)
|
||||
with self.assertRaisesRegexp(ValueError, 'not supported yet'):
|
||||
_ = layer([inp_data_1, inp_data_2])
|
||||
|
||||
def test_hash_ragged_int_input_siphash(self):
|
||||
layer = hashing.Hashing(num_bins=3, salt=[133, 137])
|
||||
inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]],
|
||||
@ -197,11 +275,11 @@ class HashingTest(keras_parameterized.TestCase):
|
||||
_ = hashing.Hashing(num_bins=None)
|
||||
with self.assertRaisesRegexp(ValueError, 'cannot be `None`'):
|
||||
_ = hashing.Hashing(num_bins=-1)
|
||||
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||
with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'):
|
||||
_ = hashing.Hashing(num_bins=2, salt='string')
|
||||
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||
with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'):
|
||||
_ = hashing.Hashing(num_bins=2, salt=[1])
|
||||
with self.assertRaisesRegexp(ValueError, 'must be a tuple'):
|
||||
with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'):
|
||||
_ = hashing.Hashing(num_bins=1, salt=constant_op.constant([133, 137]))
|
||||
|
||||
def test_hash_compute_output_signature(self):
|
||||
|
@ -45,6 +45,8 @@ from tensorflow.python.keras.layers import recurrent
|
||||
from tensorflow.python.keras.layers import recurrent_v2
|
||||
from tensorflow.python.keras.layers import rnn_cell_wrapper_v2
|
||||
from tensorflow.python.keras.layers import wrappers
|
||||
from tensorflow.python.keras.layers.preprocessing import category_crossing
|
||||
from tensorflow.python.keras.layers.preprocessing import hashing
|
||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization
|
||||
from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1
|
||||
@ -60,7 +62,7 @@ ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional,
|
||||
embeddings, einsum_dense, local, merge, noise, normalization,
|
||||
pooling, image_preprocessing, preprocessing_normalization_v1,
|
||||
preprocessing_text_vectorization_v1,
|
||||
recurrent, wrappers)
|
||||
recurrent, wrappers, hashing, category_crossing)
|
||||
ALL_V2_MODULES = (
|
||||
rnn_cell_wrapper_v2,
|
||||
normalization_v2,
|
||||
|
@ -27,6 +27,7 @@ import numbers
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat as tf_compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -569,7 +570,7 @@ def sparse_add_v2(a, b, threshold=0):
|
||||
|
||||
|
||||
@tf_export("sparse.cross")
|
||||
def sparse_cross(inputs, name=None):
|
||||
def sparse_cross(inputs, name=None, separator=None):
|
||||
"""Generates sparse cross from a list of sparse and dense tensors.
|
||||
|
||||
For example, if the inputs are
|
||||
@ -590,14 +591,39 @@ def sparse_cross(inputs, name=None):
|
||||
[1, 0]: "b_X_e_X_g"
|
||||
[1, 1]: "c_X_e_X_g"
|
||||
|
||||
Customized separator "_Y_":
|
||||
|
||||
>>> inp_0 = tf.constant([['a'], ['b']])
|
||||
>>> inp_1 = tf.constant([['c'], ['d']])
|
||||
>>> output = tf.sparse.cross([inp_0, inp_1], separator='_Y_')
|
||||
>>> output.values
|
||||
<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a_Y_c', b'b_Y_d'],
|
||||
dtype=object)>
|
||||
|
||||
|
||||
Args:
|
||||
inputs: An iterable of `Tensor` or `SparseTensor`.
|
||||
name: Optional name for the op.
|
||||
separator: A string added between each string being joined. Defaults to
|
||||
'_X_'.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` of type `string`.
|
||||
"""
|
||||
return _sparse_cross_internal(inputs=inputs, hashed_output=False, name=name)
|
||||
if separator is None and not tf_compat.forward_compatible(2020, 6, 14):
|
||||
return _sparse_cross_internal(inputs=inputs, hashed_output=False, name=name)
|
||||
if separator is None:
|
||||
separator = "_X_"
|
||||
separator = ops.convert_to_tensor(separator, dtypes.string)
|
||||
indices, values, shapes, dense_inputs = _sparse_cross_internval_v2(inputs)
|
||||
indices_out, values_out, shape_out = gen_sparse_ops.sparse_cross_v2(
|
||||
indices=indices,
|
||||
values=values,
|
||||
shapes=shapes,
|
||||
dense_inputs=dense_inputs,
|
||||
sep=separator,
|
||||
name=name)
|
||||
return sparse_tensor.SparseTensor(indices_out, values_out, shape_out)
|
||||
|
||||
|
||||
_sparse_cross = sparse_cross
|
||||
@ -655,6 +681,32 @@ _sparse_cross_hashed = sparse_cross_hashed
|
||||
_DEFAULT_HASH_KEY = 0xDECAFCAFFE
|
||||
|
||||
|
||||
def _sparse_cross_internval_v2(inputs):
|
||||
"""See gen_sparse_ops.sparse_cross_v2."""
|
||||
if not isinstance(inputs, (tuple, list)):
|
||||
raise TypeError("Inputs must be a list")
|
||||
if not all(
|
||||
isinstance(i, sparse_tensor.SparseTensor) or isinstance(i, ops.Tensor)
|
||||
for i in inputs):
|
||||
raise TypeError("All inputs must be Tensor or SparseTensor.")
|
||||
sparse_inputs = [
|
||||
i for i in inputs if isinstance(i, sparse_tensor.SparseTensor)
|
||||
]
|
||||
dense_inputs = [
|
||||
i for i in inputs if not isinstance(i, sparse_tensor.SparseTensor)
|
||||
]
|
||||
indices = [sp_input.indices for sp_input in sparse_inputs]
|
||||
values = [sp_input.values for sp_input in sparse_inputs]
|
||||
shapes = [sp_input.dense_shape for sp_input in sparse_inputs]
|
||||
for i in range(len(values)):
|
||||
if values[i].dtype != dtypes.string:
|
||||
values[i] = math_ops.cast(values[i], dtypes.int64)
|
||||
for i in range(len(dense_inputs)):
|
||||
if dense_inputs[i].dtype != dtypes.string:
|
||||
dense_inputs[i] = math_ops.cast(dense_inputs[i], dtypes.int64)
|
||||
return indices, values, shapes, dense_inputs
|
||||
|
||||
|
||||
def _sparse_cross_internal(inputs,
|
||||
hashed_output=False,
|
||||
num_buckets=0,
|
||||
|
@ -1,6 +1,6 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryCrossing"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.categorical_crossing.CategoryCrossing\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_crossing.CategoryCrossing\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -113,7 +113,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'depth\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -0,0 +1,218 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Hashing"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.hashing.Hashing\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "inbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "metrics"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "outbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "submodules"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_metric"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_signature"
|
||||
argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "count_params"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_name_scope"
|
||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "CenterCrop"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Hashing"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Normalization"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -22,7 +22,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "cross"
|
||||
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'inputs\', \'name\', \'separator\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "cross_hashed"
|
||||
|
@ -1,6 +1,6 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.CategoryCrossing"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.categorical_crossing.CategoryCrossing\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_crossing.CategoryCrossing\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
@ -113,7 +113,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'depth\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -0,0 +1,218 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Hashing"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.hashing.Hashing\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "inbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "metrics"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "outbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "submodules"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_metric"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_signature"
|
||||
argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "count_params"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_name_scope"
|
||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "CenterCrop"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Hashing"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Normalization"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -18,7 +18,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "cross"
|
||||
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'inputs\', \'name\', \'separator\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "cross_hashed"
|
||||
|
Loading…
x
Reference in New Issue
Block a user