From de5e35d7feb4252a172cd454347d13f66693f478 Mon Sep 17 00:00:00 2001 From: Yanhui Liang Date: Wed, 4 Nov 2020 10:18:54 -0800 Subject: [PATCH] Update the usages of `ragged_concat_ops.concat` to `array_ops.concat` in Keras codebase. PiperOrigin-RevId: 340678363 Change-Id: Ic366ee468156c3a89ea2121c2c4eeec5857b6c06 --- tensorflow/python/keras/backend.py | 3 +-- .../python/keras/distribute/distributed_training_utils_v1.py | 4 ++-- tensorflow/python/keras/engine/training.py | 4 ---- tensorflow/python/keras/layers/wrappers_test.py | 3 +-- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index e6b2b65b27b..14bc13f9709 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -77,7 +77,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables as variables_module -from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import moving_averages @@ -3093,7 +3092,7 @@ def concatenate(tensors, axis=-1): if py_all(is_sparse(x) for x in tensors): return sparse_ops.sparse_concat(axis, tensors) elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors): - return ragged_concat_ops.concat(tensors, axis) + return array_ops.concat(tensors, axis) else: return array_ops.concat([to_dense(x) for x in tensors], axis) diff --git a/tensorflow/python/keras/distribute/distributed_training_utils_v1.py b/tensorflow/python/keras/distribute/distributed_training_utils_v1.py index 2e7a8299e43..c631ae07b19 100644 --- a/tensorflow/python/keras/distribute/distributed_training_utils_v1.py +++ b/tensorflow/python/keras/distribute/distributed_training_utils_v1.py @@ -42,11 +42,11 @@ from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils.mode_keys import ModeKeys +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables -from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -1154,5 +1154,5 @@ def concat_along_batch_dimension(outputs): if isinstance(outputs[0], sparse_tensor.SparseTensor): return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs) if isinstance(outputs[0], ragged_tensor.RaggedTensor): - return ragged_concat_ops.concat(outputs, axis=0) + return array_ops.concat(outputs, axis=0) return np.concatenate(outputs) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 7ccfd39c82a..96a4e1e23cc 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -70,8 +70,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables -from tensorflow.python.ops.ragged import ragged_concat_ops -from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import trace from tensorflow.python.training import checkpoint_management @@ -2724,8 +2722,6 @@ def concat(tensors, axis=0): """Concats `tensor`s along `axis`.""" if isinstance(tensors[0], sparse_tensor.SparseTensor): return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors) - if isinstance(tensors[0], ragged_tensor.RaggedTensor): - return ragged_concat_ops.concat(tensors, axis=axis) return array_ops.concat(tensors, axis=axis) diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index f1412975cc3..c60e950794f 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -37,7 +37,6 @@ from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test @@ -1205,7 +1204,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): if merge_mode == 'ave': merge_func = lambda y, y_rev: (y + y_rev) / 2 elif merge_mode == 'concat': - merge_func = lambda y, y_rev: ragged_concat_ops.concat( + merge_func = lambda y, y_rev: array_ops.concat( (y, y_rev), axis=-1) elif merge_mode == 'mul': merge_func = lambda y, y_rev: (y * y_rev)