From 2ef80f22951331b80347dfe1ffe84fb1be1b62f4 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Fri, 9 Oct 2020 11:28:09 +0100 Subject: [PATCH] Use TensorShape.assert_is_compatible_with instead of merge_with --- tensorflow/python/framework/sparse_tensor.py | 4 ++-- tensorflow/python/ops/clip_ops.py | 6 +++--- tensorflow/python/ops/functional_ops.py | 2 +- tensorflow/python/ops/map_fn.py | 2 +- tensorflow/python/ops/nn_impl.py | 6 +++--- tensorflow/python/ops/rnn.py | 6 +++--- tensorflow/python/ops/sparse_ops.py | 4 ++-- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 5704563a92e..e7e8ea33b01 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -146,10 +146,10 @@ class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor): dense_shape_shape = dense_shape.shape.with_rank(1) # Assert number of rows in indices match the number of elements in values. - indices_shape.dims[0].merge_with(values_shape.dims[0]) + indices_shape.dims[0].assert_is_compatible_with(values_shape.dims[0]) # Assert number of columns in indices matches the number of elements in # dense_shape. - indices_shape.dims[1].merge_with(dense_shape_shape.dims[0]) + indices_shape.dims[1].assert_is_compatible_with(dense_shape_shape.dims[0]) def get_shape(self): """Get the `TensorShape` representing the shape of the dense tensor. diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index 1045ff692ea..3a42ae0ff45 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -111,10 +111,10 @@ def clip_by_value(t, clip_value_min, clip_value_max, t_min = math_ops.minimum(values, clip_value_max) # Assert that the shape is compatible with the initial shape, # to prevent unintentional broadcasting. - _ = values.shape.merge_with(t_min.shape) + values.shape.assert_is_compatible_with(t_min.shape) t_max = math_ops.maximum(t_min, clip_value_min, name=name) - _ = values.shape.merge_with(t_max.shape) + values.shape.assert_is_compatible_with(t_max.shape) if isinstance(t, ops.IndexedSlices): t_max = ops.IndexedSlices(t_max, t.indices, t.dense_shape) @@ -225,7 +225,7 @@ def clip_by_norm(t, clip_norm, axes=None, name=None): intermediate = values * clip_norm # Assert that the shape is compatible with the initial shape, # to prevent unintentional broadcasting. - _ = values.shape.merge_with(intermediate.shape) + values.shape.assert_is_compatible_with(intermediate.shape) values_clip = array_ops.identity( intermediate / math_ops.maximum(l2norm, clip_norm), name=name) diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index b51d1baa6c0..bdd20cda991 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -675,7 +675,7 @@ def scan(fn, tensor_shape.dimension_value( elems_flat[0].get_shape().with_rank_at_least(1)[0])) for elem in elems_flat[1:]: - n_static.merge_with( + n_static.assert_is_compatible_with( tensor_shape.Dimension( tensor_shape.dimension_value( elem.get_shape().with_rank_at_least(1)[0]))) diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py index edea769f663..af681592e67 100644 --- a/tensorflow/python/ops/map_fn.py +++ b/tensorflow/python/ops/map_fn.py @@ -445,7 +445,7 @@ def map_fn(fn, tensor_shape.dimension_value( elems_batchable[0].get_shape().with_rank_at_least(1)[0])) for tensor in elems_batchable[1:]: - n_static.merge_with( + n_static.assert_is_compatible_with( tensor_shape.Dimension( tensor_shape.dimension_value( tensor.get_shape().with_rank_at_least(1)[0]))) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index d22fbf3fa4e..5ec95b6646d 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -88,7 +88,7 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): log_input = ops.convert_to_tensor(log_input, name="log_input") targets = ops.convert_to_tensor(targets, name="targets") try: - targets.get_shape().merge_with(log_input.get_shape()) + targets.get_shape().assert_is_compatible_with(log_input.get_shape()) except ValueError: raise ValueError( "log_input and targets must have the same shape (%s vs %s)" % @@ -168,7 +168,7 @@ def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name logits = ops.convert_to_tensor(logits, name="logits") labels = ops.convert_to_tensor(labels, name="labels") try: - labels.get_shape().merge_with(logits.get_shape()) + labels.get_shape().assert_is_compatible_with(logits.get_shape()) except ValueError: raise ValueError("logits and labels must have the same shape (%s vs %s)" % (logits.get_shape(), labels.get_shape())) @@ -304,7 +304,7 @@ def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, logits = ops.convert_to_tensor(logits, name="logits") labels = ops.convert_to_tensor(labels, name="labels") try: - labels.get_shape().merge_with(logits.get_shape()) + labels.get_shape().assert_is_compatible_with(logits.get_shape()) except ValueError: raise ValueError("logits and labels must have the same shape (%s vs %s)" % (logits.get_shape(), labels.get_shape())) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 6c11ebefb1c..32dc9e38cb0 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -318,7 +318,7 @@ def _reverse_seq(input_seq, lengths): for sequence in zip(*flat_input_seq): input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank) for input_ in sequence: - input_shape.merge_with(input_.get_shape()) + input_shape.assert_is_compatible_with(input_.get_shape()) input_.set_shape(input_shape) # Join into (time, batch_size, depth) @@ -1112,7 +1112,7 @@ def raw_rnn(cell, for input_shape_i in input_shape: # Static verification that batch sizes all match - static_batch_size.merge_with( + static_batch_size.assert_is_compatible_with( tensor_shape.dimension_at_index(input_shape_i, 0)) batch_size = tensor_shape.dimension_value(static_batch_size) @@ -1339,7 +1339,7 @@ def static_rnn(cell, input_shape = flat_input.get_shape().with_rank_at_least(2) batch_size, input_size = tensor_shape.dimension_at_index( input_shape, 0), input_shape[1:] - fixed_batch_size.merge_with(batch_size) + fixed_batch_size.assert_is_compatible_with(batch_size) for i, size in enumerate(input_size.dims): if tensor_shape.dimension_value(size) is None: raise ValueError( diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 18b7561b113..3e3751e3ca6 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -1905,7 +1905,7 @@ def sparse_retain(sp_input, to_retain): retain_shape = to_retain.get_shape() retain_shape.assert_has_rank(1) if sp_input.values.get_shape().dims is not None: - sp_input.values.get_shape().dims[0].merge_with( + sp_input.values.get_shape().dims[0].assert_is_compatible_with( tensor_shape.dimension_at_index(retain_shape, 0)) where_true = array_ops.reshape(array_ops.where_v2(to_retain), [-1]) @@ -1993,7 +1993,7 @@ def sparse_reset_shape(sp_input, new_shape=None): # For cases when shape is known during graph construction, this catches the # error before the sparse_tensor.SparseTensor catches it. if output_shape_tensor.get_shape().rank is not None: - output_shape_tensor.get_shape().dims[0].merge_with( + output_shape_tensor.get_shape().dims[0].assert_is_compatible_with( in_shape.get_shape().dims[0]) output_shape_tensor_const = tensor_util.constant_value(output_shape_tensor)