Update the usages of ragged_concat_ops.concat to array_ops.concat in Keras codebase.

PiperOrigin-RevId: 340678363
Change-Id: Ic366ee468156c3a89ea2121c2c4eeec5857b6c06
This commit is contained in:
Yanhui Liang 2020-11-04 10:18:54 -08:00 committed by TensorFlower Gardener
parent a3c9e996fb
commit de5e35d7fe
4 changed files with 4 additions and 10 deletions

View File

@ -77,7 +77,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import moving_averages
@ -3093,7 +3092,7 @@ def concatenate(tensors, axis=-1):
if py_all(is_sparse(x) for x in tensors):
return sparse_ops.sparse_concat(axis, tensors)
elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors):
return ragged_concat_ops.concat(tensors, axis)
return array_ops.concat(tensors, axis)
else:
return array_ops.concat([to_dense(x) for x in tensors], axis)

View File

@ -42,11 +42,11 @@ from tensorflow.python.keras.engine import training_utils_v1
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -1154,5 +1154,5 @@ def concat_along_batch_dimension(outputs):
if isinstance(outputs[0], sparse_tensor.SparseTensor):
return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs)
if isinstance(outputs[0], ragged_tensor.RaggedTensor):
return ragged_concat_ops.concat(outputs, axis=0)
return array_ops.concat(outputs, axis=0)
return np.concatenate(outputs)

View File

@ -70,8 +70,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import trace
from tensorflow.python.training import checkpoint_management
@ -2724,8 +2722,6 @@ def concat(tensors, axis=0):
"""Concats `tensor`s along `axis`."""
if isinstance(tensors[0], sparse_tensor.SparseTensor):
return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors)
if isinstance(tensors[0], ragged_tensor.RaggedTensor):
return ragged_concat_ops.concat(tensors, axis=axis)
return array_ops.concat(tensors, axis=axis)

View File

@ -37,7 +37,6 @@ from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
@ -1205,7 +1204,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
if merge_mode == 'ave':
merge_func = lambda y, y_rev: (y + y_rev) / 2
elif merge_mode == 'concat':
merge_func = lambda y, y_rev: ragged_concat_ops.concat(
merge_func = lambda y, y_rev: array_ops.concat(
(y, y_rev), axis=-1)
elif merge_mode == 'mul':
merge_func = lambda y, y_rev: (y * y_rev)