From 31c20f9e8aafcc814481fc5e0bf59ffd0cd51b4d Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Fri, 13 Nov 2020 14:18:29 -0800 Subject: [PATCH] Fork composite_tensor_utils.append_composite_tensor into Keras so that Keras does not need to depend on the internal TF method. Facilitates splitting Keras into its own repository. PiperOrigin-RevId: 342336991 Change-Id: Icc4d31db296c51289eede0d3d5d3b32a8bf00a1e --- .../python/keras/engine/training_utils_v1.py | 117 +++++++++++++++++- 1 file changed, 115 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/engine/training_utils_v1.py b/tensorflow/python/keras/engine/training_utils_v1.py index a85a20a5313..c2e1b5e652f 100644 --- a/tensorflow/python/keras/engine/training_utils_v1.py +++ b/tensorflow/python/keras/engine/training_utils_v1.py @@ -37,7 +37,6 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import composite_tensor -from tensorflow.python.framework import composite_tensor_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -56,6 +55,7 @@ from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.platform import tf_logging as logging @@ -155,6 +155,119 @@ class MetricsAggregator(Aggregator): self.results[0] /= (self.num_samples or self.steps) +def _append_sparse_tensor_value(target, to_append): + """Append sparse tensor value objects.""" + # Make sure the sparse tensors are of the same size (except for the 0th dim). + if len(target.dense_shape) != len(to_append.dense_shape): + raise RuntimeError( + 'Unable to concatenate %s and %s. The inner dense shapes do not ' + 'have the same number of dimensions (%s vs %s)' % + (target, to_append, target.dense_shape, to_append.dense_shape)) + + if target.dense_shape[1:] != to_append.dense_shape[1:]: + raise RuntimeError( + 'Unable to concatenate %s and %s. The inner dense shapes do not ' + 'match inner dimensions (%s vs %s)' % + (target, to_append, target.dense_shape[1:], to_append.dense_shape[1:])) + + # Add the to_append indices to target, updating the 0th value, and keeping + # track of the maximum so we know the final dense_shape of this tensor. + base_dim0_value = target.dense_shape[0] + max_dim0_value = target.dense_shape[0] + new_indices = target.indices + for index in to_append.indices: + # Here, we iterate through the sparse indices of the tensor to append. For + # each index, we update its zeroth value (the batch index) by adding the + # number of batch items in the tensor we are appending to (so an index + # of [0, 0, 1] for a value that is being appended to a tensor with 0th dim + # size 3 would become [3, 0, 1].) + index[0] += base_dim0_value + max_dim0_value = max(max_dim0_value, index[0]) + new_indices = np.append(new_indices, [index], axis=0) + + # Extend the values array to contain all of the appended values. These will + # be in the same order as the indices added above. + new_values = np.concatenate((target.values, to_append.values), axis=0) + + # Create a new dense shape by replacing the value for the 0th dimension + # with the new max dim0 value. + new_dense_shape = list(target.dense_shape) + new_dense_shape[0] = max_dim0_value + 1 + new_dense_shape = tuple(new_dense_shape) + + return sparse_tensor.SparseTensorValue( + indices=new_indices, values=new_values, dense_shape=new_dense_shape) + + +def _append_ragged_tensor_value(target, to_append): + """Append ragged tensor value objects.""" + # Make sure the ragged tensors are of the same size (save for the 0th dim). + if len(target.shape) != len(to_append.shape): + raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append)) + + if target.shape[1:] != to_append.shape[1:]: + raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append)) + + adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1] + new_row_splits = np.append(target.row_splits, adjusted_row_splits) + if isinstance(target.values, ragged_tensor_value.RaggedTensorValue): + new_values = _append_ragged_tensor_value(target.values, to_append.values) + else: + new_values = np.concatenate((target.values, to_append.values), axis=0) + + return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits) + + +def _append_composite_tensor(target, to_append): + """Helper function to append composite tensors to each other in the 0 axis. + + In order to support batching within a fit/evaluate/predict call, we need + to be able to aggregate within a CompositeTensor. Unfortunately, the CT + API currently does not make this easy - especially in V1 mode, where we're + working with CompositeTensor Value objects that have no connection with the + CompositeTensors that created them. + + Arguments: + target: CompositeTensor or CompositeTensor value object that will be + appended to. + to_append: CompositeTensor or CompositeTensor value object to append to. + 'target'. + + Returns: + A CompositeTensor or CompositeTensor value object. + + Raises: + RuntimeError: if concatenation is not possible. + """ + if type(target) is not type(to_append): + raise RuntimeError('Unable to concatenate %s and %s' % + (type(target), type(to_append))) + + # Perform type-specific concatenation. + # TODO(b/125094323): This should be replaced by a simple call to + # target.append() that should work on all of the below classes. + + # If we're seeing a CompositeTensor here, we know it's because we're in + # Eager mode (or else we'd have evaluated the CT to a CT Value object + # already). Therefore, it's safe to call concat() on it without evaluating + # the result any further. If not - that is, if we're seeing a + # SparseTensorValue or a RaggedTensorValue - we need to hand-update it + # since we're outside of the graph anyways. + if isinstance(target, sparse_tensor.SparseTensor): + # We need to invoke the sparse version of concatenate here - tf.concat + # won't work. + return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0) + elif isinstance(target, ragged_tensor.RaggedTensor): + return array_ops.concat([target, to_append], axis=0) + elif isinstance(target, sparse_tensor.SparseTensorValue): + return _append_sparse_tensor_value(target, to_append) + elif isinstance(target, ragged_tensor_value.RaggedTensorValue): + return _append_ragged_tensor_value(target, to_append) + else: + raise RuntimeError('Attempted to concatenate unsupported object %s.' % + type(target)) + + class ConcatAggregator(Aggregator): """Combine tensor-likes which cannot be merged on the fly. @@ -191,7 +304,7 @@ class ConcatAggregator(Aggregator): # TODO(taylorrobie): efficiently concatenate. results = self.results[0] for r in self.results[1:]: - results = composite_tensor_utils.append_composite_tensor(results, r) + results = _append_composite_tensor(results, r) self.results = results else: