Add PreprocessingStage (sequential) class.
PiperOrigin-RevId: 295230713 Change-Id: Id2f50b7f3cb0c11e6a59e311d768a9edf9cdc057
This commit is contained in:
parent
79ed5077ce
commit
6fc343905a
@ -22,7 +22,6 @@ import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -32,7 +31,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import training_generator
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -123,11 +122,6 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
||||
data_dict[name] = var.numpy()
|
||||
return data_dict
|
||||
|
||||
def _dataset_is_infinite(self, dataset):
|
||||
"""True if the passed dataset is infinite."""
|
||||
return math_ops.equal(
|
||||
cardinality.cardinality(dataset), cardinality.INFINITE)
|
||||
|
||||
def _get_dataset_iterator(self, dataset):
|
||||
"""Gets an iterator from a tf.data.Dataset."""
|
||||
return dataset_ops.make_one_shot_iterator(dataset).get_next
|
||||
@ -148,18 +142,20 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
||||
else:
|
||||
accumulator = self._combiner.restore(self._restore_updates())
|
||||
|
||||
if not isinstance(data, (dataset_ops.DatasetV2, np.ndarray)):
|
||||
if not isinstance(data,
|
||||
(dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
|
||||
raise ValueError(
|
||||
'adapt() requires a Dataset or a Numpy array as input, got {}'.format(
|
||||
type(data)))
|
||||
'`adapt()` requires a batched Dataset, an EagerTensor, '
|
||||
'or a Numpy array as input, '
|
||||
'got {}'.format(type(data)))
|
||||
|
||||
if isinstance(data, dataset_ops.DatasetV2):
|
||||
# Validate the datasets to try and ensure we haven't been passed one with
|
||||
# infinite size. That would cause an infinite loop here.
|
||||
if self._dataset_is_infinite(data):
|
||||
if tf_utils.dataset_is_infinite(data):
|
||||
raise ValueError(
|
||||
'The dataset passed to "adapt()" has an infinite number of '
|
||||
'elements. Please use dataset.take(...) to make the number '
|
||||
'The dataset passed to `adapt()` has an infinite number of '
|
||||
'elements. Please use `dataset.take(...)` to make the number '
|
||||
'of elements finite.')
|
||||
next_data = self._get_dataset_iterator(data)
|
||||
else:
|
||||
|
@ -135,7 +135,7 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
|
||||
input_dataset = [1, 2, 3, 4, 5]
|
||||
|
||||
layer = get_layer()
|
||||
with self.assertRaisesRegex(ValueError, ".*a Dataset or a Numpy.*"):
|
||||
with self.assertRaisesRegex(ValueError, "requires a"):
|
||||
layer.adapt(input_dataset)
|
||||
|
||||
def test_adapt_infinite_dataset_fails(self):
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||
@ -54,11 +53,6 @@ class CombinerPreprocessingLayer(
|
||||
data_dict[name] = K.get_session().run(var)
|
||||
return data_dict
|
||||
|
||||
def _dataset_is_infinite(self, dataset):
|
||||
"""True if the passed dataset is infinite."""
|
||||
dataset_size = K.get_session().run(cardinality.cardinality(dataset))
|
||||
return dataset_size == cardinality.INFINITE
|
||||
|
||||
def _get_dataset_iterator(self, dataset):
|
||||
"""Gets an iterator from a tf.data.Dataset."""
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
|
@ -23,6 +23,7 @@ py_library(
|
||||
":hashing",
|
||||
":image_preprocessing",
|
||||
":normalization",
|
||||
":preprocessing_stage",
|
||||
":preprocessing_test_utils",
|
||||
":reduction",
|
||||
":text_vectorization",
|
||||
@ -204,6 +205,20 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "preprocessing_stage",
|
||||
srcs = [
|
||||
"preprocessing_stage.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/keras:engine",
|
||||
"//tensorflow/python/keras/engine:base_preprocessing_layer",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "preprocessing_test_utils",
|
||||
srcs = ["preprocessing_test_utils.py"],
|
||||
@ -344,3 +359,16 @@ tf_py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "preprocessing_stage_test",
|
||||
srcs = ["preprocessing_stage_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":preprocessing_stage",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,96 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Preprocessing stage."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
|
||||
|
||||
class PreprocessingStage(base_preprocessing_layer.PreprocessingLayer,
|
||||
sequential.Sequential):
|
||||
"""A sequential preprocessing stage.
|
||||
|
||||
This preprocessing stage wraps a list of preprocessing layers into a
|
||||
Sequential-like object that enables you to `adapt()` the whole list via
|
||||
a single `adapt()` call on the preprocessing stage.
|
||||
|
||||
Arguments:
|
||||
layers: List of layers. Can include layers that aren't preprocessing layers.
|
||||
name: String. Optional name for the preprocessing stage object.
|
||||
"""
|
||||
|
||||
def adapt(self, data, reset_state=True):
|
||||
"""Adapt the state of the layers of the preprocessing stage to the data.
|
||||
|
||||
Arguments:
|
||||
data: A batched Dataset object, or a NumPy array, or an EagerTensor.
|
||||
Data to be iterated over to adapt the state of the layers in this
|
||||
preprocessing stage.
|
||||
reset_state: Whether this call to `adapt` should reset the state of
|
||||
the layers in this preprocessing stage.
|
||||
"""
|
||||
if not isinstance(data,
|
||||
(dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
|
||||
raise ValueError(
|
||||
'`adapt()` requires a batched Dataset, an EagerTensor, '
|
||||
'or a Numpy array as input, '
|
||||
'got {}'.format(type(data)))
|
||||
if isinstance(data, dataset_ops.DatasetV2):
|
||||
# Validate the datasets to try and ensure we haven't been passed one with
|
||||
# infinite size. That would cause an infinite loop here.
|
||||
if tf_utils.dataset_is_infinite(data):
|
||||
raise ValueError(
|
||||
'The dataset passed to `adapt()` has an infinite number of '
|
||||
'elements. Please use dataset.take(...) to make the number '
|
||||
'of elements finite.')
|
||||
|
||||
for current_layer_index in range(0, len(self.layers)):
|
||||
if not hasattr(self.layers[current_layer_index], 'adapt'):
|
||||
# Skip any layer that does not need adapting.
|
||||
continue
|
||||
|
||||
def map_fn(x):
|
||||
"""Maps `PreprocessingStage` inputs to inputs at `current_layer_index`.
|
||||
|
||||
Args:
|
||||
x: Batch of inputs seen in entry of the `PreprocessingStage` instance.
|
||||
|
||||
Returns:
|
||||
Batch of inputs to be processed by layer
|
||||
`self.layers[current_layer_index]`
|
||||
"""
|
||||
if current_layer_index == 0: # pylint: disable=cell-var-from-loop
|
||||
return x
|
||||
for i in range(current_layer_index): # pylint: disable=cell-var-from-loop
|
||||
x = self.layers[i](x)
|
||||
return x
|
||||
|
||||
if isinstance(data, dataset_ops.DatasetV2):
|
||||
current_layer_data = data.map(map_fn)
|
||||
else:
|
||||
current_layer_data = map_fn(data)
|
||||
self.layers[current_layer_index].adapt(current_layer_data,
|
||||
reset_state=reset_state)
|
||||
|
||||
|
@ -0,0 +1,105 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Preprocessing stage tests."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||
from tensorflow.python.keras.layers import convolutional
|
||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.keras.layers.preprocessing import normalization
|
||||
from tensorflow.python.keras.layers.preprocessing import preprocessing_stage
|
||||
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class PreprocessingStageTest(
|
||||
keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def test_adapt(self):
|
||||
|
||||
class PL(base_preprocessing_layer.PreprocessingLayer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.adapt_time = None
|
||||
self.adapt_count = 0
|
||||
super(PL, self).__init__(**kwargs)
|
||||
|
||||
def adapt(self, data, reset_state=True):
|
||||
self.adapt_time = time.time()
|
||||
self.adapt_count += 1
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs + 1.
|
||||
|
||||
# Test with NumPy array
|
||||
stage = preprocessing_stage.PreprocessingStage([
|
||||
PL(),
|
||||
PL(),
|
||||
PL(),
|
||||
])
|
||||
stage.adapt(np.ones((3, 4)))
|
||||
self.assertEqual(stage.layers[0].adapt_count, 1)
|
||||
self.assertEqual(stage.layers[1].adapt_count, 1)
|
||||
self.assertEqual(stage.layers[2].adapt_count, 1)
|
||||
self.assertLess(stage.layers[0].adapt_time, stage.layers[1].adapt_time)
|
||||
self.assertLess(stage.layers[1].adapt_time, stage.layers[2].adapt_time)
|
||||
|
||||
# Check call
|
||||
y = stage(array_ops.ones((3, 4)))
|
||||
self.assertAllClose(y, np.ones((3, 4)) + 3.)
|
||||
|
||||
# Test with dataset
|
||||
adapt_data = dataset_ops.Dataset.from_tensor_slices(np.ones((3, 10)))
|
||||
adapt_data = adapt_data.batch(2) # 5 batches of 2 samples
|
||||
|
||||
stage.adapt(adapt_data)
|
||||
self.assertEqual(stage.layers[0].adapt_count, 2)
|
||||
self.assertEqual(stage.layers[1].adapt_count, 2)
|
||||
self.assertEqual(stage.layers[2].adapt_count, 2)
|
||||
self.assertLess(stage.layers[0].adapt_time, stage.layers[1].adapt_time)
|
||||
self.assertLess(stage.layers[1].adapt_time, stage.layers[2].adapt_time)
|
||||
|
||||
# Test error with bad data
|
||||
with self.assertRaisesRegex(ValueError, 'requires a '):
|
||||
stage.adapt(None)
|
||||
|
||||
def test_mixing_preprocessing_and_regular_layers(self):
|
||||
stage = preprocessing_stage.PreprocessingStage([
|
||||
image_preprocessing.CenterCrop(16, 16),
|
||||
normalization.Normalization(),
|
||||
convolutional.Conv2D(4, 3)
|
||||
])
|
||||
data = np.ones((16, 20, 20, 3), dtype='float32')
|
||||
stage.adapt(data)
|
||||
_ = stage(data)
|
||||
stage.compile('rmsprop', 'mse')
|
||||
stage.fit(data, np.ones((16, 14, 14, 4)))
|
||||
_ = stage.evaluate(data, np.ones((16, 14, 14, 4)))
|
||||
_ = stage.predict(data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import ops
|
||||
@ -29,6 +30,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
@ -452,3 +454,13 @@ def graph_context_for_symbolic_tensors(*args, **kwargs):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def dataset_is_infinite(dataset):
|
||||
"""True if the passed dataset is infinite."""
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
return math_ops.equal(
|
||||
cardinality.cardinality(dataset), cardinality.INFINITE)
|
||||
else:
|
||||
dataset_size = K.get_session().run(cardinality.cardinality(dataset))
|
||||
return dataset_size == cardinality.INFINITE
|
||||
|
Loading…
Reference in New Issue
Block a user