Do not use deprecated get_shape
PiperOrigin-RevId: 270957229
This commit is contained in:
parent
8e3f8ae473
commit
8e571e4306
@ -252,7 +252,7 @@ def einsum(equation, *inputs, **kwargs):
|
||||
[format(key) for key in sorted(list(kwargs.keys()))]))
|
||||
with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
|
||||
inputs = list(inputs)
|
||||
input_shapes = [x.get_shape() for x in inputs]
|
||||
input_shapes = [x.shape for x in inputs]
|
||||
input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation(
|
||||
equation, input_shapes)
|
||||
|
||||
@ -417,14 +417,14 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
|
||||
`t0_axis_labels`, or that of `t1` does not match the length of
|
||||
`t1_axis_labels`.
|
||||
"""
|
||||
if len(t0_axis_labels) != len(t0.get_shape()):
|
||||
if len(t0_axis_labels) != len(t0.shape):
|
||||
raise ValueError(
|
||||
'Tensor t0 of rank %d does not match einsum reduction of length %d' %
|
||||
(len(t0.get_shape()), len(t0_axis_labels)))
|
||||
if len(t1_axis_labels) != len(t1.get_shape()):
|
||||
(len(t0.shape), len(t0_axis_labels)))
|
||||
if len(t1_axis_labels) != len(t1.shape):
|
||||
raise ValueError(
|
||||
'Tensor t1 of rank %d does not match einsum reduction of length %d' %
|
||||
(len(t1.get_shape()), len(t1_axis_labels)))
|
||||
(len(t1.shape), len(t1_axis_labels)))
|
||||
|
||||
# This function computes the result of a two-argument einsum() using batch
|
||||
# matrix multiplication. This involves
|
||||
@ -531,7 +531,7 @@ def _reshape_if_necessary(tensor, new_shape):
|
||||
"""Like reshape(), but avoids creating a new tensor if possible."""
|
||||
# Accept None as an alias for -1 in new_shape.
|
||||
new_shape = tuple(-1 if x is None else x for x in new_shape)
|
||||
cur_shape = tuple(x.value for x in tensor.get_shape().dims)
|
||||
cur_shape = tuple(x.value for x in tensor.shape.dims)
|
||||
if (len(new_shape) == len(cur_shape) and
|
||||
all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1)
|
||||
for d0, d1 in zip(cur_shape, new_shape))):
|
||||
@ -544,7 +544,7 @@ def _get_shape(tensor):
|
||||
"""Like get_shape().as_list(), but explicitly queries the shape of a tensor
|
||||
if necessary to ensure that the returned value contains no unknown value."""
|
||||
|
||||
shape = tensor.get_shape().as_list()
|
||||
shape = tensor.shape.as_list()
|
||||
none_indices = [i for i, d in enumerate(shape) if d is None]
|
||||
if none_indices:
|
||||
# Query the shape if shape contains None values
|
||||
@ -569,7 +569,7 @@ def _total_size(shape_values):
|
||||
def _exponential_space_einsum(equation, *inputs):
|
||||
"""Fallback implementation that supports summing an index over > 2 inputs."""
|
||||
inputs = list(inputs)
|
||||
input_shapes = [x.get_shape() for x in inputs]
|
||||
input_shapes = [x.shape for x in inputs]
|
||||
idx_in, idx_out = _einsum_parse_and_resolve_equation(equation, input_shapes)
|
||||
|
||||
idx_all = set(''.join(idx_in) + idx_out)
|
||||
@ -588,11 +588,11 @@ def _exponential_space_einsum(equation, *inputs):
|
||||
|
||||
# transpose inputs so axes are in order
|
||||
for i, (input_, axes_) in enumerate(zip(inputs, idx_in)):
|
||||
if input_.get_shape().ndims != len(axes_):
|
||||
if input_.shape.ndims != len(axes_):
|
||||
raise ValueError(
|
||||
'Input %d with axes %s has incorrect' \
|
||||
' number of dimensions (expected %d, got %d)' % (
|
||||
i, axes_, len(axes_), input_.get_shape().ndims
|
||||
i, axes_, len(axes_), input_.shape.ndims
|
||||
)
|
||||
)
|
||||
|
||||
@ -609,7 +609,7 @@ def _exponential_space_einsum(equation, *inputs):
|
||||
|
||||
reduction_idx = []
|
||||
shapes = [[dim if dim else -1
|
||||
for dim in tensor.get_shape().as_list()]
|
||||
for dim in tensor.shape.as_list()]
|
||||
for tensor in inputs]
|
||||
|
||||
# validate shapes for broadcasting
|
||||
|
Loading…
Reference in New Issue
Block a user