Do not use deprecated get_shape

PiperOrigin-RevId: 270957229
This commit is contained in:
Alexandre Passos 2019-09-24 12:05:00 -07:00 committed by TensorFlower Gardener
parent 8e3f8ae473
commit 8e571e4306

View File

@ -252,7 +252,7 @@ def einsum(equation, *inputs, **kwargs):
[format(key) for key in sorted(list(kwargs.keys()))]))
with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
inputs = list(inputs)
input_shapes = [x.get_shape() for x in inputs]
input_shapes = [x.shape for x in inputs]
input_axis_labels, output_axis_labels = _einsum_parse_and_resolve_equation(
equation, input_shapes)
@ -417,14 +417,14 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
`t0_axis_labels`, or that of `t1` does not match the length of
`t1_axis_labels`.
"""
if len(t0_axis_labels) != len(t0.get_shape()):
if len(t0_axis_labels) != len(t0.shape):
raise ValueError(
'Tensor t0 of rank %d does not match einsum reduction of length %d' %
(len(t0.get_shape()), len(t0_axis_labels)))
if len(t1_axis_labels) != len(t1.get_shape()):
(len(t0.shape), len(t0_axis_labels)))
if len(t1_axis_labels) != len(t1.shape):
raise ValueError(
'Tensor t1 of rank %d does not match einsum reduction of length %d' %
(len(t1.get_shape()), len(t1_axis_labels)))
(len(t1.shape), len(t1_axis_labels)))
# This function computes the result of a two-argument einsum() using batch
# matrix multiplication. This involves
@ -531,7 +531,7 @@ def _reshape_if_necessary(tensor, new_shape):
"""Like reshape(), but avoids creating a new tensor if possible."""
# Accept None as an alias for -1 in new_shape.
new_shape = tuple(-1 if x is None else x for x in new_shape)
cur_shape = tuple(x.value for x in tensor.get_shape().dims)
cur_shape = tuple(x.value for x in tensor.shape.dims)
if (len(new_shape) == len(cur_shape) and
all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1)
for d0, d1 in zip(cur_shape, new_shape))):
@ -544,7 +544,7 @@ def _get_shape(tensor):
"""Like get_shape().as_list(), but explicitly queries the shape of a tensor
if necessary to ensure that the returned value contains no unknown value."""
shape = tensor.get_shape().as_list()
shape = tensor.shape.as_list()
none_indices = [i for i, d in enumerate(shape) if d is None]
if none_indices:
# Query the shape if shape contains None values
@ -569,7 +569,7 @@ def _total_size(shape_values):
def _exponential_space_einsum(equation, *inputs):
"""Fallback implementation that supports summing an index over > 2 inputs."""
inputs = list(inputs)
input_shapes = [x.get_shape() for x in inputs]
input_shapes = [x.shape for x in inputs]
idx_in, idx_out = _einsum_parse_and_resolve_equation(equation, input_shapes)
idx_all = set(''.join(idx_in) + idx_out)
@ -588,11 +588,11 @@ def _exponential_space_einsum(equation, *inputs):
# transpose inputs so axes are in order
for i, (input_, axes_) in enumerate(zip(inputs, idx_in)):
if input_.get_shape().ndims != len(axes_):
if input_.shape.ndims != len(axes_):
raise ValueError(
'Input %d with axes %s has incorrect' \
' number of dimensions (expected %d, got %d)' % (
i, axes_, len(axes_), input_.get_shape().ndims
i, axes_, len(axes_), input_.shape.ndims
)
)
@ -609,7 +609,7 @@ def _exponential_space_einsum(equation, *inputs):
reduction_idx = []
shapes = [[dim if dim else -1
for dim in tensor.get_shape().as_list()]
for dim in tensor.shape.as_list()]
for tensor in inputs]
# validate shapes for broadcasting