From 8c78dae8fb843fa5e92a771eb559db9df2c0ad74 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 26 Oct 2020 08:16:50 -0700 Subject: [PATCH] Add DiscretizingCombiner to KPL. PiperOrigin-RevId: 339041119 Change-Id: I7e972f40e9fcce76e93ec0cb6dbf0844a4150c1c --- RELEASE.md | 3 + .../python/keras/layers/preprocessing/BUILD | 3 + .../layers/preprocessing/benchmarks/BUILD | 10 + .../discretization_adapt_benchmark.py | 120 +++++++++ .../layers/preprocessing/discretization.py | 228 +++++++++++++++++- .../preprocessing/discretization_test.py | 118 ++++++++- .../layers/preprocessing/discretization_v1.py | 28 +++ ...-discretizing-combiner.__metaclass__.pbtxt | 14 ++ ...iscretization.-discretizing-combiner.pbtxt | 34 +++ ...mental.preprocessing.-discretization.pbtxt | 7 +- ...-discretizing-combiner.__metaclass__.pbtxt | 14 ++ ...iscretization.-discretizing-combiner.pbtxt | 34 +++ ...mental.preprocessing.-discretization.pbtxt | 7 +- 13 files changed, 607 insertions(+), 13 deletions(-) create mode 100644 tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py create mode 100644 tensorflow/python/keras/layers/preprocessing/discretization_v1.py create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt diff --git a/RELEASE.md b/RELEASE.md index d5654424afd..131f8ef4479 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -23,6 +23,9 @@ * * * +* `tf.keras`: + * Improvements to Keras preprocessing layers: + * Discretization combiner implemented, with additional arg `epsilon`. ## Thanks to our Contributors diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD index 6b828aea24d..30579e76725 100644 --- a/tensorflow/python/keras/layers/preprocessing/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/BUILD @@ -47,6 +47,7 @@ py_library( name = "discretization", srcs = [ "discretization.py", + "discretization_v1.py", ], srcs_version = "PY2AND3", deps = [ @@ -54,6 +55,7 @@ py_library( "//tensorflow/python:boosted_trees_ops", "//tensorflow/python:dtypes", "//tensorflow/python:math_ops", + "//tensorflow/python:resources", "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_spec", "//tensorflow/python:tf_export", @@ -458,6 +460,7 @@ tf_py_test( size = "small", srcs = ["discretization_test.py"], python_version = "PY3", + shard_count = 4, tags = ["no_rocm"], deps = [ ":discretization", diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD index 0935d86cc5f..7a965bfe2c2 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD @@ -103,6 +103,16 @@ tf_py_test( ], ) +tf_py_test( + name = "discretization_adapt_benchmark", + srcs = ["discretization_adapt_benchmark.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/layers/preprocessing:discretization", + ], +) + cuda_py_test( name = "image_preproc_benchmark", srcs = ["image_preproc_benchmark.py"], diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py new file mode 100644 index 00000000000..d0fb194791a --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py @@ -0,0 +1,120 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Keras discretization preprocessing layer's adapt method.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl import flags +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.keras.layers.preprocessing import discretization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS +EPSILON = 0.1 + +v2_compat.enable_v2_behavior() + + +def reduce_fn(state, values, epsilon=EPSILON): + """tf.data.Dataset-friendly implementation of mean and variance.""" + + state_, = state + summary = discretization.summarize(values, epsilon) + if np.sum(state_[:, 0]) == 0: + return (summary,) + return (discretization.merge_summaries(state_, summary, epsilon),) + + +class BenchmarkAdapt(benchmark.Benchmark): + """Benchmark adapt.""" + + def run_dataset_implementation(self, num_elements, batch_size): + input_t = keras.Input(shape=(1,)) + layer = discretization.Discretization() + _ = layer(input_t) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.range(num_elements) + ds = ds.map( + lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1)) + ds = ds.batch(batch_size) + + starts.append(time.time()) + # Benchmarked code begins here. + state = ds.reduce((np.zeros((1, 2)),), reduce_fn) + + bins = discretization.get_bucket_boundaries(state, 100) + layer.set_weights([bins]) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) + return avg_time + + def bm_adapt_implementation(self, num_elements, batch_size): + """Test the KPL adapt implementation.""" + input_t = keras.Input(shape=(1,), dtype=dtypes.float32) + layer = discretization.Discretization() + _ = layer(input_t) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.range(num_elements) + ds = ds.map( + lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1)) + ds = ds.batch(batch_size) + + starts.append(time.time()) + # Benchmarked code begins here. + layer.adapt(ds) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) + name = "discretization_adapt|%s_elements|batch_%s" % (num_elements, + batch_size) + baseline = self.run_dataset_implementation(num_elements, batch_size) + extras = { + "tf.data implementation baseline": baseline, + "delta seconds": (baseline - avg_time), + "delta percent": ((baseline - avg_time) / baseline) * 100 + } + self.report_benchmark( + iters=num_repeats, wall_time=avg_time, extras=extras, name=name) + + def benchmark_vocab_size_by_batch(self): + for vocab_size in [100, 1000, 10000, 100000, 1000000]: + for batch in [64 * 2048]: + self.bm_adapt_implementation(vocab_size, batch) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/discretization.py b/tensorflow/python/keras/layers/preprocessing/discretization.py index 425d0a207bb..317a940c889 100644 --- a/tensorflow/python/keras/layers/preprocessing/discretization.py +++ b/tensorflow/python/keras/layers/preprocessing/discretization.py @@ -17,24 +17,124 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import json + import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec from tensorflow.python.keras.engine import base_preprocessing_layer +from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_boosted_trees_ops from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.parallel_for import control_flow_ops from tensorflow.python.ops.ragged import ragged_functional_ops +from tensorflow.python.util import compat from tensorflow.python.util.tf_export import keras_export +_BINS_NAME = "bins" + + +def summarize(values, epsilon): + """Reduce a 1D sequence of values to a summary. + + This algorithm is based on numpy.quantiles but modified to allow for + intermediate steps between multiple data sets. It first finds the target + number of bins as the reciprocal of epsilon and then takes the individual + values spaced at appropriate intervals to arrive at that target. + The final step is to return the corresponding counts between those values + If the target num_bins is larger than the size of values, the whole array is + returned (with weights of 1). + + Arguments: + values: 1-D `np.ndarray` to be summarized. + epsilon: A `'float32'` that determines the approxmiate desired precision. + + Returns: + A 2-D `np.ndarray` that is a summary of the inputs. First column is the + interpolated partition values, the second is the weights (counts). + """ + + num_bins = 1.0 / epsilon + value_shape = values.shape + n = np.prod([[(1 if dim is None else dim) for dim in value_shape]]) + if num_bins >= n: + return np.hstack((np.expand_dims(np.sort(values), 1), np.ones((n, 1)))) + step_size = int(n / num_bins) + partition_indices = np.arange(step_size, n, step_size, np.int64) + + part = np.partition(values, partition_indices)[partition_indices] + + return np.hstack((np.expand_dims(part, 1), + step_size * np.ones((np.prod(part.shape), 1)))) + + +def compress(summary, epsilon): + """Compress a summary to within `epsilon` accuracy. + + The compression step is needed to keep the summary sizes small after merging, + and also used to return the final target boundaries. It finds the new bins + based on interpolating cumulative weight percentages from the large summary. + Taking the difference of the cumulative weights from the previous bin's + cumulative weight will give the new weight for that bin. + + Arguments: + summary: 2-D `np.ndarray` summary to be compressed. + epsilon: A `'float32'` that determines the approxmiate desired precision. + + Returns: + A 2-D `np.ndarray` that is a compressed summary. First column is the + interpolated partition values, the second is the weights (counts). + """ + if np.prod(summary[:, 0].shape) * epsilon < 1: + return summary + + percents = epsilon + np.arange(0.0, 1.0, epsilon) + cum_weights = summary[:, 1].cumsum() + cum_weight_percents = cum_weights / cum_weights[-1] + new_bins = np.interp(percents, cum_weight_percents, summary[:, 0]) + cum_weights = np.interp(percents, cum_weight_percents, cum_weights) + new_weights = cum_weights - np.concatenate((np.array([0]), cum_weights[:-1])) + + return np.hstack((np.expand_dims(new_bins, 1), + np.expand_dims(new_weights, 1))) + + +def merge_summaries(prev_summary, next_summary, epsilon): + """Weighted merge sort of summaries. + + Given two summaries of distinct data, this function merges (and compresses) + them to stay within `epsilon` error tolerance. + + Arguments: + prev_summary: 2-D `np.ndarray` summary to be merged with `next_summary`. + next_summary: 2-D `np.ndarray` summary to be merged with `prev_summary`. + epsilon: A `'float32'` that determines the approxmiate desired precision. + + Returns: + A 2-D `np.ndarray` that is a merged summary. First column is the + interpolated partition values, the second is the weights (counts). + """ + merged = np.concatenate((prev_summary, next_summary)) + merged = merged[merged[:, 0].argsort()] + if np.prod(merged.shape) * epsilon < 1: + return merged + return compress(merged, epsilon) + + +def get_bucket_boundaries(summary, num_bins): + return compress(summary, 1.0 / num_bins)[:-1, 0] + + @keras_export("keras.layers.experimental.preprocessing.Discretization") -class Discretization(base_preprocessing_layer.PreprocessingLayer): +class Discretization(base_preprocessing_layer.CombinerPreprocessingLayer): """Buckets data into discrete ranges. This layer will place each element of its input data into one of several @@ -48,9 +148,15 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer): Same as input shape. Attributes: - bins: Optional boundary specification. Bins exclude the left boundary and - include the right boundary, so `bins=[0., 1., 2.]` generates bins + bins: Optional boundary specification or number of bins to compute if `int`. + Bins exclude the left boundary and include the right boundary, + so `bins=[0., 1., 2.]` generates bins `(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, and `[2., +inf)`. + This would correspond to bins = 4. + epsilon: Error tolerance, typically a small fraction close to zero (e.g. + 0.01). Higher values of epsilon increase the quantile approximation, and + hence result in more unequal buckets, but could improve performance + and resource consumption. Examples: @@ -62,15 +168,38 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer): + + Bucketize float values based on a number of buckets to compute. + >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]]) + >>> layer = tf.keras.layers.experimental.preprocessing.Discretization( + ... bins=4, epsilon=0.01) + >>> layer.adapt(input) + >>> layer(input) + """ - def __init__(self, bins, **kwargs): - super(Discretization, self).__init__(**kwargs) - base_preprocessing_layer._kpl_gauge.get_cell("V2").set("Discretization") - # The bucketization op requires a final rightmost boundary in order to - # correctly assign values higher than the largest left boundary. - # This should not impact intended buckets even if a max value is provided. - self.bins = np.append(bins, [np.Inf]) + def __init__(self, + bins, + epsilon=0.01, + **kwargs): + super(Discretization, self).__init__( + combiner=Discretization.DiscretizingCombiner( + epsilon, bins if isinstance(bins, int) else 1), + **kwargs) + if bins is not None and not isinstance(bins, int): + self.bins = np.append(bins, [np.Inf]) + else: + self.bins = np.zeros(bins) + + def build(self, input_shape): + self.bins = self._add_state_variable( + name=_BINS_NAME, + shape=(self.bins.size,), + dtype=dtypes.float32, + initializer=init_ops.constant_initializer(self.bins)) + super(Discretization, self).build(input_shape) def get_config(self): config = { @@ -128,3 +257,82 @@ class Discretization(base_preprocessing_layer.PreprocessingLayer): control_flow_ops.vectorized_map( _bucketize_op(array_ops.squeeze(self.bins)), reshaped), array_ops.constant([-1] + input_shape.as_list()[1:])) + + class DiscretizingCombiner(Combiner): + """Combiner for the Discretization preprocessing layer. + + This class encapsulates the computations for finding the quantile boundaries + of a set of data in a stable and numerically correct way. Its associated + accumulator is a namedtuple('summaries'), representing summarizations of + the data used to generate boundaries. + + Attributes: + epsilon: Error tolerance. + num_bins: The desired number of buckets. + """ + + def __init__(self, epsilon, num_bins,): + self.epsilon = epsilon + self.num_bins = num_bins + + # TODO(mwunder): Implement elementwise per-column discretization. + + def compute(self, values, accumulator=None): + """Compute a step in this computation, returning a new accumulator.""" + + if isinstance(values, sparse_tensor.SparseTensor): + values = values.values + if tf_utils.is_ragged(values): + values = values.flat_values + flattened_input = np.reshape(values, newshape=(-1, 1)) + + summaries = [summarize(v, self.epsilon) for v in flattened_input.T] + + if accumulator is None: + return self._create_accumulator(summaries) + else: + return self._create_accumulator( + [merge_summaries(prev_summ, summ, self.epsilon) + for prev_summ, summ in zip(accumulator.summaries, summaries)]) + + def merge(self, accumulators): + """Merge several accumulators to a single accumulator.""" + # Combine accumulators and return the result. + + merged = accumulators[0].summaries + for accumulator in accumulators[1:]: + merged = [merge_summaries(prev, summary, self.epsilon) + for prev, summary in zip(merged, accumulator.summaries)] + + return self._create_accumulator(merged) + + def extract(self, accumulator): + """Convert an accumulator into a dict of output values.""" + + boundaries = [np.append(get_bucket_boundaries(summary, self.num_bins), + [np.Inf]) + for summary in accumulator.summaries] + return { + _BINS_NAME: np.squeeze(np.vstack(boundaries)) + } + + def restore(self, output): + """Create an accumulator based on 'output'.""" + raise NotImplementedError( + "Discretization does not restore or support streaming updates.") + + def serialize(self, accumulator): + """Serialize an accumulator for a remote call.""" + output_dict = { + _BINS_NAME: [summary.tolist() for summary in accumulator.summaries] + } + return compat.as_bytes(json.dumps(output_dict)) + + def deserialize(self, encoded_accumulator): + """Deserialize an accumulator received from 'serialize()'.""" + value_dict = json.loads(compat.as_text(encoded_accumulator)) + return self._create_accumulator(np.array(value_dict[_BINS_NAME])) + + def _create_accumulator(self, summaries): + """Represent the accumulator as one or more summaries of the dataset.""" + return collections.namedtuple("Accumulator", ["summaries"])(summaries) diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_test.py b/tensorflow/python/keras/layers/preprocessing/discretization_test.py index 9d04ccc26a5..0226355cc76 100644 --- a/tensorflow/python/keras/layers/preprocessing/discretization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/discretization_test.py @@ -18,19 +18,32 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + import numpy as np from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers.preprocessing import discretization +from tensorflow.python.keras.layers.preprocessing import discretization_v1 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test +def get_layer_class(): + if context.executing_eagerly(): + return discretization.Discretization + else: + return discretization_v1.Discretization + + @keras_parameterized.run_all_keras_modes class DiscretizationTest(keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): @@ -106,7 +119,6 @@ class DiscretizationTest(keras_parameterized.TestCase, layer = discretization.Discretization(bins=[-.5, 0.5, 1.5]) bucket_data = layer(input_data) self.assertAllEqual(expected_output_shape, bucket_data.shape.as_list()) - model = keras.Model(inputs=input_data, outputs=bucket_data) output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) @@ -125,6 +137,110 @@ class DiscretizationTest(keras_parameterized.TestCase, self.assertAllEqual(indices, output_dataset.indices) self.assertAllEqual(expected_output, output_dataset.values) + @parameterized.named_parameters([ + { + "testcase_name": "2d_single_element", + "adapt_data": np.array([[1.], [2.], [3.], [4.], [5.]]), + "test_data": np.array([[1.], [2.], [3.]]), + "use_dataset": True, + "expected": np.array([[0], [1], [2]]), + "num_bins": 5, + "epsilon": 0.01 + }, { + "testcase_name": "2d_multi_element", + "adapt_data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.], + [5., 10.]]), + "test_data": np.array([[1., 10.], [2., 6.], [3., 8.]]), + "use_dataset": True, + "expected": np.array([[0, 4], [0, 2], [1, 3]]), + "num_bins": 5, + "epsilon": 0.01 + }, { + "testcase_name": "1d_single_element", + "adapt_data": np.array([3., 2., 1., 5., 4.]), + "test_data": np.array([1., 2., 3.]), + "use_dataset": True, + "expected": np.array([0, 1, 2]), + "num_bins": 5, + "epsilon": 0.01 + }, { + "testcase_name": "300_batch_1d_single_element_1", + "adapt_data": np.arange(300), + "test_data": np.arange(300), + "use_dataset": True, + "expected": + np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]), + "num_bins": 3, + "epsilon": 0.01 + }, { + "testcase_name": "300_batch_1d_single_element_2", + "adapt_data": np.arange(300) ** 2, + "test_data": np.arange(300) ** 2, + "use_dataset": True, + "expected": + np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]), + "num_bins": 3, + "epsilon": 0.01 + }, { + "testcase_name": "300_batch_1d_single_element_large_epsilon", + "adapt_data": np.arange(300), + "test_data": np.arange(300), + "use_dataset": True, + "expected": np.concatenate([np.zeros(137), np.ones(163)]), + "num_bins": 2, + "epsilon": 0.1 + }]) + def test_layer_computation(self, adapt_data, test_data, use_dataset, + expected, num_bins=5, epsilon=0.01): + + input_shape = tuple(list(test_data.shape)[1:]) + np.random.shuffle(adapt_data) + if use_dataset: + # Keras APIs expect batched datasets + adapt_data = dataset_ops.Dataset.from_tensor_slices(adapt_data).batch( + test_data.shape[0] // 2) + test_data = dataset_ops.Dataset.from_tensor_slices(test_data).batch( + test_data.shape[0] // 2) + + cls = get_layer_class() + layer = cls(epsilon=epsilon, bins=num_bins) + layer.adapt(adapt_data) + + input_data = keras.Input(shape=input_shape) + output = layer(input_data) + model = keras.Model(input_data, output) + model._run_eagerly = testing_utils.should_run_eagerly() + output_data = model.predict(test_data) + self.assertAllClose(expected, output_data) + + @parameterized.named_parameters( + { + "num_bins": 5, + "data": np.array([[1.], [2.], [3.], [4.], [5.]]), + "expected": { + "bins": np.array([1., 2., 3., 4., np.Inf]) + }, + "testcase_name": "2d_single_element_all_bins" + }, { + "num_bins": 5, + "data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.], [5., 10.]]), + "expected": { + "bins": np.array([2., 4., 6., 8., np.Inf]) + }, + "testcase_name": "2d_multi_element_all_bins", + }, { + "num_bins": 3, + "data": np.array([[0.], [1.], [2.], [3.], [4.], [5.]]), + "expected": { + "bins": np.array([1., 3., np.Inf]) + }, + "testcase_name": "2d_single_element_3_bins" + }) + def test_combiner_computation(self, num_bins, data, expected): + epsilon = 0.01 + combiner = discretization.Discretization.DiscretizingCombiner(epsilon, + num_bins) + self.validate_accumulator_extract(combiner, data, expected) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_v1.py b/tensorflow/python/keras/layers/preprocessing/discretization_v1.py new file mode 100644 index 00000000000..6daea9b21e6 --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/discretization_v1.py @@ -0,0 +1,28 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tensorflow V1 version of the Discretization preprocessing layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras.engine.base_preprocessing_layer_v1 import CombinerPreprocessingLayer +from tensorflow.python.keras.layers.preprocessing import discretization +from tensorflow.python.util.tf_export import keras_export + + +@keras_export(v1=['keras.layers.experimental.preprocessing.Discretization']) +class Discretization(discretization.Discretization, CombinerPreprocessingLayer): + pass diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt new file mode 100644 index 00000000000..088d4507b12 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt new file mode 100644 index 00000000000..2f75c151d66 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute" + argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "deserialize" + argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "extract" + argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge" + argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "restore" + argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize" + argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt index 628f76c84a3..87c0e792cfd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.experimental.preprocessing.Discretization" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -8,6 +9,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "DiscretizingCombiner" + mtype: "" + } member { name: "activity_regularizer" mtype: "" @@ -130,7 +135,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None" + argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], " } member_method { name: "adapt" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt new file mode 100644 index 00000000000..088d4507b12 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt new file mode 100644 index 00000000000..2f75c151d66 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute" + argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "deserialize" + argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "extract" + argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge" + argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "restore" + argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize" + argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt index 628f76c84a3..87c0e792cfd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.experimental.preprocessing.Discretization" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -8,6 +9,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "DiscretizingCombiner" + mtype: "" + } member { name: "activity_regularizer" mtype: "" @@ -130,7 +135,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None" + argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], " } member_method { name: "adapt"