Cleaned up a few *TensorArray._element_shape leaks

PiperOrigin-RevId: 254749223
This commit is contained in:
Sergei Lebedev 2019-06-24 06:56:12 -07:00 committed by TensorFlower Gardener
parent 0bd27b7fd4
commit 1cd4a2b5c8
4 changed files with 69 additions and 191 deletions

View File

@ -503,29 +503,15 @@ class BeamSearchDecoderMixin(object):
"""
if not isinstance(t, tensor_array_ops.TensorArray):
return t
# pylint: disable=protected-access
# This is a bad hack due to the implementation detail of eager/graph TA.
# TODO(b/124374427): Update this to use public property of TensorArray.
if context.executing_eagerly():
element_shape = t._element_shape
else:
element_shape = t._element_shape[0]
if (not t._infer_shape
or not t._element_shape
or element_shape.ndims is None
or element_shape.ndims < 1):
shape = (
element_shape if t._infer_shape and t._element_shape
else tensor_shape.TensorShape(None))
if t.element_shape.ndims is None or t.element_shape.ndims < 1:
tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
"sorting based on the beam search result. For a "
"TensorArray to be sorted, its elements shape must be "
"defined and have at least a rank of 1, but saw shape: %s"
% (t.handle.name, shape))
% (t.handle.name, t.element_shape))
return t
# pylint: enable=protected-access
if not _check_static_batch_beam_maybe(
element_shape, tensor_util.constant_value(self._batch_size),
t.element_shape, tensor_util.constant_value(self._batch_size),
self._beam_width):
return t
t = t.stack()

View File

@ -1065,6 +1065,15 @@ class TensorArrayTest(test.TestCase):
grad = gradients_impl.gradients(loop(x), [x])[0]
self.assertAllClose(31.0, self.evaluate(grad))
def testShapeAfterWhileLoop(self):
size = 10
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size)
_, ta = control_flow_ops.while_loop(
lambda i, _: i < size,
lambda i, ta: (i + 1, ta.write(i, [[0.]])), [0, ta],
parallel_iterations=1)
self.assertIsNotNone(ta.element_shape.dims)
@test_util.deprecated_graph_mode_only
def testSkipEagerSumOfTwoReadVariablesWithoutRepeatGrad(self):
with self.session(use_gpu=True) as session:

View File

@ -432,27 +432,13 @@ def _convert_tensorarray_to_flow(tensor_or_tensor_array):
return tensor_or_tensor_array
def _make_tensor_array(ta, t_or_flow):
# pylint: disable=protected-access
new_ta = tensor_array_ops.TensorArray(
dtype=ta.dtype,
handle=ta.handle,
flow=t_or_flow,
infer_shape=ta._infer_shape,
colocate_with_first_write_call=ta._colocate_with_first_write_call)
new_ta._colocate_with = ta._colocate_with
new_ta._element_shape = ta._element_shape
# pylint: enable=protected-access
return new_ta
def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
if len(tensors_or_tensorarrays) != len(tensors_or_flows):
raise ValueError(
"Lengths of original Tensor list and new list do not match: %d vs. %d" %
(len(tensors_or_tensorarrays), len(tensors_or_flows)))
return [
_make_tensor_array(ta, t_or_flow) if isinstance(
tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance(
ta, tensor_array_ops.TensorArray) else t_or_flow
for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
]

View File

@ -134,12 +134,8 @@ class _GraphTensorArray(object):
# shape is defined either by `element_shape` or the shape of the tensor
# of the first write. If `infer_shape` is true, all writes checks for
# shape equality.
if element_shape is None:
self._infer_shape = infer_shape
self._element_shape = []
else:
self._infer_shape = True
self._element_shape = [tensor_shape.as_shape(element_shape)]
self._element_shape = [tensor_shape.as_shape(element_shape)]
self._infer_shape = element_shape is not None or infer_shape
with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
if handle is not None:
self._handle = handle
@ -181,10 +177,7 @@ class _GraphTensorArray(object):
@property
def element_shape(self):
if self._element_shape:
return self._element_shape[0]
else:
return tensor_shape.unknown_shape(None)
return self._element_shape[0]
def _merge_element_shape(self, shape):
"""Changes the element shape of the array given a shape to merge with.
@ -196,15 +189,11 @@ class _GraphTensorArray(object):
ValueError: if the provided shape is incompatible with the current
element shape of the `TensorArray`.
"""
if self._element_shape:
if not shape.is_compatible_with(self._element_shape[0]):
raise ValueError(
"Inconsistent shapes: saw %s but expected %s "
"(and infer_shape=True)" % (shape, self._element_shape[0]))
self._element_shape[0] = self._element_shape[0].merge_with(shape)
else:
self._element_shape.append(shape)
if not shape.is_compatible_with(self.element_shape):
raise ValueError(
"Inconsistent shapes: saw %s but expected %s "
"(and infer_shape=True)" % (shape, self.element_shape))
self._element_shape[0] = self.element_shape.merge_with(shape)
@contextlib.contextmanager
def _maybe_colocate_with(self, value):
@ -230,16 +219,7 @@ class _GraphTensorArray(object):
def identity(self):
"""See TensorArray."""
flow = array_ops.identity(self._flow)
ta = TensorArray(
dtype=self._dtype,
handle=self._handle,
flow=flow,
infer_shape=self._infer_shape,
colocate_with_first_write_call=self._colocate_with_first_write_call)
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
ta._dynamic_size = self._dynamic_size
return ta
return build_ta_with_new_flow(self, flow)
def grad(self, source, flow=None, name=None):
"""See TensorArray."""
@ -261,7 +241,9 @@ class _GraphTensorArray(object):
flow=flow,
infer_shape=self._infer_shape,
colocate_with_first_write_call=False)
g._element_shape = self._element_shape
# pylint: disable=protected-access
g._implementation._element_shape = self._element_shape
# pylint: enable=protected-access
return g
def read(self, index, name=None):
@ -293,16 +275,7 @@ class _GraphTensorArray(object):
value=value,
flow_in=self._flow,
name=name)
ta = TensorArray(
dtype=self._dtype,
handle=self._handle,
flow=flow_out,
colocate_with_first_write_call=self._colocate_with_first_write_call)
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
ta._dynamic_size = self._dynamic_size
return ta
return build_ta_with_new_flow(self, flow_out)
def stack(self, name=None):
"""See TensorArray."""
@ -323,25 +296,20 @@ class _GraphTensorArray(object):
dtype=self._dtype,
name=name,
element_shape=element_shape)
if self._element_shape and self._element_shape[0].dims is not None:
value.set_shape([None] + self._element_shape[0].dims)
if self.element_shape:
value.set_shape([None] + self.element_shape.dims)
return value
def concat(self, name=None):
"""See TensorArray."""
if self._element_shape and self._element_shape[0].dims is not None:
element_shape_except0 = (
tensor_shape.TensorShape(self._element_shape[0].dims[1:]))
else:
element_shape_except0 = tensor_shape.TensorShape(None)
value, _ = gen_data_flow_ops.tensor_array_concat_v3(
handle=self._handle,
flow_in=self._flow,
dtype=self._dtype,
name=name,
element_shape_except0=element_shape_except0)
if self._element_shape and self._element_shape[0].dims is not None:
value.set_shape([None] + self._element_shape[0].dims[1:])
element_shape_except0=self.element_shape[1:])
if self.element_shape:
value.set_shape([None] + self.element_shape.dims[1:])
return value
@tf_should_use.should_use_result
@ -370,16 +338,7 @@ class _GraphTensorArray(object):
value=value,
flow_in=self._flow,
name=name)
ta = TensorArray(
dtype=self._dtype,
handle=self._handle,
flow=flow_out,
colocate_with_first_write_call=self._colocate_with_first_write_call)
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
ta._dynamic_size = self._dynamic_size
return ta
return build_ta_with_new_flow(self, flow_out)
@tf_should_use.should_use_result
def split(self, value, lengths, name=None):
@ -402,16 +361,7 @@ class _GraphTensorArray(object):
lengths=lengths_64,
flow_in=self._flow,
name=name)
ta = TensorArray(
dtype=self._dtype,
handle=self._handle,
flow=flow_out,
colocate_with_first_write_call=self._colocate_with_first_write_call)
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
ta._dynamic_size = self._dynamic_size
return ta
return build_ta_with_new_flow(self, flow_out)
def size(self, name=None):
"""See TensorArray."""
@ -496,12 +446,8 @@ class _GraphTensorArrayV2(object):
# shape is defined either by `element_shape` or the shape of the tensor
# of the first write. If `infer_shape` is true, all writes checks for
# shape equality.
if element_shape is None:
self._infer_shape = infer_shape
self._element_shape = []
else:
self._infer_shape = True
self._element_shape = [tensor_shape.as_shape(element_shape)]
self._element_shape = [tensor_shape.as_shape(element_shape)]
self._infer_shape = element_shape is not None or infer_shape
with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
if flow is None:
self._flow = list_ops.tensor_list_reserve(
@ -526,10 +472,7 @@ class _GraphTensorArrayV2(object):
@property
def element_shape(self):
if self._element_shape:
return self._element_shape[0]
else:
return tensor_shape.unknown_shape(None)
return self._element_shape[0]
@property
def handle(self):
@ -547,15 +490,11 @@ class _GraphTensorArrayV2(object):
ValueError: if the provided shape is incompatible with the current
element shape of the `TensorArray`.
"""
if self._element_shape:
if not shape.is_compatible_with(self._element_shape[0]):
raise ValueError(
"Inconsistent shapes: saw %s but expected %s "
"(and infer_shape=True)" % (shape, self._element_shape[0]))
self._element_shape[0] = self._element_shape[0].merge_with(shape)
else:
self._element_shape.append(shape)
if not shape.is_compatible_with(self.element_shape):
raise ValueError(
"Inconsistent shapes: saw %s but expected %s "
"(and infer_shape=True)" % (shape, self.element_shape))
self._element_shape[0] = self.element_shape.merge_with(shape)
def identity(self):
"""See TensorArray."""
@ -569,15 +508,11 @@ class _GraphTensorArrayV2(object):
def read(self, index, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
if self._element_shape:
element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.unknown_shape(None)
value = list_ops.tensor_list_get_item(
input_handle=self._flow,
index=index,
element_dtype=self._dtype,
element_shape=element_shape,
element_shape=self.element_shape,
name=name)
return value
@ -602,34 +537,26 @@ class _GraphTensorArrayV2(object):
def stack(self, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
if self._element_shape:
element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.unknown_shape(None)
value = list_ops.tensor_list_stack(
input_handle=self._flow,
element_dtype=self._dtype,
element_shape=element_shape)
element_shape=self.element_shape)
return value
def gather(self, indices, name=None):
"""See TensorArray."""
if self._element_shape:
element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.unknown_shape(None)
value = list_ops.tensor_list_gather(
input_handle=self._flow,
indices=indices,
element_dtype=self._dtype,
element_shape=element_shape,
element_shape=self.element_shape,
name=name)
return value
def concat(self, name=None):
"""See TensorArray."""
if self._element_shape and self._element_shape[0].dims is not None:
element_shape = [None] + self._element_shape[0].dims[1:]
if self.element_shape:
element_shape = [None] + self.element_shape.dims[1:]
else:
element_shape = None
@ -665,9 +592,9 @@ class _GraphTensorArrayV2(object):
_check_dtypes(value, self._dtype)
if self._infer_shape and not context.executing_eagerly():
self._merge_element_shape(value.shape[1:])
element_shape = self._element_shape[0] if self._element_shape else None
flow_out = list_ops.tensor_list_scatter(
tensor=value, indices=indices, input_handle=self._flow)
tensor=value, indices=indices, element_shape=self.element_shape,
input_handle=self._flow)
return build_ta_with_new_flow(self, flow_out)
@tf_should_use.should_use_result
@ -689,7 +616,7 @@ class _GraphTensorArrayV2(object):
flow_out = list_ops.tensor_list_split(
tensor=value,
lengths=lengths_64,
element_shape=self._element_shape[0] if self._element_shape else None,
element_shape=self.element_shape,
name=name)
return build_ta_with_new_flow(self, flow_out)
@ -761,7 +688,7 @@ class _EagerTensorArray(object):
# we assign a dummy value to _flow in case other code assumes it to be
# a Tensor
self._flow = constant_op.constant(0, dtype=dtypes.int32)
self._infer_shape = infer_shape
self._infer_shape = element_shape is not None or infer_shape
self._element_shape = tensor_shape.as_shape(element_shape)
self._colocate_with_first_write_call = colocate_with_first_write_call
@ -791,10 +718,7 @@ class _EagerTensorArray(object):
@property
def element_shape(self):
if not self._element_shape:
return tensor_shape.unknown_shape(None)
else:
return
return self._element_shape
def identity(self):
"""See TensorArray."""
@ -1114,42 +1038,13 @@ class TensorArray(object):
"""Python bool; if `True` the TensorArray can grow dynamically."""
return self._implementation._dynamic_size
@property
def _dynamic_size(self):
return self._implementation._dynamic_size
@_dynamic_size.setter
def _dynamic_size(self, dynamic_size):
self._implementation._dynamic_size = dynamic_size
@property
def _infer_shape(self):
# TODO(slebedev): consider making public or changing TensorArrayStructure
# to access _implementation directly. Note that dynamic_size is also
# only used by TensorArrayStructure.
return self._implementation._infer_shape
@_infer_shape.setter
def _infer_shape(self, infer_shape):
self._implementation._infer_shape = infer_shape
@property
def _element_shape(self):
return self._implementation._element_shape
@_element_shape.setter
def _element_shape(self, element_shape):
self._implementation._element_shape = element_shape
@property
def _colocate_with_first_write_call(self):
return self._implementation._colocate_with_first_write_call
@property
def _colocate_with(self):
return self._implementation._colocate_with
@_colocate_with.setter
def _colocate_with(self, colocate_with):
self._implementation._colocate_with = colocate_with
def identity(self):
"""Returns a TensorArray with the same content and properties.
@ -1307,11 +1202,12 @@ class TensorArray(object):
def build_ta_with_new_flow(old_ta, flow):
"""Builds a TensorArray with a new `flow` tensor."""
# Sometimes we get old_ta as the implementation, sometimes it's the
# TensorArray wrapper object.
impl = (old_ta._implementation if isinstance(old_ta, TensorArray)
else old_ta)
if not context.executing_eagerly():
# Sometimes we get old_ta as the implementation, sometimes it's the
# TensorArray wrapper object.
impl = (old_ta._implementation if isinstance(old_ta, TensorArray)
else old_ta)
if (not isinstance(impl, _GraphTensorArrayV2) and
control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
raise NotImplementedError("Attempting to build a graph-mode TF2-style "
@ -1322,16 +1218,17 @@ def build_ta_with_new_flow(old_ta, flow):
"inside a tf.function or tf.data map function. "
"Instead, construct a new TensorArray inside "
"the function.")
ta = TensorArray(
dtype=old_ta.dtype,
dynamic_size=old_ta._dynamic_size,
handle=old_ta.handle,
new_ta = TensorArray(
dtype=impl.dtype,
handle=impl.handle,
flow=flow,
infer_shape=old_ta._infer_shape,
colocate_with_first_write_call=old_ta._colocate_with_first_write_call)
ta._colocate_with = old_ta._colocate_with
ta._element_shape = old_ta._element_shape
return ta
infer_shape=impl._infer_shape,
colocate_with_first_write_call=impl._colocate_with_first_write_call)
new_impl = new_ta._implementation
new_impl._dynamic_size = impl._dynamic_size
new_impl._colocate_with = impl._colocate_with
new_impl._element_shape = impl._element_shape # Share _element_shape.
return new_ta
# pylint: enable=protected-access