Fork composite_tensor_utils.append_composite_tensor into Keras so that Keras does not need to depend on the internal TF method. Facilitates splitting Keras into its own repository.
PiperOrigin-RevId: 342336991 Change-Id: Icc4d31db296c51289eede0d3d5d3b32a8bf00a1e
This commit is contained in:
parent
0dd94c4ad3
commit
31c20f9e8a
@ -37,7 +37,6 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import composite_tensor_utils
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -56,6 +55,7 @@ from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -155,6 +155,119 @@ class MetricsAggregator(Aggregator):
|
||||
self.results[0] /= (self.num_samples or self.steps)
|
||||
|
||||
|
||||
def _append_sparse_tensor_value(target, to_append):
|
||||
"""Append sparse tensor value objects."""
|
||||
# Make sure the sparse tensors are of the same size (except for the 0th dim).
|
||||
if len(target.dense_shape) != len(to_append.dense_shape):
|
||||
raise RuntimeError(
|
||||
'Unable to concatenate %s and %s. The inner dense shapes do not '
|
||||
'have the same number of dimensions (%s vs %s)' %
|
||||
(target, to_append, target.dense_shape, to_append.dense_shape))
|
||||
|
||||
if target.dense_shape[1:] != to_append.dense_shape[1:]:
|
||||
raise RuntimeError(
|
||||
'Unable to concatenate %s and %s. The inner dense shapes do not '
|
||||
'match inner dimensions (%s vs %s)' %
|
||||
(target, to_append, target.dense_shape[1:], to_append.dense_shape[1:]))
|
||||
|
||||
# Add the to_append indices to target, updating the 0th value, and keeping
|
||||
# track of the maximum so we know the final dense_shape of this tensor.
|
||||
base_dim0_value = target.dense_shape[0]
|
||||
max_dim0_value = target.dense_shape[0]
|
||||
new_indices = target.indices
|
||||
for index in to_append.indices:
|
||||
# Here, we iterate through the sparse indices of the tensor to append. For
|
||||
# each index, we update its zeroth value (the batch index) by adding the
|
||||
# number of batch items in the tensor we are appending to (so an index
|
||||
# of [0, 0, 1] for a value that is being appended to a tensor with 0th dim
|
||||
# size 3 would become [3, 0, 1].)
|
||||
index[0] += base_dim0_value
|
||||
max_dim0_value = max(max_dim0_value, index[0])
|
||||
new_indices = np.append(new_indices, [index], axis=0)
|
||||
|
||||
# Extend the values array to contain all of the appended values. These will
|
||||
# be in the same order as the indices added above.
|
||||
new_values = np.concatenate((target.values, to_append.values), axis=0)
|
||||
|
||||
# Create a new dense shape by replacing the value for the 0th dimension
|
||||
# with the new max dim0 value.
|
||||
new_dense_shape = list(target.dense_shape)
|
||||
new_dense_shape[0] = max_dim0_value + 1
|
||||
new_dense_shape = tuple(new_dense_shape)
|
||||
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=new_indices, values=new_values, dense_shape=new_dense_shape)
|
||||
|
||||
|
||||
def _append_ragged_tensor_value(target, to_append):
|
||||
"""Append ragged tensor value objects."""
|
||||
# Make sure the ragged tensors are of the same size (save for the 0th dim).
|
||||
if len(target.shape) != len(to_append.shape):
|
||||
raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
|
||||
|
||||
if target.shape[1:] != to_append.shape[1:]:
|
||||
raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
|
||||
|
||||
adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1]
|
||||
new_row_splits = np.append(target.row_splits, adjusted_row_splits)
|
||||
if isinstance(target.values, ragged_tensor_value.RaggedTensorValue):
|
||||
new_values = _append_ragged_tensor_value(target.values, to_append.values)
|
||||
else:
|
||||
new_values = np.concatenate((target.values, to_append.values), axis=0)
|
||||
|
||||
return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits)
|
||||
|
||||
|
||||
def _append_composite_tensor(target, to_append):
|
||||
"""Helper function to append composite tensors to each other in the 0 axis.
|
||||
|
||||
In order to support batching within a fit/evaluate/predict call, we need
|
||||
to be able to aggregate within a CompositeTensor. Unfortunately, the CT
|
||||
API currently does not make this easy - especially in V1 mode, where we're
|
||||
working with CompositeTensor Value objects that have no connection with the
|
||||
CompositeTensors that created them.
|
||||
|
||||
Arguments:
|
||||
target: CompositeTensor or CompositeTensor value object that will be
|
||||
appended to.
|
||||
to_append: CompositeTensor or CompositeTensor value object to append to.
|
||||
'target'.
|
||||
|
||||
Returns:
|
||||
A CompositeTensor or CompositeTensor value object.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if concatenation is not possible.
|
||||
"""
|
||||
if type(target) is not type(to_append):
|
||||
raise RuntimeError('Unable to concatenate %s and %s' %
|
||||
(type(target), type(to_append)))
|
||||
|
||||
# Perform type-specific concatenation.
|
||||
# TODO(b/125094323): This should be replaced by a simple call to
|
||||
# target.append() that should work on all of the below classes.
|
||||
|
||||
# If we're seeing a CompositeTensor here, we know it's because we're in
|
||||
# Eager mode (or else we'd have evaluated the CT to a CT Value object
|
||||
# already). Therefore, it's safe to call concat() on it without evaluating
|
||||
# the result any further. If not - that is, if we're seeing a
|
||||
# SparseTensorValue or a RaggedTensorValue - we need to hand-update it
|
||||
# since we're outside of the graph anyways.
|
||||
if isinstance(target, sparse_tensor.SparseTensor):
|
||||
# We need to invoke the sparse version of concatenate here - tf.concat
|
||||
# won't work.
|
||||
return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0)
|
||||
elif isinstance(target, ragged_tensor.RaggedTensor):
|
||||
return array_ops.concat([target, to_append], axis=0)
|
||||
elif isinstance(target, sparse_tensor.SparseTensorValue):
|
||||
return _append_sparse_tensor_value(target, to_append)
|
||||
elif isinstance(target, ragged_tensor_value.RaggedTensorValue):
|
||||
return _append_ragged_tensor_value(target, to_append)
|
||||
else:
|
||||
raise RuntimeError('Attempted to concatenate unsupported object %s.' %
|
||||
type(target))
|
||||
|
||||
|
||||
class ConcatAggregator(Aggregator):
|
||||
"""Combine tensor-likes which cannot be merged on the fly.
|
||||
|
||||
@ -191,7 +304,7 @@ class ConcatAggregator(Aggregator):
|
||||
# TODO(taylorrobie): efficiently concatenate.
|
||||
results = self.results[0]
|
||||
for r in self.results[1:]:
|
||||
results = composite_tensor_utils.append_composite_tensor(results, r)
|
||||
results = _append_composite_tensor(results, r)
|
||||
self.results = results
|
||||
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user