413 lines
15 KiB
Python
413 lines
15 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Contains the base ProcessingLayer and a subclass that uses Combiners."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import abc
|
|
import collections
|
|
|
|
import numpy as np
|
|
import six
|
|
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import monitoring
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.framework import type_spec
|
|
from tensorflow.python.keras import backend as K
|
|
from tensorflow.python.keras.engine import training_generator_v1
|
|
from tensorflow.python.keras.engine.base_layer import Layer
|
|
from tensorflow.python.keras.utils import tf_utils
|
|
from tensorflow.python.ops import sparse_ops
|
|
from tensorflow.python.ops.ragged import ragged_tensor
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
|
|
|
|
keras_kpl_gauge = monitoring.BoolGauge(
|
|
'/tensorflow/api/keras/layers/preprocessing',
|
|
'keras preprocessing layers usage', 'method')
|
|
|
|
|
|
@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer')
|
|
@six.add_metaclass(abc.ABCMeta)
|
|
class PreprocessingLayer(Layer):
|
|
"""Base class for PreprocessingLayers."""
|
|
_must_restore_from_config = True
|
|
|
|
def adapt(self, data, reset_state=True):
|
|
# TODO(momernick): Add examples.
|
|
"""Fits the state of the preprocessing layer to the data being passed.
|
|
|
|
Args:
|
|
data: The data to train on. It can be passed either as a tf.data
|
|
Dataset, or as a numpy array.
|
|
reset_state: Optional argument specifying whether to clear the state of
|
|
the layer at the start of the call to `adapt`, or whether to start
|
|
from the existing state. This argument may not be relevant to all
|
|
preprocessing layers: a subclass of PreprocessingLayer may choose to
|
|
throw if 'reset_state' is set to False.
|
|
"""
|
|
pass
|
|
|
|
|
|
class CombinerPreprocessingLayer(PreprocessingLayer):
|
|
"""Base class for PreprocessingLayers that do computation using a Combiner.
|
|
|
|
This class provides several helper methods to make creating a
|
|
PreprocessingLayer easier. It assumes that the core of your computation will
|
|
be done via a Combiner object. Subclassing this class to create a
|
|
PreprocessingLayer allows your layer to be compatible with distributed
|
|
computation.
|
|
|
|
This class is compatible with Tensorflow 2.0+.
|
|
"""
|
|
|
|
def __init__(self, combiner, **kwargs):
|
|
super(CombinerPreprocessingLayer, self).__init__(**kwargs)
|
|
self._combiner = combiner
|
|
self._previously_updated = False
|
|
self.state_variables = collections.OrderedDict()
|
|
|
|
def _add_state_variable(self,
|
|
name,
|
|
shape,
|
|
dtype,
|
|
initializer=None,
|
|
partitioner=None,
|
|
use_resource=None,
|
|
**kwargs):
|
|
"""Add a variable that can hold state which is updated during adapt().
|
|
|
|
Args:
|
|
name: Variable name.
|
|
shape: Variable shape. Defaults to scalar if unspecified.
|
|
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
|
initializer: initializer instance (callable).
|
|
partitioner: Partitioner to be passed to the `Trackable` API.
|
|
use_resource: Whether to use `ResourceVariable`
|
|
**kwargs: Additional keyword arguments. Accepted values are `getter` and
|
|
`collections`.
|
|
|
|
Returns:
|
|
The created variable.
|
|
"""
|
|
weight = self.add_weight(
|
|
name=name,
|
|
shape=shape,
|
|
dtype=dtype,
|
|
initializer=initializer,
|
|
regularizer=None,
|
|
trainable=False,
|
|
constraint=None,
|
|
partitioner=partitioner,
|
|
use_resource=use_resource,
|
|
**kwargs)
|
|
# TODO(momernick): Do not allow collisions here.
|
|
self.state_variables[name] = weight
|
|
return weight
|
|
|
|
def _restore_updates(self):
|
|
"""Recreates a dict of updates from the layer's weights."""
|
|
data_dict = {}
|
|
for name, var in self.state_variables.items():
|
|
data_dict[name] = var.numpy()
|
|
return data_dict
|
|
|
|
def _get_dataset_iterator(self, dataset):
|
|
"""Gets an iterator from a tf.data.Dataset."""
|
|
return dataset_ops.make_one_shot_iterator(dataset).get_next
|
|
|
|
def adapt(self, data, reset_state=True):
|
|
"""Fits the state of the preprocessing layer to the data being passed.
|
|
|
|
Args:
|
|
data: The data to train on. It can be passed either as a tf.data Dataset,
|
|
or as a numpy array.
|
|
reset_state: Optional argument specifying whether to clear the state of
|
|
the layer at the start of the call to `adapt`, or whether to start from
|
|
the existing state. Subclasses may choose to throw if reset_state is set
|
|
to 'False'.
|
|
"""
|
|
if reset_state:
|
|
accumulator = None
|
|
else:
|
|
accumulator = self._combiner.restore(self._restore_updates())
|
|
if isinstance(data, (list, tuple)):
|
|
data = ops.convert_to_tensor_v2_with_dispatch(data)
|
|
if not isinstance(data,
|
|
(dataset_ops.DatasetV2,
|
|
np.ndarray,
|
|
ops.Tensor,
|
|
ragged_tensor.RaggedTensor)):
|
|
raise ValueError(
|
|
'`adapt()` requires a batched Dataset, a Tensor, '
|
|
'or a Numpy array as input, '
|
|
'got {}'.format(type(data)))
|
|
|
|
if isinstance(data, dataset_ops.DatasetV2):
|
|
# Validate that the dataset only contains single-tensor elements.
|
|
if not isinstance(data.element_spec, type_spec.TypeSpec):
|
|
raise TypeError(
|
|
'The dataset should yield single-Tensor elements. Use `dataset.map`'
|
|
'to select the element of interest.\n'
|
|
'Got dataset.element_spec=' + str(data.element_spec))
|
|
# Validate the datasets to try and ensure we haven't been passed one with
|
|
# infinite size. That would cause an infinite loop here.
|
|
if tf_utils.dataset_is_infinite(data):
|
|
raise ValueError(
|
|
'The dataset passed to `adapt()` has an infinite number of '
|
|
'elements. Please use `dataset.take(...)` to make the number '
|
|
'of elements finite.')
|
|
next_data = self._get_dataset_iterator(data)
|
|
# TODO(fchollet): consider checking if the dataset is already batched
|
|
# and otherwise batching it.
|
|
elif isinstance(data, (ops.Tensor, ragged_tensor.RaggedTensor)):
|
|
next_data = self._get_dataset_iterator(
|
|
dataset_ops.Dataset.from_tensor_slices(data).batch(512))
|
|
else:
|
|
generator, _ = training_generator_v1.convert_to_generator_like(
|
|
data, batch_size=512)
|
|
# If the data is not a dataset, we can iterate over it using next(foo);
|
|
# here, we wrap that into a callable.
|
|
next_data = lambda: next(generator)
|
|
|
|
# TODO(momernick): Some sort of status bar?
|
|
# TODO(momernick): Implement parallel processing here?
|
|
try:
|
|
data_element = next_data()
|
|
|
|
# First, see if the layer is built or not. If it is not, then we must
|
|
# build it.
|
|
if not self.built:
|
|
try:
|
|
# If this is a Numpy array or tensor, we can get shape from .shape.
|
|
# If not, an attribute error will be thrown.
|
|
data_shape = data_element.shape
|
|
data_shape_nones = tuple([None]*len(data_element.shape))
|
|
except AttributeError:
|
|
# The input has an unknown number of dimensions.
|
|
data_shape = None
|
|
data_shape_nones = None
|
|
|
|
# TODO (b/159261555): move this to base layer build.
|
|
batch_input_shape = getattr(self, '_batch_input_shape', None)
|
|
if batch_input_shape is None:
|
|
# Set the number of dimensions.
|
|
self._batch_input_shape = data_shape_nones
|
|
|
|
self.build(data_shape)
|
|
|
|
# Once we have built the Layer, we can process the input data. We do so
|
|
# until we've gotten an exception indicating that we have no more data.
|
|
while True:
|
|
accumulator = self._combiner.compute(data_element, accumulator)
|
|
data_element = next_data()
|
|
# Note that this belongs to the outer indentation of 'try' - we need to
|
|
# catch exceptions resulting from the first 'next_data()' invocation as
|
|
# well.
|
|
except (StopIteration, errors.OutOfRangeError):
|
|
pass
|
|
|
|
updates = self._combiner.extract(accumulator)
|
|
self._set_state_variables(updates)
|
|
|
|
def _set_state_variables(self, updates):
|
|
"""Directly update the internal state of this Layer.
|
|
|
|
This method expects a string-keyed dict of {state_variable_name: state}. The
|
|
precise nature of the state, and the names associated, are describe by
|
|
the subclasses of CombinerPreprocessingLayer.
|
|
|
|
Args:
|
|
updates: A string keyed dict of weights to update.
|
|
|
|
Raises:
|
|
RuntimeError: if 'build()' was not called before 'set_processing_state'.
|
|
"""
|
|
# TODO(momernick): Do we need to do any more input sanitization?
|
|
if not self.built:
|
|
raise RuntimeError('_set_state_variables() must be called after build().')
|
|
|
|
with ops.init_scope():
|
|
for var_name, value in updates.items():
|
|
self.state_variables[var_name].assign(value)
|
|
|
|
|
|
def convert_to_list(values, sparse_default_value=None):
|
|
"""Convert a TensorLike, CompositeTensor, or ndarray into a Python list."""
|
|
if tf_utils.is_ragged(values):
|
|
# There is a corner case when dealing with ragged tensors: if you get an
|
|
# actual RaggedTensor (not a RaggedTensorValue) passed in non-eager mode,
|
|
# you can't call to_list() on it without evaluating it first. However,
|
|
# because we don't yet fully support composite tensors across Keras,
|
|
# K.get_value() won't evaluate the tensor.
|
|
# TODO(momernick): Get Keras to recognize composite tensors as Tensors
|
|
# and then replace this with a call to K.get_value.
|
|
if (isinstance(values, ragged_tensor.RaggedTensor) and
|
|
not context.executing_eagerly()):
|
|
values = K.get_session(values).run(values)
|
|
values = values.to_list()
|
|
|
|
if isinstance(values,
|
|
(sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
|
|
if sparse_default_value is None:
|
|
if dtypes.as_dtype(values.values.dtype) == dtypes.string:
|
|
sparse_default_value = ''
|
|
else:
|
|
sparse_default_value = -1
|
|
dense_tensor = sparse_ops.sparse_tensor_to_dense(
|
|
values, default_value=sparse_default_value)
|
|
values = K.get_value(dense_tensor)
|
|
|
|
if isinstance(values, ops.Tensor):
|
|
values = K.get_value(values)
|
|
|
|
# We may get passed a ndarray or the code above may give us a ndarray.
|
|
# In either case, we want to force it into a standard python list.
|
|
if isinstance(values, np.ndarray):
|
|
values = values.tolist()
|
|
|
|
return values
|
|
|
|
|
|
class Combiner(object):
|
|
"""Functional object that defines a shardable computation.
|
|
|
|
This object defines functions required to create and manipulate data objects.
|
|
These data objects, referred to below as 'accumulators', are computation-
|
|
specific and may be implemented alongside concrete subclasses of Combiner
|
|
(if necessary - some computations may be simple enough that standard Python
|
|
types can be used as accumulators).
|
|
|
|
The intent for this class is that by describing computations in this way, we
|
|
can arbitrarily shard a dataset, perform computations on a subset, and then
|
|
merge the computation into a final result. This enables distributed
|
|
computation.
|
|
|
|
The combiner itself does not own any state - all computational state is owned
|
|
by the accumulator objects. This is so that we can have an arbitrary number of
|
|
Combiners (thus sharding the computation N ways) without risking any change
|
|
to the underlying computation. These accumulator objects are uniquely
|
|
associated with each Combiner; a Combiner defines what the accumulator object
|
|
should be and will only work with accumulators of that type.
|
|
"""
|
|
__metaclass__ = abc.ABCMeta
|
|
|
|
def __repr__(self):
|
|
return '<{}>'.format(self.__class__.__name__)
|
|
|
|
@abc.abstractmethod
|
|
def compute(self, batch_values, accumulator=None):
|
|
"""Compute a step in this computation, returning a new accumulator.
|
|
|
|
This method computes a step of the computation described by this Combiner.
|
|
If an accumulator is passed, the data in that accumulator is also used; so
|
|
compute(batch_values) results in f(batch_values), while
|
|
compute(batch_values, accumulator) results in
|
|
merge(f(batch_values), accumulator).
|
|
|
|
Args:
|
|
batch_values: A list of ndarrays representing the values of the inputs for
|
|
this step of the computation.
|
|
accumulator: the current accumulator. Can be None.
|
|
|
|
Returns:
|
|
An accumulator that includes the passed batch of inputs.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def merge(self, accumulators):
|
|
"""Merge several accumulators to a single accumulator.
|
|
|
|
This method takes the partial values in several accumulators and combines
|
|
them into a single accumulator. This computation must not be order-specific
|
|
(that is, merge([a, b]) must return the same result as merge([b, a]).
|
|
|
|
Args:
|
|
accumulators: the accumulators to merge, as a list.
|
|
|
|
Returns:
|
|
A merged accumulator.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def extract(self, accumulator):
|
|
"""Convert an accumulator into a dict of output values.
|
|
|
|
Args:
|
|
accumulator: The accumulator to convert.
|
|
|
|
Returns:
|
|
A dict of ndarrays representing the data in this accumulator.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def restore(self, output):
|
|
"""Create an accumulator based on 'output'.
|
|
|
|
This method creates a new accumulator with identical internal state to the
|
|
one used to create the data in 'output'. This means that if you do
|
|
|
|
output_data = combiner.extract(accumulator_1)
|
|
accumulator_2 = combiner.restore(output_data)
|
|
|
|
then accumulator_1 and accumulator_2 will have identical internal state, and
|
|
computations using either of them will be equivalent.
|
|
|
|
Args:
|
|
output: The data output from a previous computation. Should be in the same
|
|
form as provided by 'extract_output'.
|
|
|
|
Returns:
|
|
A new accumulator.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def serialize(self, accumulator):
|
|
"""Serialize an accumulator for a remote call.
|
|
|
|
This function serializes an accumulator to be sent to a remote process.
|
|
|
|
Args:
|
|
accumulator: The accumulator to serialize.
|
|
|
|
Returns:
|
|
A byte string representing the passed accumulator.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def deserialize(self, encoded_accumulator):
|
|
"""Deserialize an accumulator received from 'serialize()'.
|
|
|
|
This function deserializes an accumulator serialized by 'serialize()'.
|
|
|
|
Args:
|
|
encoded_accumulator: A byte string representing an accumulator.
|
|
|
|
Returns:
|
|
The accumulator represented by the passed byte_string.
|
|
"""
|
|
pass
|