Use TensorShape.assert_is_compatible_with instead of merge_with

This commit is contained in:
Lukas Geiger 2020-10-09 11:28:09 +01:00
parent 2b31ba7d0a
commit 2ef80f2295
7 changed files with 15 additions and 15 deletions

View File

@ -146,10 +146,10 @@ class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor):
dense_shape_shape = dense_shape.shape.with_rank(1)
# Assert number of rows in indices match the number of elements in values.
indices_shape.dims[0].merge_with(values_shape.dims[0])
indices_shape.dims[0].assert_is_compatible_with(values_shape.dims[0])
# Assert number of columns in indices matches the number of elements in
# dense_shape.
indices_shape.dims[1].merge_with(dense_shape_shape.dims[0])
indices_shape.dims[1].assert_is_compatible_with(dense_shape_shape.dims[0])
def get_shape(self):
"""Get the `TensorShape` representing the shape of the dense tensor.

View File

@ -111,10 +111,10 @@ def clip_by_value(t, clip_value_min, clip_value_max,
t_min = math_ops.minimum(values, clip_value_max)
# Assert that the shape is compatible with the initial shape,
# to prevent unintentional broadcasting.
_ = values.shape.merge_with(t_min.shape)
values.shape.assert_is_compatible_with(t_min.shape)
t_max = math_ops.maximum(t_min, clip_value_min, name=name)
_ = values.shape.merge_with(t_max.shape)
values.shape.assert_is_compatible_with(t_max.shape)
if isinstance(t, ops.IndexedSlices):
t_max = ops.IndexedSlices(t_max, t.indices, t.dense_shape)
@ -225,7 +225,7 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
intermediate = values * clip_norm
# Assert that the shape is compatible with the initial shape,
# to prevent unintentional broadcasting.
_ = values.shape.merge_with(intermediate.shape)
values.shape.assert_is_compatible_with(intermediate.shape)
values_clip = array_ops.identity(
intermediate / math_ops.maximum(l2norm, clip_norm), name=name)

View File

@ -675,7 +675,7 @@ def scan(fn,
tensor_shape.dimension_value(
elems_flat[0].get_shape().with_rank_at_least(1)[0]))
for elem in elems_flat[1:]:
n_static.merge_with(
n_static.assert_is_compatible_with(
tensor_shape.Dimension(
tensor_shape.dimension_value(
elem.get_shape().with_rank_at_least(1)[0])))

View File

@ -445,7 +445,7 @@ def map_fn(fn,
tensor_shape.dimension_value(
elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
for tensor in elems_batchable[1:]:
n_static.merge_with(
n_static.assert_is_compatible_with(
tensor_shape.Dimension(
tensor_shape.dimension_value(
tensor.get_shape().with_rank_at_least(1)[0])))

View File

@ -88,7 +88,7 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
log_input = ops.convert_to_tensor(log_input, name="log_input")
targets = ops.convert_to_tensor(targets, name="targets")
try:
targets.get_shape().merge_with(log_input.get_shape())
targets.get_shape().assert_is_compatible_with(log_input.get_shape())
except ValueError:
raise ValueError(
"log_input and targets must have the same shape (%s vs %s)" %
@ -168,7 +168,7 @@ def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name
logits = ops.convert_to_tensor(logits, name="logits")
labels = ops.convert_to_tensor(labels, name="labels")
try:
labels.get_shape().merge_with(logits.get_shape())
labels.get_shape().assert_is_compatible_with(logits.get_shape())
except ValueError:
raise ValueError("logits and labels must have the same shape (%s vs %s)" %
(logits.get_shape(), labels.get_shape()))
@ -304,7 +304,7 @@ def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
logits = ops.convert_to_tensor(logits, name="logits")
labels = ops.convert_to_tensor(labels, name="labels")
try:
labels.get_shape().merge_with(logits.get_shape())
labels.get_shape().assert_is_compatible_with(logits.get_shape())
except ValueError:
raise ValueError("logits and labels must have the same shape (%s vs %s)" %
(logits.get_shape(), labels.get_shape()))

View File

@ -318,7 +318,7 @@ def _reverse_seq(input_seq, lengths):
for sequence in zip(*flat_input_seq):
input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank)
for input_ in sequence:
input_shape.merge_with(input_.get_shape())
input_shape.assert_is_compatible_with(input_.get_shape())
input_.set_shape(input_shape)
# Join into (time, batch_size, depth)
@ -1112,7 +1112,7 @@ def raw_rnn(cell,
for input_shape_i in input_shape:
# Static verification that batch sizes all match
static_batch_size.merge_with(
static_batch_size.assert_is_compatible_with(
tensor_shape.dimension_at_index(input_shape_i, 0))
batch_size = tensor_shape.dimension_value(static_batch_size)
@ -1339,7 +1339,7 @@ def static_rnn(cell,
input_shape = flat_input.get_shape().with_rank_at_least(2)
batch_size, input_size = tensor_shape.dimension_at_index(
input_shape, 0), input_shape[1:]
fixed_batch_size.merge_with(batch_size)
fixed_batch_size.assert_is_compatible_with(batch_size)
for i, size in enumerate(input_size.dims):
if tensor_shape.dimension_value(size) is None:
raise ValueError(

View File

@ -1905,7 +1905,7 @@ def sparse_retain(sp_input, to_retain):
retain_shape = to_retain.get_shape()
retain_shape.assert_has_rank(1)
if sp_input.values.get_shape().dims is not None:
sp_input.values.get_shape().dims[0].merge_with(
sp_input.values.get_shape().dims[0].assert_is_compatible_with(
tensor_shape.dimension_at_index(retain_shape, 0))
where_true = array_ops.reshape(array_ops.where_v2(to_retain), [-1])
@ -1993,7 +1993,7 @@ def sparse_reset_shape(sp_input, new_shape=None):
# For cases when shape is known during graph construction, this catches the
# error before the sparse_tensor.SparseTensor catches it.
if output_shape_tensor.get_shape().rank is not None:
output_shape_tensor.get_shape().dims[0].merge_with(
output_shape_tensor.get_shape().dims[0].assert_is_compatible_with(
in_shape.get_shape().dims[0])
output_shape_tensor_const = tensor_util.constant_value(output_shape_tensor)