Use TensorShape.assert_is_compatible_with instead of merge_with
This commit is contained in:
parent
2b31ba7d0a
commit
2ef80f2295
@ -146,10 +146,10 @@ class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor):
|
||||
dense_shape_shape = dense_shape.shape.with_rank(1)
|
||||
|
||||
# Assert number of rows in indices match the number of elements in values.
|
||||
indices_shape.dims[0].merge_with(values_shape.dims[0])
|
||||
indices_shape.dims[0].assert_is_compatible_with(values_shape.dims[0])
|
||||
# Assert number of columns in indices matches the number of elements in
|
||||
# dense_shape.
|
||||
indices_shape.dims[1].merge_with(dense_shape_shape.dims[0])
|
||||
indices_shape.dims[1].assert_is_compatible_with(dense_shape_shape.dims[0])
|
||||
|
||||
def get_shape(self):
|
||||
"""Get the `TensorShape` representing the shape of the dense tensor.
|
||||
|
@ -111,10 +111,10 @@ def clip_by_value(t, clip_value_min, clip_value_max,
|
||||
t_min = math_ops.minimum(values, clip_value_max)
|
||||
# Assert that the shape is compatible with the initial shape,
|
||||
# to prevent unintentional broadcasting.
|
||||
_ = values.shape.merge_with(t_min.shape)
|
||||
values.shape.assert_is_compatible_with(t_min.shape)
|
||||
|
||||
t_max = math_ops.maximum(t_min, clip_value_min, name=name)
|
||||
_ = values.shape.merge_with(t_max.shape)
|
||||
values.shape.assert_is_compatible_with(t_max.shape)
|
||||
|
||||
if isinstance(t, ops.IndexedSlices):
|
||||
t_max = ops.IndexedSlices(t_max, t.indices, t.dense_shape)
|
||||
@ -225,7 +225,7 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
|
||||
intermediate = values * clip_norm
|
||||
# Assert that the shape is compatible with the initial shape,
|
||||
# to prevent unintentional broadcasting.
|
||||
_ = values.shape.merge_with(intermediate.shape)
|
||||
values.shape.assert_is_compatible_with(intermediate.shape)
|
||||
values_clip = array_ops.identity(
|
||||
intermediate / math_ops.maximum(l2norm, clip_norm), name=name)
|
||||
|
||||
|
@ -675,7 +675,7 @@ def scan(fn,
|
||||
tensor_shape.dimension_value(
|
||||
elems_flat[0].get_shape().with_rank_at_least(1)[0]))
|
||||
for elem in elems_flat[1:]:
|
||||
n_static.merge_with(
|
||||
n_static.assert_is_compatible_with(
|
||||
tensor_shape.Dimension(
|
||||
tensor_shape.dimension_value(
|
||||
elem.get_shape().with_rank_at_least(1)[0])))
|
||||
|
@ -445,7 +445,7 @@ def map_fn(fn,
|
||||
tensor_shape.dimension_value(
|
||||
elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
|
||||
for tensor in elems_batchable[1:]:
|
||||
n_static.merge_with(
|
||||
n_static.assert_is_compatible_with(
|
||||
tensor_shape.Dimension(
|
||||
tensor_shape.dimension_value(
|
||||
tensor.get_shape().with_rank_at_least(1)[0])))
|
||||
|
@ -88,7 +88,7 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
|
||||
log_input = ops.convert_to_tensor(log_input, name="log_input")
|
||||
targets = ops.convert_to_tensor(targets, name="targets")
|
||||
try:
|
||||
targets.get_shape().merge_with(log_input.get_shape())
|
||||
targets.get_shape().assert_is_compatible_with(log_input.get_shape())
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"log_input and targets must have the same shape (%s vs %s)" %
|
||||
@ -168,7 +168,7 @@ def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name
|
||||
logits = ops.convert_to_tensor(logits, name="logits")
|
||||
labels = ops.convert_to_tensor(labels, name="labels")
|
||||
try:
|
||||
labels.get_shape().merge_with(logits.get_shape())
|
||||
labels.get_shape().assert_is_compatible_with(logits.get_shape())
|
||||
except ValueError:
|
||||
raise ValueError("logits and labels must have the same shape (%s vs %s)" %
|
||||
(logits.get_shape(), labels.get_shape()))
|
||||
@ -304,7 +304,7 @@ def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
|
||||
logits = ops.convert_to_tensor(logits, name="logits")
|
||||
labels = ops.convert_to_tensor(labels, name="labels")
|
||||
try:
|
||||
labels.get_shape().merge_with(logits.get_shape())
|
||||
labels.get_shape().assert_is_compatible_with(logits.get_shape())
|
||||
except ValueError:
|
||||
raise ValueError("logits and labels must have the same shape (%s vs %s)" %
|
||||
(logits.get_shape(), labels.get_shape()))
|
||||
|
@ -318,7 +318,7 @@ def _reverse_seq(input_seq, lengths):
|
||||
for sequence in zip(*flat_input_seq):
|
||||
input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank)
|
||||
for input_ in sequence:
|
||||
input_shape.merge_with(input_.get_shape())
|
||||
input_shape.assert_is_compatible_with(input_.get_shape())
|
||||
input_.set_shape(input_shape)
|
||||
|
||||
# Join into (time, batch_size, depth)
|
||||
@ -1112,7 +1112,7 @@ def raw_rnn(cell,
|
||||
|
||||
for input_shape_i in input_shape:
|
||||
# Static verification that batch sizes all match
|
||||
static_batch_size.merge_with(
|
||||
static_batch_size.assert_is_compatible_with(
|
||||
tensor_shape.dimension_at_index(input_shape_i, 0))
|
||||
|
||||
batch_size = tensor_shape.dimension_value(static_batch_size)
|
||||
@ -1339,7 +1339,7 @@ def static_rnn(cell,
|
||||
input_shape = flat_input.get_shape().with_rank_at_least(2)
|
||||
batch_size, input_size = tensor_shape.dimension_at_index(
|
||||
input_shape, 0), input_shape[1:]
|
||||
fixed_batch_size.merge_with(batch_size)
|
||||
fixed_batch_size.assert_is_compatible_with(batch_size)
|
||||
for i, size in enumerate(input_size.dims):
|
||||
if tensor_shape.dimension_value(size) is None:
|
||||
raise ValueError(
|
||||
|
@ -1905,7 +1905,7 @@ def sparse_retain(sp_input, to_retain):
|
||||
retain_shape = to_retain.get_shape()
|
||||
retain_shape.assert_has_rank(1)
|
||||
if sp_input.values.get_shape().dims is not None:
|
||||
sp_input.values.get_shape().dims[0].merge_with(
|
||||
sp_input.values.get_shape().dims[0].assert_is_compatible_with(
|
||||
tensor_shape.dimension_at_index(retain_shape, 0))
|
||||
|
||||
where_true = array_ops.reshape(array_ops.where_v2(to_retain), [-1])
|
||||
@ -1993,7 +1993,7 @@ def sparse_reset_shape(sp_input, new_shape=None):
|
||||
# For cases when shape is known during graph construction, this catches the
|
||||
# error before the sparse_tensor.SparseTensor catches it.
|
||||
if output_shape_tensor.get_shape().rank is not None:
|
||||
output_shape_tensor.get_shape().dims[0].merge_with(
|
||||
output_shape_tensor.get_shape().dims[0].assert_is_compatible_with(
|
||||
in_shape.get_shape().dims[0])
|
||||
|
||||
output_shape_tensor_const = tensor_util.constant_value(output_shape_tensor)
|
||||
|
Loading…
Reference in New Issue
Block a user