Add DiscretizingCombiner to KPL.
PiperOrigin-RevId: 339041119 Change-Id: I7e972f40e9fcce76e93ec0cb6dbf0844a4150c1c
This commit is contained in:
parent
65efe3b554
commit
8c78dae8fb
@ -23,6 +23,9 @@
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* `tf.keras`:
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* Discretization combiner implemented, with additional arg `epsilon`.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
@ -47,6 +47,7 @@ py_library(
|
||||
name = "discretization",
|
||||
srcs = [
|
||||
"discretization.py",
|
||||
"discretization_v1.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
@ -54,6 +55,7 @@ py_library(
|
||||
"//tensorflow/python:boosted_trees_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tf_export",
|
||||
@ -458,6 +460,7 @@ tf_py_test(
|
||||
size = "small",
|
||||
srcs = ["discretization_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
tags = ["no_rocm"],
|
||||
deps = [
|
||||
":discretization",
|
||||
|
@ -103,6 +103,16 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "discretization_adapt_benchmark",
|
||||
srcs = ["discretization_adapt_benchmark.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/keras/layers/preprocessing:discretization",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "image_preproc_benchmark",
|
||||
srcs = ["image_preproc_benchmark.py"],
|
||||
|
@ -0,0 +1,120 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Benchmark for Keras discretization preprocessing layer's adapt method."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras.layers.preprocessing import discretization
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
EPSILON = 0.1
|
||||
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
||||
|
||||
def reduce_fn(state, values, epsilon=EPSILON):
|
||||
"""tf.data.Dataset-friendly implementation of mean and variance."""
|
||||
|
||||
state_, = state
|
||||
summary = discretization.summarize(values, epsilon)
|
||||
if np.sum(state_[:, 0]) == 0:
|
||||
return (summary,)
|
||||
return (discretization.merge_summaries(state_, summary, epsilon),)
|
||||
|
||||
|
||||
class BenchmarkAdapt(benchmark.Benchmark):
|
||||
"""Benchmark adapt."""
|
||||
|
||||
def run_dataset_implementation(self, num_elements, batch_size):
|
||||
input_t = keras.Input(shape=(1,))
|
||||
layer = discretization.Discretization()
|
||||
_ = layer(input_t)
|
||||
|
||||
num_repeats = 5
|
||||
starts = []
|
||||
ends = []
|
||||
for _ in range(num_repeats):
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.map(
|
||||
lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
|
||||
ds = ds.batch(batch_size)
|
||||
|
||||
starts.append(time.time())
|
||||
# Benchmarked code begins here.
|
||||
state = ds.reduce((np.zeros((1, 2)),), reduce_fn)
|
||||
|
||||
bins = discretization.get_bucket_boundaries(state, 100)
|
||||
layer.set_weights([bins])
|
||||
# Benchmarked code ends here.
|
||||
ends.append(time.time())
|
||||
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts))
|
||||
return avg_time
|
||||
|
||||
def bm_adapt_implementation(self, num_elements, batch_size):
|
||||
"""Test the KPL adapt implementation."""
|
||||
input_t = keras.Input(shape=(1,), dtype=dtypes.float32)
|
||||
layer = discretization.Discretization()
|
||||
_ = layer(input_t)
|
||||
|
||||
num_repeats = 5
|
||||
starts = []
|
||||
ends = []
|
||||
for _ in range(num_repeats):
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.map(
|
||||
lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
|
||||
ds = ds.batch(batch_size)
|
||||
|
||||
starts.append(time.time())
|
||||
# Benchmarked code begins here.
|
||||
layer.adapt(ds)
|
||||
# Benchmarked code ends here.
|
||||
ends.append(time.time())
|
||||
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts))
|
||||
name = "discretization_adapt|%s_elements|batch_%s" % (num_elements,
|
||||
batch_size)
|
||||
baseline = self.run_dataset_implementation(num_elements, batch_size)
|
||||
extras = {
|
||||
"tf.data implementation baseline": baseline,
|
||||
"delta seconds": (baseline - avg_time),
|
||||
"delta percent": ((baseline - avg_time) / baseline) * 100
|
||||
}
|
||||
self.report_benchmark(
|
||||
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
|
||||
|
||||
def benchmark_vocab_size_by_batch(self):
|
||||
for vocab_size in [100, 1000, 10000, 100000, 1000000]:
|
||||
for batch in [64 * 2048]:
|
||||
self.bm_adapt_implementation(vocab_size, batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -17,24 +17,124 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||
from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_boosted_trees_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.parallel_for import control_flow_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
_BINS_NAME = "bins"
|
||||
|
||||
|
||||
def summarize(values, epsilon):
|
||||
"""Reduce a 1D sequence of values to a summary.
|
||||
|
||||
This algorithm is based on numpy.quantiles but modified to allow for
|
||||
intermediate steps between multiple data sets. It first finds the target
|
||||
number of bins as the reciprocal of epsilon and then takes the individual
|
||||
values spaced at appropriate intervals to arrive at that target.
|
||||
The final step is to return the corresponding counts between those values
|
||||
If the target num_bins is larger than the size of values, the whole array is
|
||||
returned (with weights of 1).
|
||||
|
||||
Arguments:
|
||||
values: 1-D `np.ndarray` to be summarized.
|
||||
epsilon: A `'float32'` that determines the approxmiate desired precision.
|
||||
|
||||
Returns:
|
||||
A 2-D `np.ndarray` that is a summary of the inputs. First column is the
|
||||
interpolated partition values, the second is the weights (counts).
|
||||
"""
|
||||
|
||||
num_bins = 1.0 / epsilon
|
||||
value_shape = values.shape
|
||||
n = np.prod([[(1 if dim is None else dim) for dim in value_shape]])
|
||||
if num_bins >= n:
|
||||
return np.hstack((np.expand_dims(np.sort(values), 1), np.ones((n, 1))))
|
||||
step_size = int(n / num_bins)
|
||||
partition_indices = np.arange(step_size, n, step_size, np.int64)
|
||||
|
||||
part = np.partition(values, partition_indices)[partition_indices]
|
||||
|
||||
return np.hstack((np.expand_dims(part, 1),
|
||||
step_size * np.ones((np.prod(part.shape), 1))))
|
||||
|
||||
|
||||
def compress(summary, epsilon):
|
||||
"""Compress a summary to within `epsilon` accuracy.
|
||||
|
||||
The compression step is needed to keep the summary sizes small after merging,
|
||||
and also used to return the final target boundaries. It finds the new bins
|
||||
based on interpolating cumulative weight percentages from the large summary.
|
||||
Taking the difference of the cumulative weights from the previous bin's
|
||||
cumulative weight will give the new weight for that bin.
|
||||
|
||||
Arguments:
|
||||
summary: 2-D `np.ndarray` summary to be compressed.
|
||||
epsilon: A `'float32'` that determines the approxmiate desired precision.
|
||||
|
||||
Returns:
|
||||
A 2-D `np.ndarray` that is a compressed summary. First column is the
|
||||
interpolated partition values, the second is the weights (counts).
|
||||
"""
|
||||
if np.prod(summary[:, 0].shape) * epsilon < 1:
|
||||
return summary
|
||||
|
||||
percents = epsilon + np.arange(0.0, 1.0, epsilon)
|
||||
cum_weights = summary[:, 1].cumsum()
|
||||
cum_weight_percents = cum_weights / cum_weights[-1]
|
||||
new_bins = np.interp(percents, cum_weight_percents, summary[:, 0])
|
||||
cum_weights = np.interp(percents, cum_weight_percents, cum_weights)
|
||||
new_weights = cum_weights - np.concatenate((np.array([0]), cum_weights[:-1]))
|
||||
|
||||
return np.hstack((np.expand_dims(new_bins, 1),
|
||||
np.expand_dims(new_weights, 1)))
|
||||
|
||||
|
||||
def merge_summaries(prev_summary, next_summary, epsilon):
|
||||
"""Weighted merge sort of summaries.
|
||||
|
||||
Given two summaries of distinct data, this function merges (and compresses)
|
||||
them to stay within `epsilon` error tolerance.
|
||||
|
||||
Arguments:
|
||||
prev_summary: 2-D `np.ndarray` summary to be merged with `next_summary`.
|
||||
next_summary: 2-D `np.ndarray` summary to be merged with `prev_summary`.
|
||||
epsilon: A `'float32'` that determines the approxmiate desired precision.
|
||||
|
||||
Returns:
|
||||
A 2-D `np.ndarray` that is a merged summary. First column is the
|
||||
interpolated partition values, the second is the weights (counts).
|
||||
"""
|
||||
merged = np.concatenate((prev_summary, next_summary))
|
||||
merged = merged[merged[:, 0].argsort()]
|
||||
if np.prod(merged.shape) * epsilon < 1:
|
||||
return merged
|
||||
return compress(merged, epsilon)
|
||||
|
||||
|
||||
def get_bucket_boundaries(summary, num_bins):
|
||||
return compress(summary, 1.0 / num_bins)[:-1, 0]
|
||||
|
||||
|
||||
@keras_export("keras.layers.experimental.preprocessing.Discretization")
|
||||
class Discretization(base_preprocessing_layer.PreprocessingLayer):
|
||||
class Discretization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
"""Buckets data into discrete ranges.
|
||||
|
||||
This layer will place each element of its input data into one of several
|
||||
@ -48,9 +148,15 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer):
|
||||
Same as input shape.
|
||||
|
||||
Attributes:
|
||||
bins: Optional boundary specification. Bins exclude the left boundary and
|
||||
include the right boundary, so `bins=[0., 1., 2.]` generates bins
|
||||
bins: Optional boundary specification or number of bins to compute if `int`.
|
||||
Bins exclude the left boundary and include the right boundary,
|
||||
so `bins=[0., 1., 2.]` generates bins
|
||||
`(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, and `[2., +inf)`.
|
||||
This would correspond to bins = 4.
|
||||
epsilon: Error tolerance, typically a small fraction close to zero (e.g.
|
||||
0.01). Higher values of epsilon increase the quantile approximation, and
|
||||
hence result in more unequal buckets, but could improve performance
|
||||
and resource consumption.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -62,15 +168,38 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer):
|
||||
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
|
||||
array([[0, 1, 3, 1],
|
||||
[0, 3, 2, 0]], dtype=int32)>
|
||||
|
||||
Bucketize float values based on a number of buckets to compute.
|
||||
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
|
||||
>>> layer = tf.keras.layers.experimental.preprocessing.Discretization(
|
||||
... bins=4, epsilon=0.01)
|
||||
>>> layer.adapt(input)
|
||||
>>> layer(input)
|
||||
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
|
||||
array([[0, 2, 3, 1],
|
||||
[0, 3, 2, 0]], dtype=int32)>
|
||||
"""
|
||||
|
||||
def __init__(self, bins, **kwargs):
|
||||
super(Discretization, self).__init__(**kwargs)
|
||||
base_preprocessing_layer._kpl_gauge.get_cell("V2").set("Discretization")
|
||||
# The bucketization op requires a final rightmost boundary in order to
|
||||
# correctly assign values higher than the largest left boundary.
|
||||
# This should not impact intended buckets even if a max value is provided.
|
||||
self.bins = np.append(bins, [np.Inf])
|
||||
def __init__(self,
|
||||
bins,
|
||||
epsilon=0.01,
|
||||
**kwargs):
|
||||
super(Discretization, self).__init__(
|
||||
combiner=Discretization.DiscretizingCombiner(
|
||||
epsilon, bins if isinstance(bins, int) else 1),
|
||||
**kwargs)
|
||||
if bins is not None and not isinstance(bins, int):
|
||||
self.bins = np.append(bins, [np.Inf])
|
||||
else:
|
||||
self.bins = np.zeros(bins)
|
||||
|
||||
def build(self, input_shape):
|
||||
self.bins = self._add_state_variable(
|
||||
name=_BINS_NAME,
|
||||
shape=(self.bins.size,),
|
||||
dtype=dtypes.float32,
|
||||
initializer=init_ops.constant_initializer(self.bins))
|
||||
super(Discretization, self).build(input_shape)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
@ -128,3 +257,82 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer):
|
||||
control_flow_ops.vectorized_map(
|
||||
_bucketize_op(array_ops.squeeze(self.bins)), reshaped),
|
||||
array_ops.constant([-1] + input_shape.as_list()[1:]))
|
||||
|
||||
class DiscretizingCombiner(Combiner):
|
||||
"""Combiner for the Discretization preprocessing layer.
|
||||
|
||||
This class encapsulates the computations for finding the quantile boundaries
|
||||
of a set of data in a stable and numerically correct way. Its associated
|
||||
accumulator is a namedtuple('summaries'), representing summarizations of
|
||||
the data used to generate boundaries.
|
||||
|
||||
Attributes:
|
||||
epsilon: Error tolerance.
|
||||
num_bins: The desired number of buckets.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon, num_bins,):
|
||||
self.epsilon = epsilon
|
||||
self.num_bins = num_bins
|
||||
|
||||
# TODO(mwunder): Implement elementwise per-column discretization.
|
||||
|
||||
def compute(self, values, accumulator=None):
|
||||
"""Compute a step in this computation, returning a new accumulator."""
|
||||
|
||||
if isinstance(values, sparse_tensor.SparseTensor):
|
||||
values = values.values
|
||||
if tf_utils.is_ragged(values):
|
||||
values = values.flat_values
|
||||
flattened_input = np.reshape(values, newshape=(-1, 1))
|
||||
|
||||
summaries = [summarize(v, self.epsilon) for v in flattened_input.T]
|
||||
|
||||
if accumulator is None:
|
||||
return self._create_accumulator(summaries)
|
||||
else:
|
||||
return self._create_accumulator(
|
||||
[merge_summaries(prev_summ, summ, self.epsilon)
|
||||
for prev_summ, summ in zip(accumulator.summaries, summaries)])
|
||||
|
||||
def merge(self, accumulators):
|
||||
"""Merge several accumulators to a single accumulator."""
|
||||
# Combine accumulators and return the result.
|
||||
|
||||
merged = accumulators[0].summaries
|
||||
for accumulator in accumulators[1:]:
|
||||
merged = [merge_summaries(prev, summary, self.epsilon)
|
||||
for prev, summary in zip(merged, accumulator.summaries)]
|
||||
|
||||
return self._create_accumulator(merged)
|
||||
|
||||
def extract(self, accumulator):
|
||||
"""Convert an accumulator into a dict of output values."""
|
||||
|
||||
boundaries = [np.append(get_bucket_boundaries(summary, self.num_bins),
|
||||
[np.Inf])
|
||||
for summary in accumulator.summaries]
|
||||
return {
|
||||
_BINS_NAME: np.squeeze(np.vstack(boundaries))
|
||||
}
|
||||
|
||||
def restore(self, output):
|
||||
"""Create an accumulator based on 'output'."""
|
||||
raise NotImplementedError(
|
||||
"Discretization does not restore or support streaming updates.")
|
||||
|
||||
def serialize(self, accumulator):
|
||||
"""Serialize an accumulator for a remote call."""
|
||||
output_dict = {
|
||||
_BINS_NAME: [summary.tolist() for summary in accumulator.summaries]
|
||||
}
|
||||
return compat.as_bytes(json.dumps(output_dict))
|
||||
|
||||
def deserialize(self, encoded_accumulator):
|
||||
"""Deserialize an accumulator received from 'serialize()'."""
|
||||
value_dict = json.loads(compat.as_text(encoded_accumulator))
|
||||
return self._create_accumulator(np.array(value_dict[_BINS_NAME]))
|
||||
|
||||
def _create_accumulator(self, summaries):
|
||||
"""Represent the accumulator as one or more summaries of the dataset."""
|
||||
return collections.namedtuple("Accumulator", ["summaries"])(summaries)
|
||||
|
@ -18,19 +18,32 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers.preprocessing import discretization
|
||||
from tensorflow.python.keras.layers.preprocessing import discretization_v1
|
||||
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def get_layer_class():
|
||||
if context.executing_eagerly():
|
||||
return discretization.Discretization
|
||||
else:
|
||||
return discretization_v1.Discretization
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class DiscretizationTest(keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
@ -106,7 +119,6 @@ class DiscretizationTest(keras_parameterized.TestCase,
|
||||
layer = discretization.Discretization(bins=[-.5, 0.5, 1.5])
|
||||
bucket_data = layer(input_data)
|
||||
self.assertAllEqual(expected_output_shape, bucket_data.shape.as_list())
|
||||
|
||||
model = keras.Model(inputs=input_data, outputs=bucket_data)
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
@ -125,6 +137,110 @@ class DiscretizationTest(keras_parameterized.TestCase,
|
||||
self.assertAllEqual(indices, output_dataset.indices)
|
||||
self.assertAllEqual(expected_output, output_dataset.values)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{
|
||||
"testcase_name": "2d_single_element",
|
||||
"adapt_data": np.array([[1.], [2.], [3.], [4.], [5.]]),
|
||||
"test_data": np.array([[1.], [2.], [3.]]),
|
||||
"use_dataset": True,
|
||||
"expected": np.array([[0], [1], [2]]),
|
||||
"num_bins": 5,
|
||||
"epsilon": 0.01
|
||||
}, {
|
||||
"testcase_name": "2d_multi_element",
|
||||
"adapt_data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.],
|
||||
[5., 10.]]),
|
||||
"test_data": np.array([[1., 10.], [2., 6.], [3., 8.]]),
|
||||
"use_dataset": True,
|
||||
"expected": np.array([[0, 4], [0, 2], [1, 3]]),
|
||||
"num_bins": 5,
|
||||
"epsilon": 0.01
|
||||
}, {
|
||||
"testcase_name": "1d_single_element",
|
||||
"adapt_data": np.array([3., 2., 1., 5., 4.]),
|
||||
"test_data": np.array([1., 2., 3.]),
|
||||
"use_dataset": True,
|
||||
"expected": np.array([0, 1, 2]),
|
||||
"num_bins": 5,
|
||||
"epsilon": 0.01
|
||||
}, {
|
||||
"testcase_name": "300_batch_1d_single_element_1",
|
||||
"adapt_data": np.arange(300),
|
||||
"test_data": np.arange(300),
|
||||
"use_dataset": True,
|
||||
"expected":
|
||||
np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]),
|
||||
"num_bins": 3,
|
||||
"epsilon": 0.01
|
||||
}, {
|
||||
"testcase_name": "300_batch_1d_single_element_2",
|
||||
"adapt_data": np.arange(300) ** 2,
|
||||
"test_data": np.arange(300) ** 2,
|
||||
"use_dataset": True,
|
||||
"expected":
|
||||
np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]),
|
||||
"num_bins": 3,
|
||||
"epsilon": 0.01
|
||||
}, {
|
||||
"testcase_name": "300_batch_1d_single_element_large_epsilon",
|
||||
"adapt_data": np.arange(300),
|
||||
"test_data": np.arange(300),
|
||||
"use_dataset": True,
|
||||
"expected": np.concatenate([np.zeros(137), np.ones(163)]),
|
||||
"num_bins": 2,
|
||||
"epsilon": 0.1
|
||||
}])
|
||||
def test_layer_computation(self, adapt_data, test_data, use_dataset,
|
||||
expected, num_bins=5, epsilon=0.01):
|
||||
|
||||
input_shape = tuple(list(test_data.shape)[1:])
|
||||
np.random.shuffle(adapt_data)
|
||||
if use_dataset:
|
||||
# Keras APIs expect batched datasets
|
||||
adapt_data = dataset_ops.Dataset.from_tensor_slices(adapt_data).batch(
|
||||
test_data.shape[0] // 2)
|
||||
test_data = dataset_ops.Dataset.from_tensor_slices(test_data).batch(
|
||||
test_data.shape[0] // 2)
|
||||
|
||||
cls = get_layer_class()
|
||||
layer = cls(epsilon=epsilon, bins=num_bins)
|
||||
layer.adapt(adapt_data)
|
||||
|
||||
input_data = keras.Input(shape=input_shape)
|
||||
output = layer(input_data)
|
||||
model = keras.Model(input_data, output)
|
||||
model._run_eagerly = testing_utils.should_run_eagerly()
|
||||
output_data = model.predict(test_data)
|
||||
self.assertAllClose(expected, output_data)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"num_bins": 5,
|
||||
"data": np.array([[1.], [2.], [3.], [4.], [5.]]),
|
||||
"expected": {
|
||||
"bins": np.array([1., 2., 3., 4., np.Inf])
|
||||
},
|
||||
"testcase_name": "2d_single_element_all_bins"
|
||||
}, {
|
||||
"num_bins": 5,
|
||||
"data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.], [5., 10.]]),
|
||||
"expected": {
|
||||
"bins": np.array([2., 4., 6., 8., np.Inf])
|
||||
},
|
||||
"testcase_name": "2d_multi_element_all_bins",
|
||||
}, {
|
||||
"num_bins": 3,
|
||||
"data": np.array([[0.], [1.], [2.], [3.], [4.], [5.]]),
|
||||
"expected": {
|
||||
"bins": np.array([1., 3., np.Inf])
|
||||
},
|
||||
"testcase_name": "2d_single_element_3_bins"
|
||||
})
|
||||
def test_combiner_computation(self, num_bins, data, expected):
|
||||
epsilon = 0.01
|
||||
combiner = discretization.Discretization.DiscretizingCombiner(epsilon,
|
||||
num_bins)
|
||||
self.validate_accumulator_extract(combiner, data, expected)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -0,0 +1,28 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tensorflow V1 version of the Discretization preprocessing layer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.engine.base_preprocessing_layer_v1 import CombinerPreprocessingLayer
|
||||
from tensorflow.python.keras.layers.preprocessing import discretization
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@keras_export(v1=['keras.layers.experimental.preprocessing.Discretization'])
|
||||
class Discretization(discretization.Discretization, CombinerPreprocessingLayer):
|
||||
pass
|
@ -0,0 +1,14 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__"
|
||||
tf_class {
|
||||
is_instance: "<type \'type\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "mro"
|
||||
}
|
||||
member_method {
|
||||
name: "register"
|
||||
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization.DiscretizingCombiner\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.Combiner\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute"
|
||||
argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "deserialize"
|
||||
argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "extract"
|
||||
argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge"
|
||||
argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "restore"
|
||||
argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "serialize"
|
||||
argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
@ -8,6 +9,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "DiscretizingCombiner"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -130,7 +135,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "adapt"
|
||||
|
@ -0,0 +1,14 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__"
|
||||
tf_class {
|
||||
is_instance: "<type \'type\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "mro"
|
||||
}
|
||||
member_method {
|
||||
name: "register"
|
||||
argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization.DiscretizingCombiner\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.Combiner\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute"
|
||||
argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "deserialize"
|
||||
argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "extract"
|
||||
argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge"
|
||||
argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "restore"
|
||||
argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "serialize"
|
||||
argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
@ -8,6 +9,10 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "DiscretizingCombiner"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -130,7 +135,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "adapt"
|
||||
|
Loading…
x
Reference in New Issue
Block a user