diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 867a4b7b3c8..cbb811054af 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -107,7 +107,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): out_grads = [] if isinstance(grad, ops.Tensor): - if context.executing_eagerly(): + if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor): # Using mod here for convenience since concat_dim is already verified # in concat implementation to be within the allowed [-rank, rank) range. non_neg_concat_dim = (