Add PreprocessingStage (sequential) class.

PiperOrigin-RevId: 295230713
Change-Id: Id2f50b7f3cb0c11e6a59e311d768a9edf9cdc057
This commit is contained in:
Francois Chollet 2020-02-14 14:42:43 -08:00 committed by TensorFlower Gardener
parent 79ed5077ce
commit 6fc343905a
7 changed files with 251 additions and 20 deletions

View File

@ -22,7 +22,6 @@ import collections
import numpy as np
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
@ -32,7 +31,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.ops import math_ops
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util.tf_export import keras_export
@ -123,11 +122,6 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
data_dict[name] = var.numpy()
return data_dict
def _dataset_is_infinite(self, dataset):
"""True if the passed dataset is infinite."""
return math_ops.equal(
cardinality.cardinality(dataset), cardinality.INFINITE)
def _get_dataset_iterator(self, dataset):
"""Gets an iterator from a tf.data.Dataset."""
return dataset_ops.make_one_shot_iterator(dataset).get_next
@ -148,18 +142,20 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
else:
accumulator = self._combiner.restore(self._restore_updates())
if not isinstance(data, (dataset_ops.DatasetV2, np.ndarray)):
if not isinstance(data,
(dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
raise ValueError(
'adapt() requires a Dataset or a Numpy array as input, got {}'.format(
type(data)))
'`adapt()` requires a batched Dataset, an EagerTensor, '
'or a Numpy array as input, '
'got {}'.format(type(data)))
if isinstance(data, dataset_ops.DatasetV2):
# Validate the datasets to try and ensure we haven't been passed one with
# infinite size. That would cause an infinite loop here.
if self._dataset_is_infinite(data):
if tf_utils.dataset_is_infinite(data):
raise ValueError(
'The dataset passed to "adapt()" has an infinite number of '
'elements. Please use dataset.take(...) to make the number '
'The dataset passed to `adapt()` has an infinite number of '
'elements. Please use `dataset.take(...)` to make the number '
'of elements finite.')
next_data = self._get_dataset_iterator(data)
else:

View File

@ -135,7 +135,7 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
input_dataset = [1, 2, 3, 4, 5]
layer = get_layer()
with self.assertRaisesRegex(ValueError, ".*a Dataset or a Numpy.*"):
with self.assertRaisesRegex(ValueError, "requires a"):
layer.adapt(input_dataset)
def test_adapt_infinite_dataset_fails(self):

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_preprocessing_layer
@ -54,11 +53,6 @@ class CombinerPreprocessingLayer(
data_dict[name] = K.get_session().run(var)
return data_dict
def _dataset_is_infinite(self, dataset):
"""True if the passed dataset is infinite."""
dataset_size = K.get_session().run(cardinality.cardinality(dataset))
return dataset_size == cardinality.INFINITE
def _get_dataset_iterator(self, dataset):
"""Gets an iterator from a tf.data.Dataset."""
iterator = dataset_ops.make_one_shot_iterator(dataset)

View File

@ -23,6 +23,7 @@ py_library(
":hashing",
":image_preprocessing",
":normalization",
":preprocessing_stage",
":preprocessing_test_utils",
":reduction",
":text_vectorization",
@ -204,6 +205,20 @@ py_library(
],
)
py_library(
name = "preprocessing_stage",
srcs = [
"preprocessing_stage.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/keras:engine",
"//tensorflow/python/keras/engine:base_preprocessing_layer",
],
)
py_library(
name = "preprocessing_test_utils",
srcs = ["preprocessing_test_utils.py"],
@ -344,3 +359,16 @@ tf_py_test(
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "preprocessing_stage_test",
srcs = ["preprocessing_stage_test.py"],
python_version = "PY3",
deps = [
":preprocessing_stage",
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -0,0 +1,96 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing stage."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.utils import tf_utils
class PreprocessingStage(base_preprocessing_layer.PreprocessingLayer,
sequential.Sequential):
"""A sequential preprocessing stage.
This preprocessing stage wraps a list of preprocessing layers into a
Sequential-like object that enables you to `adapt()` the whole list via
a single `adapt()` call on the preprocessing stage.
Arguments:
layers: List of layers. Can include layers that aren't preprocessing layers.
name: String. Optional name for the preprocessing stage object.
"""
def adapt(self, data, reset_state=True):
"""Adapt the state of the layers of the preprocessing stage to the data.
Arguments:
data: A batched Dataset object, or a NumPy array, or an EagerTensor.
Data to be iterated over to adapt the state of the layers in this
preprocessing stage.
reset_state: Whether this call to `adapt` should reset the state of
the layers in this preprocessing stage.
"""
if not isinstance(data,
(dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
raise ValueError(
'`adapt()` requires a batched Dataset, an EagerTensor, '
'or a Numpy array as input, '
'got {}'.format(type(data)))
if isinstance(data, dataset_ops.DatasetV2):
# Validate the datasets to try and ensure we haven't been passed one with
# infinite size. That would cause an infinite loop here.
if tf_utils.dataset_is_infinite(data):
raise ValueError(
'The dataset passed to `adapt()` has an infinite number of '
'elements. Please use dataset.take(...) to make the number '
'of elements finite.')
for current_layer_index in range(0, len(self.layers)):
if not hasattr(self.layers[current_layer_index], 'adapt'):
# Skip any layer that does not need adapting.
continue
def map_fn(x):
"""Maps `PreprocessingStage` inputs to inputs at `current_layer_index`.
Args:
x: Batch of inputs seen in entry of the `PreprocessingStage` instance.
Returns:
Batch of inputs to be processed by layer
`self.layers[current_layer_index]`
"""
if current_layer_index == 0: # pylint: disable=cell-var-from-loop
return x
for i in range(current_layer_index): # pylint: disable=cell-var-from-loop
x = self.layers[i](x)
return x
if isinstance(data, dataset_ops.DatasetV2):
current_layer_data = data.map(map_fn)
else:
current_layer_data = map_fn(data)
self.layers[current_layer_index].adapt(current_layer_data,
reset_state=reset_state)

View File

@ -0,0 +1,105 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing stage tests."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.layers.preprocessing import normalization
from tensorflow.python.keras.layers.preprocessing import preprocessing_stage
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class PreprocessingStageTest(
keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_adapt(self):
class PL(base_preprocessing_layer.PreprocessingLayer):
def __init__(self, **kwargs):
self.adapt_time = None
self.adapt_count = 0
super(PL, self).__init__(**kwargs)
def adapt(self, data, reset_state=True):
self.adapt_time = time.time()
self.adapt_count += 1
def call(self, inputs):
return inputs + 1.
# Test with NumPy array
stage = preprocessing_stage.PreprocessingStage([
PL(),
PL(),
PL(),
])
stage.adapt(np.ones((3, 4)))
self.assertEqual(stage.layers[0].adapt_count, 1)
self.assertEqual(stage.layers[1].adapt_count, 1)
self.assertEqual(stage.layers[2].adapt_count, 1)
self.assertLess(stage.layers[0].adapt_time, stage.layers[1].adapt_time)
self.assertLess(stage.layers[1].adapt_time, stage.layers[2].adapt_time)
# Check call
y = stage(array_ops.ones((3, 4)))
self.assertAllClose(y, np.ones((3, 4)) + 3.)
# Test with dataset
adapt_data = dataset_ops.Dataset.from_tensor_slices(np.ones((3, 10)))
adapt_data = adapt_data.batch(2) # 5 batches of 2 samples
stage.adapt(adapt_data)
self.assertEqual(stage.layers[0].adapt_count, 2)
self.assertEqual(stage.layers[1].adapt_count, 2)
self.assertEqual(stage.layers[2].adapt_count, 2)
self.assertLess(stage.layers[0].adapt_time, stage.layers[1].adapt_time)
self.assertLess(stage.layers[1].adapt_time, stage.layers[2].adapt_time)
# Test error with bad data
with self.assertRaisesRegex(ValueError, 'requires a '):
stage.adapt(None)
def test_mixing_preprocessing_and_regular_layers(self):
stage = preprocessing_stage.PreprocessingStage([
image_preprocessing.CenterCrop(16, 16),
normalization.Normalization(),
convolutional.Conv2D(4, 3)
])
data = np.ones((16, 20, 20, 3), dtype='float32')
stage.adapt(data)
_ = stage(data)
stage.compile('rmsprop', 'mse')
stage.fit(data, np.ones((16, 14, 14, 4)))
_ = stage.evaluate(data, np.ones((16, 14, 14, 4)))
_ = stage.predict(data)
if __name__ == '__main__':
test.main()

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import six
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import ops
@ -29,6 +30,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
@ -452,3 +454,13 @@ def graph_context_for_symbolic_tensors(*args, **kwargs):
yield
else:
yield
def dataset_is_infinite(dataset):
"""True if the passed dataset is infinite."""
if ops.executing_eagerly_outside_functions():
return math_ops.equal(
cardinality.cardinality(dataset), cardinality.INFINITE)
else:
dataset_size = K.get_session().run(cardinality.cardinality(dataset))
return dataset_size == cardinality.INFINITE