Add DiscretizingCombiner to KPL.

PiperOrigin-RevId: 339041119
Change-Id: I7e972f40e9fcce76e93ec0cb6dbf0844a4150c1c
This commit is contained in:
A. Unique TensorFlower 2020-10-26 08:16:50 -07:00 committed by TensorFlower Gardener
parent 65efe3b554
commit 8c78dae8fb
13 changed files with 607 additions and 13 deletions

View File

@ -23,6 +23,9 @@
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* `tf.keras`:
* Improvements to Keras preprocessing layers:
* Discretization combiner implemented, with additional arg `epsilon`.
## Thanks to our Contributors

View File

@ -47,6 +47,7 @@ py_library(
name = "discretization",
srcs = [
"discretization.py",
"discretization_v1.py",
],
srcs_version = "PY2AND3",
deps = [
@ -54,6 +55,7 @@ py_library(
"//tensorflow/python:boosted_trees_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:resources",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tf_export",
@ -458,6 +460,7 @@ tf_py_test(
size = "small",
srcs = ["discretization_test.py"],
python_version = "PY3",
shard_count = 4,
tags = ["no_rocm"],
deps = [
":discretization",

View File

@ -103,6 +103,16 @@ tf_py_test(
],
)
tf_py_test(
name = "discretization_adapt_benchmark",
srcs = ["discretization_adapt_benchmark.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python/keras/layers/preprocessing:discretization",
],
)
cuda_py_test(
name = "image_preproc_benchmark",
srcs = ["image_preproc_benchmark.py"],

View File

@ -0,0 +1,120 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for Keras discretization preprocessing layer's adapt method."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl import flags
import numpy as np
from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.keras.layers.preprocessing import discretization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
FLAGS = flags.FLAGS
EPSILON = 0.1
v2_compat.enable_v2_behavior()
def reduce_fn(state, values, epsilon=EPSILON):
"""tf.data.Dataset-friendly implementation of mean and variance."""
state_, = state
summary = discretization.summarize(values, epsilon)
if np.sum(state_[:, 0]) == 0:
return (summary,)
return (discretization.merge_summaries(state_, summary, epsilon),)
class BenchmarkAdapt(benchmark.Benchmark):
"""Benchmark adapt."""
def run_dataset_implementation(self, num_elements, batch_size):
input_t = keras.Input(shape=(1,))
layer = discretization.Discretization()
_ = layer(input_t)
num_repeats = 5
starts = []
ends = []
for _ in range(num_repeats):
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.map(
lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
ds = ds.batch(batch_size)
starts.append(time.time())
# Benchmarked code begins here.
state = ds.reduce((np.zeros((1, 2)),), reduce_fn)
bins = discretization.get_bucket_boundaries(state, 100)
layer.set_weights([bins])
# Benchmarked code ends here.
ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts))
return avg_time
def bm_adapt_implementation(self, num_elements, batch_size):
"""Test the KPL adapt implementation."""
input_t = keras.Input(shape=(1,), dtype=dtypes.float32)
layer = discretization.Discretization()
_ = layer(input_t)
num_repeats = 5
starts = []
ends = []
for _ in range(num_repeats):
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.map(
lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
ds = ds.batch(batch_size)
starts.append(time.time())
# Benchmarked code begins here.
layer.adapt(ds)
# Benchmarked code ends here.
ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts))
name = "discretization_adapt|%s_elements|batch_%s" % (num_elements,
batch_size)
baseline = self.run_dataset_implementation(num_elements, batch_size)
extras = {
"tf.data implementation baseline": baseline,
"delta seconds": (baseline - avg_time),
"delta percent": ((baseline - avg_time) / baseline) * 100
}
self.report_benchmark(
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
def benchmark_vocab_size_by_batch(self):
for vocab_size in [100, 1000, 10000, 100000, 1000000]:
for batch in [64 * 2048]:
self.bm_adapt_implementation(vocab_size, batch)
if __name__ == "__main__":
test.main()

View File

@ -17,24 +17,124 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_boosted_trees_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.parallel_for import control_flow_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import keras_export
_BINS_NAME = "bins"
def summarize(values, epsilon):
"""Reduce a 1D sequence of values to a summary.
This algorithm is based on numpy.quantiles but modified to allow for
intermediate steps between multiple data sets. It first finds the target
number of bins as the reciprocal of epsilon and then takes the individual
values spaced at appropriate intervals to arrive at that target.
The final step is to return the corresponding counts between those values
If the target num_bins is larger than the size of values, the whole array is
returned (with weights of 1).
Arguments:
values: 1-D `np.ndarray` to be summarized.
epsilon: A `'float32'` that determines the approxmiate desired precision.
Returns:
A 2-D `np.ndarray` that is a summary of the inputs. First column is the
interpolated partition values, the second is the weights (counts).
"""
num_bins = 1.0 / epsilon
value_shape = values.shape
n = np.prod([[(1 if dim is None else dim) for dim in value_shape]])
if num_bins >= n:
return np.hstack((np.expand_dims(np.sort(values), 1), np.ones((n, 1))))
step_size = int(n / num_bins)
partition_indices = np.arange(step_size, n, step_size, np.int64)
part = np.partition(values, partition_indices)[partition_indices]
return np.hstack((np.expand_dims(part, 1),
step_size * np.ones((np.prod(part.shape), 1))))
def compress(summary, epsilon):
"""Compress a summary to within `epsilon` accuracy.
The compression step is needed to keep the summary sizes small after merging,
and also used to return the final target boundaries. It finds the new bins
based on interpolating cumulative weight percentages from the large summary.
Taking the difference of the cumulative weights from the previous bin's
cumulative weight will give the new weight for that bin.
Arguments:
summary: 2-D `np.ndarray` summary to be compressed.
epsilon: A `'float32'` that determines the approxmiate desired precision.
Returns:
A 2-D `np.ndarray` that is a compressed summary. First column is the
interpolated partition values, the second is the weights (counts).
"""
if np.prod(summary[:, 0].shape) * epsilon < 1:
return summary
percents = epsilon + np.arange(0.0, 1.0, epsilon)
cum_weights = summary[:, 1].cumsum()
cum_weight_percents = cum_weights / cum_weights[-1]
new_bins = np.interp(percents, cum_weight_percents, summary[:, 0])
cum_weights = np.interp(percents, cum_weight_percents, cum_weights)
new_weights = cum_weights - np.concatenate((np.array([0]), cum_weights[:-1]))
return np.hstack((np.expand_dims(new_bins, 1),
np.expand_dims(new_weights, 1)))
def merge_summaries(prev_summary, next_summary, epsilon):
"""Weighted merge sort of summaries.
Given two summaries of distinct data, this function merges (and compresses)
them to stay within `epsilon` error tolerance.
Arguments:
prev_summary: 2-D `np.ndarray` summary to be merged with `next_summary`.
next_summary: 2-D `np.ndarray` summary to be merged with `prev_summary`.
epsilon: A `'float32'` that determines the approxmiate desired precision.
Returns:
A 2-D `np.ndarray` that is a merged summary. First column is the
interpolated partition values, the second is the weights (counts).
"""
merged = np.concatenate((prev_summary, next_summary))
merged = merged[merged[:, 0].argsort()]
if np.prod(merged.shape) * epsilon < 1:
return merged
return compress(merged, epsilon)
def get_bucket_boundaries(summary, num_bins):
return compress(summary, 1.0 / num_bins)[:-1, 0]
@keras_export("keras.layers.experimental.preprocessing.Discretization")
class Discretization(base_preprocessing_layer.PreprocessingLayer):
class Discretization(base_preprocessing_layer.CombinerPreprocessingLayer):
"""Buckets data into discrete ranges.
This layer will place each element of its input data into one of several
@ -48,9 +148,15 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer):
Same as input shape.
Attributes:
bins: Optional boundary specification. Bins exclude the left boundary and
include the right boundary, so `bins=[0., 1., 2.]` generates bins
bins: Optional boundary specification or number of bins to compute if `int`.
Bins exclude the left boundary and include the right boundary,
so `bins=[0., 1., 2.]` generates bins
`(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, and `[2., +inf)`.
This would correspond to bins = 4.
epsilon: Error tolerance, typically a small fraction close to zero (e.g.
0.01). Higher values of epsilon increase the quantile approximation, and
hence result in more unequal buckets, but could improve performance
and resource consumption.
Examples:
@ -62,15 +168,38 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer):
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[0, 1, 3, 1],
[0, 3, 2, 0]], dtype=int32)>
Bucketize float values based on a number of buckets to compute.
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
>>> layer = tf.keras.layers.experimental.preprocessing.Discretization(
... bins=4, epsilon=0.01)
>>> layer.adapt(input)
>>> layer(input)
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[0, 2, 3, 1],
[0, 3, 2, 0]], dtype=int32)>
"""
def __init__(self, bins, **kwargs):
super(Discretization, self).__init__(**kwargs)
base_preprocessing_layer._kpl_gauge.get_cell("V2").set("Discretization")
# The bucketization op requires a final rightmost boundary in order to
# correctly assign values higher than the largest left boundary.
# This should not impact intended buckets even if a max value is provided.
self.bins = np.append(bins, [np.Inf])
def __init__(self,
bins,
epsilon=0.01,
**kwargs):
super(Discretization, self).__init__(
combiner=Discretization.DiscretizingCombiner(
epsilon, bins if isinstance(bins, int) else 1),
**kwargs)
if bins is not None and not isinstance(bins, int):
self.bins = np.append(bins, [np.Inf])
else:
self.bins = np.zeros(bins)
def build(self, input_shape):
self.bins = self._add_state_variable(
name=_BINS_NAME,
shape=(self.bins.size,),
dtype=dtypes.float32,
initializer=init_ops.constant_initializer(self.bins))
super(Discretization, self).build(input_shape)
def get_config(self):
config = {
@ -128,3 +257,82 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer):
control_flow_ops.vectorized_map(
_bucketize_op(array_ops.squeeze(self.bins)), reshaped),
array_ops.constant([-1] + input_shape.as_list()[1:]))
class DiscretizingCombiner(Combiner):
"""Combiner for the Discretization preprocessing layer.
This class encapsulates the computations for finding the quantile boundaries
of a set of data in a stable and numerically correct way. Its associated
accumulator is a namedtuple('summaries'), representing summarizations of
the data used to generate boundaries.
Attributes:
epsilon: Error tolerance.
num_bins: The desired number of buckets.
"""
def __init__(self, epsilon, num_bins,):
self.epsilon = epsilon
self.num_bins = num_bins
# TODO(mwunder): Implement elementwise per-column discretization.
def compute(self, values, accumulator=None):
"""Compute a step in this computation, returning a new accumulator."""
if isinstance(values, sparse_tensor.SparseTensor):
values = values.values
if tf_utils.is_ragged(values):
values = values.flat_values
flattened_input = np.reshape(values, newshape=(-1, 1))
summaries = [summarize(v, self.epsilon) for v in flattened_input.T]
if accumulator is None:
return self._create_accumulator(summaries)
else:
return self._create_accumulator(
[merge_summaries(prev_summ, summ, self.epsilon)
for prev_summ, summ in zip(accumulator.summaries, summaries)])
def merge(self, accumulators):
"""Merge several accumulators to a single accumulator."""
# Combine accumulators and return the result.
merged = accumulators[0].summaries
for accumulator in accumulators[1:]:
merged = [merge_summaries(prev, summary, self.epsilon)
for prev, summary in zip(merged, accumulator.summaries)]
return self._create_accumulator(merged)
def extract(self, accumulator):
"""Convert an accumulator into a dict of output values."""
boundaries = [np.append(get_bucket_boundaries(summary, self.num_bins),
[np.Inf])
for summary in accumulator.summaries]
return {
_BINS_NAME: np.squeeze(np.vstack(boundaries))
}
def restore(self, output):
"""Create an accumulator based on 'output'."""
raise NotImplementedError(
"Discretization does not restore or support streaming updates.")
def serialize(self, accumulator):
"""Serialize an accumulator for a remote call."""
output_dict = {
_BINS_NAME: [summary.tolist() for summary in accumulator.summaries]
}
return compat.as_bytes(json.dumps(output_dict))
def deserialize(self, encoded_accumulator):
"""Deserialize an accumulator received from 'serialize()'."""
value_dict = json.loads(compat.as_text(encoded_accumulator))
return self._create_accumulator(np.array(value_dict[_BINS_NAME]))
def _create_accumulator(self, summaries):
"""Represent the accumulator as one or more summaries of the dataset."""
return collections.namedtuple("Accumulator", ["summaries"])(summaries)

View File

@ -18,19 +18,32 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers.preprocessing import discretization
from tensorflow.python.keras.layers.preprocessing import discretization_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
def get_layer_class():
if context.executing_eagerly():
return discretization.Discretization
else:
return discretization_v1.Discretization
@keras_parameterized.run_all_keras_modes
class DiscretizationTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
@ -106,7 +119,6 @@ class DiscretizationTest(keras_parameterized.TestCase,
layer = discretization.Discretization(bins=[-.5, 0.5, 1.5])
bucket_data = layer(input_data)
self.assertAllEqual(expected_output_shape, bucket_data.shape.as_list())
model = keras.Model(inputs=input_data, outputs=bucket_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
@ -125,6 +137,110 @@ class DiscretizationTest(keras_parameterized.TestCase,
self.assertAllEqual(indices, output_dataset.indices)
self.assertAllEqual(expected_output, output_dataset.values)
@parameterized.named_parameters([
{
"testcase_name": "2d_single_element",
"adapt_data": np.array([[1.], [2.], [3.], [4.], [5.]]),
"test_data": np.array([[1.], [2.], [3.]]),
"use_dataset": True,
"expected": np.array([[0], [1], [2]]),
"num_bins": 5,
"epsilon": 0.01
}, {
"testcase_name": "2d_multi_element",
"adapt_data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.],
[5., 10.]]),
"test_data": np.array([[1., 10.], [2., 6.], [3., 8.]]),
"use_dataset": True,
"expected": np.array([[0, 4], [0, 2], [1, 3]]),
"num_bins": 5,
"epsilon": 0.01
}, {
"testcase_name": "1d_single_element",
"adapt_data": np.array([3., 2., 1., 5., 4.]),
"test_data": np.array([1., 2., 3.]),
"use_dataset": True,
"expected": np.array([0, 1, 2]),
"num_bins": 5,
"epsilon": 0.01
}, {
"testcase_name": "300_batch_1d_single_element_1",
"adapt_data": np.arange(300),
"test_data": np.arange(300),
"use_dataset": True,
"expected":
np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]),
"num_bins": 3,
"epsilon": 0.01
}, {
"testcase_name": "300_batch_1d_single_element_2",
"adapt_data": np.arange(300) ** 2,
"test_data": np.arange(300) ** 2,
"use_dataset": True,
"expected":
np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]),
"num_bins": 3,
"epsilon": 0.01
}, {
"testcase_name": "300_batch_1d_single_element_large_epsilon",
"adapt_data": np.arange(300),
"test_data": np.arange(300),
"use_dataset": True,
"expected": np.concatenate([np.zeros(137), np.ones(163)]),
"num_bins": 2,
"epsilon": 0.1
}])
def test_layer_computation(self, adapt_data, test_data, use_dataset,
expected, num_bins=5, epsilon=0.01):
input_shape = tuple(list(test_data.shape)[1:])
np.random.shuffle(adapt_data)
if use_dataset:
# Keras APIs expect batched datasets
adapt_data = dataset_ops.Dataset.from_tensor_slices(adapt_data).batch(
test_data.shape[0] // 2)
test_data = dataset_ops.Dataset.from_tensor_slices(test_data).batch(
test_data.shape[0] // 2)
cls = get_layer_class()
layer = cls(epsilon=epsilon, bins=num_bins)
layer.adapt(adapt_data)
input_data = keras.Input(shape=input_shape)
output = layer(input_data)
model = keras.Model(input_data, output)
model._run_eagerly = testing_utils.should_run_eagerly()
output_data = model.predict(test_data)
self.assertAllClose(expected, output_data)
@parameterized.named_parameters(
{
"num_bins": 5,
"data": np.array([[1.], [2.], [3.], [4.], [5.]]),
"expected": {
"bins": np.array([1., 2., 3., 4., np.Inf])
},
"testcase_name": "2d_single_element_all_bins"
}, {
"num_bins": 5,
"data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.], [5., 10.]]),
"expected": {
"bins": np.array([2., 4., 6., 8., np.Inf])
},
"testcase_name": "2d_multi_element_all_bins",
}, {
"num_bins": 3,
"data": np.array([[0.], [1.], [2.], [3.], [4.], [5.]]),
"expected": {
"bins": np.array([1., 3., np.Inf])
},
"testcase_name": "2d_single_element_3_bins"
})
def test_combiner_computation(self, num_bins, data, expected):
epsilon = 0.01
combiner = discretization.Discretization.DiscretizingCombiner(epsilon,
num_bins)
self.validate_accumulator_extract(combiner, data, expected)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,28 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow V1 version of the Discretization preprocessing layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.engine.base_preprocessing_layer_v1 import CombinerPreprocessingLayer
from tensorflow.python.keras.layers.preprocessing import discretization
from tensorflow.python.util.tf_export import keras_export
@keras_export(v1=['keras.layers.experimental.preprocessing.Discretization'])
class Discretization(discretization.Discretization, CombinerPreprocessingLayer):
pass

View File

@ -0,0 +1,14 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__"
tf_class {
is_instance: "<type \'type\'>"
member_method {
name: "__init__"
}
member_method {
name: "mro"
}
member_method {
name: "register"
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,34 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization.DiscretizingCombiner\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.Combiner\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute"
argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "deserialize"
argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "extract"
argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge"
argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "restore"
argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "serialize"
argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
@ -8,6 +9,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "DiscretizingCombiner"
mtype: "<type \'type\'>"
}
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
@ -130,7 +135,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], "
}
member_method {
name: "adapt"

View File

@ -0,0 +1,14 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__"
tf_class {
is_instance: "<type \'type\'>"
member_method {
name: "__init__"
}
member_method {
name: "mro"
}
member_method {
name: "register"
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,34 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization.DiscretizingCombiner\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.Combiner\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute"
argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "deserialize"
argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "extract"
argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge"
argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "restore"
argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "serialize"
argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
@ -8,6 +9,10 @@ tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "DiscretizingCombiner"
mtype: "<type \'type\'>"
}
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
@ -130,7 +135,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], "
}
member_method {
name: "adapt"