Update the usages of ragged_concat_ops.concat
to array_ops.concat
in Keras codebase.
PiperOrigin-RevId: 340678363 Change-Id: Ic366ee468156c3a89ea2121c2c4eeec5857b6c06
This commit is contained in:
parent
a3c9e996fb
commit
de5e35d7fe
@ -77,7 +77,6 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import moving_averages
|
||||
@ -3093,7 +3092,7 @@ def concatenate(tensors, axis=-1):
|
||||
if py_all(is_sparse(x) for x in tensors):
|
||||
return sparse_ops.sparse_concat(axis, tensors)
|
||||
elif py_all(isinstance(x, ragged_tensor.RaggedTensor) for x in tensors):
|
||||
return ragged_concat_ops.concat(tensors, axis)
|
||||
return array_ops.concat(tensors, axis)
|
||||
else:
|
||||
return array_ops.concat([to_dense(x) for x in tensors], axis)
|
||||
|
||||
|
@ -42,11 +42,11 @@ from tensorflow.python.keras.engine import training_utils_v1
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
@ -1154,5 +1154,5 @@ def concat_along_batch_dimension(outputs):
|
||||
if isinstance(outputs[0], sparse_tensor.SparseTensor):
|
||||
return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs)
|
||||
if isinstance(outputs[0], ragged_tensor.RaggedTensor):
|
||||
return ragged_concat_ops.concat(outputs, axis=0)
|
||||
return array_ops.concat(outputs, axis=0)
|
||||
return np.concatenate(outputs)
|
||||
|
@ -70,8 +70,6 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import trace
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
@ -2724,8 +2722,6 @@ def concat(tensors, axis=0):
|
||||
"""Concats `tensor`s along `axis`."""
|
||||
if isinstance(tensors[0], sparse_tensor.SparseTensor):
|
||||
return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors)
|
||||
if isinstance(tensors[0], ragged_tensor.RaggedTensor):
|
||||
return ragged_concat_ops.concat(tensors, axis=axis)
|
||||
return array_ops.concat(tensors, axis=axis)
|
||||
|
||||
|
||||
|
@ -37,7 +37,6 @@ from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
@ -1205,7 +1204,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
|
||||
if merge_mode == 'ave':
|
||||
merge_func = lambda y, y_rev: (y + y_rev) / 2
|
||||
elif merge_mode == 'concat':
|
||||
merge_func = lambda y, y_rev: ragged_concat_ops.concat(
|
||||
merge_func = lambda y, y_rev: array_ops.concat(
|
||||
(y, y_rev), axis=-1)
|
||||
elif merge_mode == 'mul':
|
||||
merge_func = lambda y, y_rev: (y * y_rev)
|
||||
|
Loading…
x
Reference in New Issue
Block a user