Add PreprocessingStage (sequential) class.
PiperOrigin-RevId: 295230713 Change-Id: Id2f50b7f3cb0c11e6a59e311d768a9edf9cdc057
This commit is contained in:
parent
79ed5077ce
commit
6fc343905a
@ -22,7 +22,6 @@ import collections
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import cardinality
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -32,7 +31,7 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.engine import training_generator
|
from tensorflow.python.keras.engine import training_generator
|
||||||
from tensorflow.python.keras.engine.base_layer import Layer
|
from tensorflow.python.keras.engine.base_layer import Layer
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
@ -123,11 +122,6 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
|||||||
data_dict[name] = var.numpy()
|
data_dict[name] = var.numpy()
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def _dataset_is_infinite(self, dataset):
|
|
||||||
"""True if the passed dataset is infinite."""
|
|
||||||
return math_ops.equal(
|
|
||||||
cardinality.cardinality(dataset), cardinality.INFINITE)
|
|
||||||
|
|
||||||
def _get_dataset_iterator(self, dataset):
|
def _get_dataset_iterator(self, dataset):
|
||||||
"""Gets an iterator from a tf.data.Dataset."""
|
"""Gets an iterator from a tf.data.Dataset."""
|
||||||
return dataset_ops.make_one_shot_iterator(dataset).get_next
|
return dataset_ops.make_one_shot_iterator(dataset).get_next
|
||||||
@ -148,18 +142,20 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
|||||||
else:
|
else:
|
||||||
accumulator = self._combiner.restore(self._restore_updates())
|
accumulator = self._combiner.restore(self._restore_updates())
|
||||||
|
|
||||||
if not isinstance(data, (dataset_ops.DatasetV2, np.ndarray)):
|
if not isinstance(data,
|
||||||
|
(dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'adapt() requires a Dataset or a Numpy array as input, got {}'.format(
|
'`adapt()` requires a batched Dataset, an EagerTensor, '
|
||||||
type(data)))
|
'or a Numpy array as input, '
|
||||||
|
'got {}'.format(type(data)))
|
||||||
|
|
||||||
if isinstance(data, dataset_ops.DatasetV2):
|
if isinstance(data, dataset_ops.DatasetV2):
|
||||||
# Validate the datasets to try and ensure we haven't been passed one with
|
# Validate the datasets to try and ensure we haven't been passed one with
|
||||||
# infinite size. That would cause an infinite loop here.
|
# infinite size. That would cause an infinite loop here.
|
||||||
if self._dataset_is_infinite(data):
|
if tf_utils.dataset_is_infinite(data):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The dataset passed to "adapt()" has an infinite number of '
|
'The dataset passed to `adapt()` has an infinite number of '
|
||||||
'elements. Please use dataset.take(...) to make the number '
|
'elements. Please use `dataset.take(...)` to make the number '
|
||||||
'of elements finite.')
|
'of elements finite.')
|
||||||
next_data = self._get_dataset_iterator(data)
|
next_data = self._get_dataset_iterator(data)
|
||||||
else:
|
else:
|
||||||
|
@ -135,7 +135,7 @@ class PreprocessingLayerTest(keras_parameterized.TestCase):
|
|||||||
input_dataset = [1, 2, 3, 4, 5]
|
input_dataset = [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
layer = get_layer()
|
layer = get_layer()
|
||||||
with self.assertRaisesRegex(ValueError, ".*a Dataset or a Numpy.*"):
|
with self.assertRaisesRegex(ValueError, "requires a"):
|
||||||
layer.adapt(input_dataset)
|
layer.adapt(input_dataset)
|
||||||
|
|
||||||
def test_adapt_infinite_dataset_fails(self):
|
def test_adapt_infinite_dataset_fails(self):
|
||||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import cardinality
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||||
@ -54,11 +53,6 @@ class CombinerPreprocessingLayer(
|
|||||||
data_dict[name] = K.get_session().run(var)
|
data_dict[name] = K.get_session().run(var)
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def _dataset_is_infinite(self, dataset):
|
|
||||||
"""True if the passed dataset is infinite."""
|
|
||||||
dataset_size = K.get_session().run(cardinality.cardinality(dataset))
|
|
||||||
return dataset_size == cardinality.INFINITE
|
|
||||||
|
|
||||||
def _get_dataset_iterator(self, dataset):
|
def _get_dataset_iterator(self, dataset):
|
||||||
"""Gets an iterator from a tf.data.Dataset."""
|
"""Gets an iterator from a tf.data.Dataset."""
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||||
|
@ -23,6 +23,7 @@ py_library(
|
|||||||
":hashing",
|
":hashing",
|
||||||
":image_preprocessing",
|
":image_preprocessing",
|
||||||
":normalization",
|
":normalization",
|
||||||
|
":preprocessing_stage",
|
||||||
":preprocessing_test_utils",
|
":preprocessing_test_utils",
|
||||||
":reduction",
|
":reduction",
|
||||||
":text_vectorization",
|
":text_vectorization",
|
||||||
@ -204,6 +205,20 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "preprocessing_stage",
|
||||||
|
srcs = [
|
||||||
|
"preprocessing_stage.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/keras:engine",
|
||||||
|
"//tensorflow/python/keras/engine:base_preprocessing_layer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "preprocessing_test_utils",
|
name = "preprocessing_test_utils",
|
||||||
srcs = ["preprocessing_test_utils.py"],
|
srcs = ["preprocessing_test_utils.py"],
|
||||||
@ -344,3 +359,16 @@ tf_py_test(
|
|||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "preprocessing_stage_test",
|
||||||
|
srcs = ["preprocessing_stage_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":preprocessing_stage",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/keras",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Preprocessing stage."""
|
||||||
|
# pylint: disable=g-classes-have-attributes
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||||
|
from tensorflow.python.keras.engine import sequential
|
||||||
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingStage(base_preprocessing_layer.PreprocessingLayer,
|
||||||
|
sequential.Sequential):
|
||||||
|
"""A sequential preprocessing stage.
|
||||||
|
|
||||||
|
This preprocessing stage wraps a list of preprocessing layers into a
|
||||||
|
Sequential-like object that enables you to `adapt()` the whole list via
|
||||||
|
a single `adapt()` call on the preprocessing stage.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
layers: List of layers. Can include layers that aren't preprocessing layers.
|
||||||
|
name: String. Optional name for the preprocessing stage object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def adapt(self, data, reset_state=True):
|
||||||
|
"""Adapt the state of the layers of the preprocessing stage to the data.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
data: A batched Dataset object, or a NumPy array, or an EagerTensor.
|
||||||
|
Data to be iterated over to adapt the state of the layers in this
|
||||||
|
preprocessing stage.
|
||||||
|
reset_state: Whether this call to `adapt` should reset the state of
|
||||||
|
the layers in this preprocessing stage.
|
||||||
|
"""
|
||||||
|
if not isinstance(data,
|
||||||
|
(dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
|
||||||
|
raise ValueError(
|
||||||
|
'`adapt()` requires a batched Dataset, an EagerTensor, '
|
||||||
|
'or a Numpy array as input, '
|
||||||
|
'got {}'.format(type(data)))
|
||||||
|
if isinstance(data, dataset_ops.DatasetV2):
|
||||||
|
# Validate the datasets to try and ensure we haven't been passed one with
|
||||||
|
# infinite size. That would cause an infinite loop here.
|
||||||
|
if tf_utils.dataset_is_infinite(data):
|
||||||
|
raise ValueError(
|
||||||
|
'The dataset passed to `adapt()` has an infinite number of '
|
||||||
|
'elements. Please use dataset.take(...) to make the number '
|
||||||
|
'of elements finite.')
|
||||||
|
|
||||||
|
for current_layer_index in range(0, len(self.layers)):
|
||||||
|
if not hasattr(self.layers[current_layer_index], 'adapt'):
|
||||||
|
# Skip any layer that does not need adapting.
|
||||||
|
continue
|
||||||
|
|
||||||
|
def map_fn(x):
|
||||||
|
"""Maps `PreprocessingStage` inputs to inputs at `current_layer_index`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Batch of inputs seen in entry of the `PreprocessingStage` instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Batch of inputs to be processed by layer
|
||||||
|
`self.layers[current_layer_index]`
|
||||||
|
"""
|
||||||
|
if current_layer_index == 0: # pylint: disable=cell-var-from-loop
|
||||||
|
return x
|
||||||
|
for i in range(current_layer_index): # pylint: disable=cell-var-from-loop
|
||||||
|
x = self.layers[i](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
if isinstance(data, dataset_ops.DatasetV2):
|
||||||
|
current_layer_data = data.map(map_fn)
|
||||||
|
else:
|
||||||
|
current_layer_data = map_fn(data)
|
||||||
|
self.layers[current_layer_index].adapt(current_layer_data,
|
||||||
|
reset_state=reset_state)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,105 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Preprocessing stage tests."""
|
||||||
|
# pylint: disable=g-classes-have-attributes
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
|
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||||
|
from tensorflow.python.keras.layers import convolutional
|
||||||
|
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||||
|
from tensorflow.python.keras.layers.preprocessing import normalization
|
||||||
|
from tensorflow.python.keras.layers.preprocessing import preprocessing_stage
|
||||||
|
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
|
class PreprocessingStageTest(
|
||||||
|
keras_parameterized.TestCase,
|
||||||
|
preprocessing_test_utils.PreprocessingLayerTest):
|
||||||
|
|
||||||
|
def test_adapt(self):
|
||||||
|
|
||||||
|
class PL(base_preprocessing_layer.PreprocessingLayer):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.adapt_time = None
|
||||||
|
self.adapt_count = 0
|
||||||
|
super(PL, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def adapt(self, data, reset_state=True):
|
||||||
|
self.adapt_time = time.time()
|
||||||
|
self.adapt_count += 1
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return inputs + 1.
|
||||||
|
|
||||||
|
# Test with NumPy array
|
||||||
|
stage = preprocessing_stage.PreprocessingStage([
|
||||||
|
PL(),
|
||||||
|
PL(),
|
||||||
|
PL(),
|
||||||
|
])
|
||||||
|
stage.adapt(np.ones((3, 4)))
|
||||||
|
self.assertEqual(stage.layers[0].adapt_count, 1)
|
||||||
|
self.assertEqual(stage.layers[1].adapt_count, 1)
|
||||||
|
self.assertEqual(stage.layers[2].adapt_count, 1)
|
||||||
|
self.assertLess(stage.layers[0].adapt_time, stage.layers[1].adapt_time)
|
||||||
|
self.assertLess(stage.layers[1].adapt_time, stage.layers[2].adapt_time)
|
||||||
|
|
||||||
|
# Check call
|
||||||
|
y = stage(array_ops.ones((3, 4)))
|
||||||
|
self.assertAllClose(y, np.ones((3, 4)) + 3.)
|
||||||
|
|
||||||
|
# Test with dataset
|
||||||
|
adapt_data = dataset_ops.Dataset.from_tensor_slices(np.ones((3, 10)))
|
||||||
|
adapt_data = adapt_data.batch(2) # 5 batches of 2 samples
|
||||||
|
|
||||||
|
stage.adapt(adapt_data)
|
||||||
|
self.assertEqual(stage.layers[0].adapt_count, 2)
|
||||||
|
self.assertEqual(stage.layers[1].adapt_count, 2)
|
||||||
|
self.assertEqual(stage.layers[2].adapt_count, 2)
|
||||||
|
self.assertLess(stage.layers[0].adapt_time, stage.layers[1].adapt_time)
|
||||||
|
self.assertLess(stage.layers[1].adapt_time, stage.layers[2].adapt_time)
|
||||||
|
|
||||||
|
# Test error with bad data
|
||||||
|
with self.assertRaisesRegex(ValueError, 'requires a '):
|
||||||
|
stage.adapt(None)
|
||||||
|
|
||||||
|
def test_mixing_preprocessing_and_regular_layers(self):
|
||||||
|
stage = preprocessing_stage.PreprocessingStage([
|
||||||
|
image_preprocessing.CenterCrop(16, 16),
|
||||||
|
normalization.Normalization(),
|
||||||
|
convolutional.Conv2D(4, 3)
|
||||||
|
])
|
||||||
|
data = np.ones((16, 20, 20, 3), dtype='float32')
|
||||||
|
stage.adapt(data)
|
||||||
|
_ = stage(data)
|
||||||
|
stage.compile('rmsprop', 'mse')
|
||||||
|
stage.fit(data, np.ones((16, 14, 14, 4)))
|
||||||
|
_ = stage.evaluate(data, np.ones((16, 14, 14, 4)))
|
||||||
|
_ = stage.predict(data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.ops import cardinality
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -29,6 +30,7 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.framework import type_spec
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import object_identity
|
from tensorflow.python.util import object_identity
|
||||||
@ -452,3 +454,13 @@ def graph_context_for_symbolic_tensors(*args, **kwargs):
|
|||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_is_infinite(dataset):
|
||||||
|
"""True if the passed dataset is infinite."""
|
||||||
|
if ops.executing_eagerly_outside_functions():
|
||||||
|
return math_ops.equal(
|
||||||
|
cardinality.cardinality(dataset), cardinality.INFINITE)
|
||||||
|
else:
|
||||||
|
dataset_size = K.get_session().run(cardinality.cardinality(dataset))
|
||||||
|
return dataset_size == cardinality.INFINITE
|
||||||
|
Loading…
x
Reference in New Issue
Block a user