Cleaned up a few *TensorArray._element_shape leaks
PiperOrigin-RevId: 254749223
This commit is contained in:
parent
0bd27b7fd4
commit
1cd4a2b5c8
@ -503,29 +503,15 @@ class BeamSearchDecoderMixin(object):
|
||||
"""
|
||||
if not isinstance(t, tensor_array_ops.TensorArray):
|
||||
return t
|
||||
# pylint: disable=protected-access
|
||||
# This is a bad hack due to the implementation detail of eager/graph TA.
|
||||
# TODO(b/124374427): Update this to use public property of TensorArray.
|
||||
if context.executing_eagerly():
|
||||
element_shape = t._element_shape
|
||||
else:
|
||||
element_shape = t._element_shape[0]
|
||||
if (not t._infer_shape
|
||||
or not t._element_shape
|
||||
or element_shape.ndims is None
|
||||
or element_shape.ndims < 1):
|
||||
shape = (
|
||||
element_shape if t._infer_shape and t._element_shape
|
||||
else tensor_shape.TensorShape(None))
|
||||
if t.element_shape.ndims is None or t.element_shape.ndims < 1:
|
||||
tf_logging.warn("The TensorArray %s in the cell state is not amenable to "
|
||||
"sorting based on the beam search result. For a "
|
||||
"TensorArray to be sorted, its elements shape must be "
|
||||
"defined and have at least a rank of 1, but saw shape: %s"
|
||||
% (t.handle.name, shape))
|
||||
% (t.handle.name, t.element_shape))
|
||||
return t
|
||||
# pylint: enable=protected-access
|
||||
if not _check_static_batch_beam_maybe(
|
||||
element_shape, tensor_util.constant_value(self._batch_size),
|
||||
t.element_shape, tensor_util.constant_value(self._batch_size),
|
||||
self._beam_width):
|
||||
return t
|
||||
t = t.stack()
|
||||
|
@ -1065,6 +1065,15 @@ class TensorArrayTest(test.TestCase):
|
||||
grad = gradients_impl.gradients(loop(x), [x])[0]
|
||||
self.assertAllClose(31.0, self.evaluate(grad))
|
||||
|
||||
def testShapeAfterWhileLoop(self):
|
||||
size = 10
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size)
|
||||
_, ta = control_flow_ops.while_loop(
|
||||
lambda i, _: i < size,
|
||||
lambda i, ta: (i + 1, ta.write(i, [[0.]])), [0, ta],
|
||||
parallel_iterations=1)
|
||||
self.assertIsNotNone(ta.element_shape.dims)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testSkipEagerSumOfTwoReadVariablesWithoutRepeatGrad(self):
|
||||
with self.session(use_gpu=True) as session:
|
||||
|
@ -432,27 +432,13 @@ def _convert_tensorarray_to_flow(tensor_or_tensor_array):
|
||||
return tensor_or_tensor_array
|
||||
|
||||
|
||||
def _make_tensor_array(ta, t_or_flow):
|
||||
# pylint: disable=protected-access
|
||||
new_ta = tensor_array_ops.TensorArray(
|
||||
dtype=ta.dtype,
|
||||
handle=ta.handle,
|
||||
flow=t_or_flow,
|
||||
infer_shape=ta._infer_shape,
|
||||
colocate_with_first_write_call=ta._colocate_with_first_write_call)
|
||||
new_ta._colocate_with = ta._colocate_with
|
||||
new_ta._element_shape = ta._element_shape
|
||||
# pylint: enable=protected-access
|
||||
return new_ta
|
||||
|
||||
|
||||
def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
|
||||
if len(tensors_or_tensorarrays) != len(tensors_or_flows):
|
||||
raise ValueError(
|
||||
"Lengths of original Tensor list and new list do not match: %d vs. %d" %
|
||||
(len(tensors_or_tensorarrays), len(tensors_or_flows)))
|
||||
return [
|
||||
_make_tensor_array(ta, t_or_flow) if isinstance(
|
||||
tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance(
|
||||
ta, tensor_array_ops.TensorArray) else t_or_flow
|
||||
for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
|
||||
]
|
||||
|
@ -134,12 +134,8 @@ class _GraphTensorArray(object):
|
||||
# shape is defined either by `element_shape` or the shape of the tensor
|
||||
# of the first write. If `infer_shape` is true, all writes checks for
|
||||
# shape equality.
|
||||
if element_shape is None:
|
||||
self._infer_shape = infer_shape
|
||||
self._element_shape = []
|
||||
else:
|
||||
self._infer_shape = True
|
||||
self._element_shape = [tensor_shape.as_shape(element_shape)]
|
||||
self._element_shape = [tensor_shape.as_shape(element_shape)]
|
||||
self._infer_shape = element_shape is not None or infer_shape
|
||||
with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
|
||||
if handle is not None:
|
||||
self._handle = handle
|
||||
@ -181,10 +177,7 @@ class _GraphTensorArray(object):
|
||||
|
||||
@property
|
||||
def element_shape(self):
|
||||
if self._element_shape:
|
||||
return self._element_shape[0]
|
||||
else:
|
||||
return tensor_shape.unknown_shape(None)
|
||||
return self._element_shape[0]
|
||||
|
||||
def _merge_element_shape(self, shape):
|
||||
"""Changes the element shape of the array given a shape to merge with.
|
||||
@ -196,15 +189,11 @@ class _GraphTensorArray(object):
|
||||
ValueError: if the provided shape is incompatible with the current
|
||||
element shape of the `TensorArray`.
|
||||
"""
|
||||
|
||||
if self._element_shape:
|
||||
if not shape.is_compatible_with(self._element_shape[0]):
|
||||
raise ValueError(
|
||||
"Inconsistent shapes: saw %s but expected %s "
|
||||
"(and infer_shape=True)" % (shape, self._element_shape[0]))
|
||||
self._element_shape[0] = self._element_shape[0].merge_with(shape)
|
||||
else:
|
||||
self._element_shape.append(shape)
|
||||
if not shape.is_compatible_with(self.element_shape):
|
||||
raise ValueError(
|
||||
"Inconsistent shapes: saw %s but expected %s "
|
||||
"(and infer_shape=True)" % (shape, self.element_shape))
|
||||
self._element_shape[0] = self.element_shape.merge_with(shape)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _maybe_colocate_with(self, value):
|
||||
@ -230,16 +219,7 @@ class _GraphTensorArray(object):
|
||||
def identity(self):
|
||||
"""See TensorArray."""
|
||||
flow = array_ops.identity(self._flow)
|
||||
ta = TensorArray(
|
||||
dtype=self._dtype,
|
||||
handle=self._handle,
|
||||
flow=flow,
|
||||
infer_shape=self._infer_shape,
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
ta._dynamic_size = self._dynamic_size
|
||||
return ta
|
||||
return build_ta_with_new_flow(self, flow)
|
||||
|
||||
def grad(self, source, flow=None, name=None):
|
||||
"""See TensorArray."""
|
||||
@ -261,7 +241,9 @@ class _GraphTensorArray(object):
|
||||
flow=flow,
|
||||
infer_shape=self._infer_shape,
|
||||
colocate_with_first_write_call=False)
|
||||
g._element_shape = self._element_shape
|
||||
# pylint: disable=protected-access
|
||||
g._implementation._element_shape = self._element_shape
|
||||
# pylint: enable=protected-access
|
||||
return g
|
||||
|
||||
def read(self, index, name=None):
|
||||
@ -293,16 +275,7 @@ class _GraphTensorArray(object):
|
||||
value=value,
|
||||
flow_in=self._flow,
|
||||
name=name)
|
||||
ta = TensorArray(
|
||||
dtype=self._dtype,
|
||||
handle=self._handle,
|
||||
flow=flow_out,
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
ta._dynamic_size = self._dynamic_size
|
||||
return ta
|
||||
return build_ta_with_new_flow(self, flow_out)
|
||||
|
||||
def stack(self, name=None):
|
||||
"""See TensorArray."""
|
||||
@ -323,25 +296,20 @@ class _GraphTensorArray(object):
|
||||
dtype=self._dtype,
|
||||
name=name,
|
||||
element_shape=element_shape)
|
||||
if self._element_shape and self._element_shape[0].dims is not None:
|
||||
value.set_shape([None] + self._element_shape[0].dims)
|
||||
if self.element_shape:
|
||||
value.set_shape([None] + self.element_shape.dims)
|
||||
return value
|
||||
|
||||
def concat(self, name=None):
|
||||
"""See TensorArray."""
|
||||
if self._element_shape and self._element_shape[0].dims is not None:
|
||||
element_shape_except0 = (
|
||||
tensor_shape.TensorShape(self._element_shape[0].dims[1:]))
|
||||
else:
|
||||
element_shape_except0 = tensor_shape.TensorShape(None)
|
||||
value, _ = gen_data_flow_ops.tensor_array_concat_v3(
|
||||
handle=self._handle,
|
||||
flow_in=self._flow,
|
||||
dtype=self._dtype,
|
||||
name=name,
|
||||
element_shape_except0=element_shape_except0)
|
||||
if self._element_shape and self._element_shape[0].dims is not None:
|
||||
value.set_shape([None] + self._element_shape[0].dims[1:])
|
||||
element_shape_except0=self.element_shape[1:])
|
||||
if self.element_shape:
|
||||
value.set_shape([None] + self.element_shape.dims[1:])
|
||||
return value
|
||||
|
||||
@tf_should_use.should_use_result
|
||||
@ -370,16 +338,7 @@ class _GraphTensorArray(object):
|
||||
value=value,
|
||||
flow_in=self._flow,
|
||||
name=name)
|
||||
ta = TensorArray(
|
||||
dtype=self._dtype,
|
||||
handle=self._handle,
|
||||
flow=flow_out,
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
ta._dynamic_size = self._dynamic_size
|
||||
return ta
|
||||
return build_ta_with_new_flow(self, flow_out)
|
||||
|
||||
@tf_should_use.should_use_result
|
||||
def split(self, value, lengths, name=None):
|
||||
@ -402,16 +361,7 @@ class _GraphTensorArray(object):
|
||||
lengths=lengths_64,
|
||||
flow_in=self._flow,
|
||||
name=name)
|
||||
ta = TensorArray(
|
||||
dtype=self._dtype,
|
||||
handle=self._handle,
|
||||
flow=flow_out,
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
ta._dynamic_size = self._dynamic_size
|
||||
return ta
|
||||
return build_ta_with_new_flow(self, flow_out)
|
||||
|
||||
def size(self, name=None):
|
||||
"""See TensorArray."""
|
||||
@ -496,12 +446,8 @@ class _GraphTensorArrayV2(object):
|
||||
# shape is defined either by `element_shape` or the shape of the tensor
|
||||
# of the first write. If `infer_shape` is true, all writes checks for
|
||||
# shape equality.
|
||||
if element_shape is None:
|
||||
self._infer_shape = infer_shape
|
||||
self._element_shape = []
|
||||
else:
|
||||
self._infer_shape = True
|
||||
self._element_shape = [tensor_shape.as_shape(element_shape)]
|
||||
self._element_shape = [tensor_shape.as_shape(element_shape)]
|
||||
self._infer_shape = element_shape is not None or infer_shape
|
||||
with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
|
||||
if flow is None:
|
||||
self._flow = list_ops.tensor_list_reserve(
|
||||
@ -526,10 +472,7 @@ class _GraphTensorArrayV2(object):
|
||||
|
||||
@property
|
||||
def element_shape(self):
|
||||
if self._element_shape:
|
||||
return self._element_shape[0]
|
||||
else:
|
||||
return tensor_shape.unknown_shape(None)
|
||||
return self._element_shape[0]
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
@ -547,15 +490,11 @@ class _GraphTensorArrayV2(object):
|
||||
ValueError: if the provided shape is incompatible with the current
|
||||
element shape of the `TensorArray`.
|
||||
"""
|
||||
|
||||
if self._element_shape:
|
||||
if not shape.is_compatible_with(self._element_shape[0]):
|
||||
raise ValueError(
|
||||
"Inconsistent shapes: saw %s but expected %s "
|
||||
"(and infer_shape=True)" % (shape, self._element_shape[0]))
|
||||
self._element_shape[0] = self._element_shape[0].merge_with(shape)
|
||||
else:
|
||||
self._element_shape.append(shape)
|
||||
if not shape.is_compatible_with(self.element_shape):
|
||||
raise ValueError(
|
||||
"Inconsistent shapes: saw %s but expected %s "
|
||||
"(and infer_shape=True)" % (shape, self.element_shape))
|
||||
self._element_shape[0] = self.element_shape.merge_with(shape)
|
||||
|
||||
def identity(self):
|
||||
"""See TensorArray."""
|
||||
@ -569,15 +508,11 @@ class _GraphTensorArrayV2(object):
|
||||
def read(self, index, name=None):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
|
||||
if self._element_shape:
|
||||
element_shape = self._element_shape[0]
|
||||
else:
|
||||
element_shape = tensor_shape.unknown_shape(None)
|
||||
value = list_ops.tensor_list_get_item(
|
||||
input_handle=self._flow,
|
||||
index=index,
|
||||
element_dtype=self._dtype,
|
||||
element_shape=element_shape,
|
||||
element_shape=self.element_shape,
|
||||
name=name)
|
||||
return value
|
||||
|
||||
@ -602,34 +537,26 @@ class _GraphTensorArrayV2(object):
|
||||
def stack(self, name=None):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
|
||||
if self._element_shape:
|
||||
element_shape = self._element_shape[0]
|
||||
else:
|
||||
element_shape = tensor_shape.unknown_shape(None)
|
||||
value = list_ops.tensor_list_stack(
|
||||
input_handle=self._flow,
|
||||
element_dtype=self._dtype,
|
||||
element_shape=element_shape)
|
||||
element_shape=self.element_shape)
|
||||
return value
|
||||
|
||||
def gather(self, indices, name=None):
|
||||
"""See TensorArray."""
|
||||
if self._element_shape:
|
||||
element_shape = self._element_shape[0]
|
||||
else:
|
||||
element_shape = tensor_shape.unknown_shape(None)
|
||||
value = list_ops.tensor_list_gather(
|
||||
input_handle=self._flow,
|
||||
indices=indices,
|
||||
element_dtype=self._dtype,
|
||||
element_shape=element_shape,
|
||||
element_shape=self.element_shape,
|
||||
name=name)
|
||||
return value
|
||||
|
||||
def concat(self, name=None):
|
||||
"""See TensorArray."""
|
||||
if self._element_shape and self._element_shape[0].dims is not None:
|
||||
element_shape = [None] + self._element_shape[0].dims[1:]
|
||||
if self.element_shape:
|
||||
element_shape = [None] + self.element_shape.dims[1:]
|
||||
else:
|
||||
element_shape = None
|
||||
|
||||
@ -665,9 +592,9 @@ class _GraphTensorArrayV2(object):
|
||||
_check_dtypes(value, self._dtype)
|
||||
if self._infer_shape and not context.executing_eagerly():
|
||||
self._merge_element_shape(value.shape[1:])
|
||||
element_shape = self._element_shape[0] if self._element_shape else None
|
||||
flow_out = list_ops.tensor_list_scatter(
|
||||
tensor=value, indices=indices, input_handle=self._flow)
|
||||
tensor=value, indices=indices, element_shape=self.element_shape,
|
||||
input_handle=self._flow)
|
||||
return build_ta_with_new_flow(self, flow_out)
|
||||
|
||||
@tf_should_use.should_use_result
|
||||
@ -689,7 +616,7 @@ class _GraphTensorArrayV2(object):
|
||||
flow_out = list_ops.tensor_list_split(
|
||||
tensor=value,
|
||||
lengths=lengths_64,
|
||||
element_shape=self._element_shape[0] if self._element_shape else None,
|
||||
element_shape=self.element_shape,
|
||||
name=name)
|
||||
return build_ta_with_new_flow(self, flow_out)
|
||||
|
||||
@ -761,7 +688,7 @@ class _EagerTensorArray(object):
|
||||
# we assign a dummy value to _flow in case other code assumes it to be
|
||||
# a Tensor
|
||||
self._flow = constant_op.constant(0, dtype=dtypes.int32)
|
||||
self._infer_shape = infer_shape
|
||||
self._infer_shape = element_shape is not None or infer_shape
|
||||
self._element_shape = tensor_shape.as_shape(element_shape)
|
||||
self._colocate_with_first_write_call = colocate_with_first_write_call
|
||||
|
||||
@ -791,10 +718,7 @@ class _EagerTensorArray(object):
|
||||
|
||||
@property
|
||||
def element_shape(self):
|
||||
if not self._element_shape:
|
||||
return tensor_shape.unknown_shape(None)
|
||||
else:
|
||||
return
|
||||
return self._element_shape
|
||||
|
||||
def identity(self):
|
||||
"""See TensorArray."""
|
||||
@ -1114,42 +1038,13 @@ class TensorArray(object):
|
||||
"""Python bool; if `True` the TensorArray can grow dynamically."""
|
||||
return self._implementation._dynamic_size
|
||||
|
||||
@property
|
||||
def _dynamic_size(self):
|
||||
return self._implementation._dynamic_size
|
||||
|
||||
@_dynamic_size.setter
|
||||
def _dynamic_size(self, dynamic_size):
|
||||
self._implementation._dynamic_size = dynamic_size
|
||||
|
||||
@property
|
||||
def _infer_shape(self):
|
||||
# TODO(slebedev): consider making public or changing TensorArrayStructure
|
||||
# to access _implementation directly. Note that dynamic_size is also
|
||||
# only used by TensorArrayStructure.
|
||||
return self._implementation._infer_shape
|
||||
|
||||
@_infer_shape.setter
|
||||
def _infer_shape(self, infer_shape):
|
||||
self._implementation._infer_shape = infer_shape
|
||||
|
||||
@property
|
||||
def _element_shape(self):
|
||||
return self._implementation._element_shape
|
||||
|
||||
@_element_shape.setter
|
||||
def _element_shape(self, element_shape):
|
||||
self._implementation._element_shape = element_shape
|
||||
|
||||
@property
|
||||
def _colocate_with_first_write_call(self):
|
||||
return self._implementation._colocate_with_first_write_call
|
||||
|
||||
@property
|
||||
def _colocate_with(self):
|
||||
return self._implementation._colocate_with
|
||||
|
||||
@_colocate_with.setter
|
||||
def _colocate_with(self, colocate_with):
|
||||
self._implementation._colocate_with = colocate_with
|
||||
|
||||
def identity(self):
|
||||
"""Returns a TensorArray with the same content and properties.
|
||||
|
||||
@ -1307,11 +1202,12 @@ class TensorArray(object):
|
||||
|
||||
def build_ta_with_new_flow(old_ta, flow):
|
||||
"""Builds a TensorArray with a new `flow` tensor."""
|
||||
# Sometimes we get old_ta as the implementation, sometimes it's the
|
||||
# TensorArray wrapper object.
|
||||
impl = (old_ta._implementation if isinstance(old_ta, TensorArray)
|
||||
else old_ta)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
# Sometimes we get old_ta as the implementation, sometimes it's the
|
||||
# TensorArray wrapper object.
|
||||
impl = (old_ta._implementation if isinstance(old_ta, TensorArray)
|
||||
else old_ta)
|
||||
if (not isinstance(impl, _GraphTensorArrayV2) and
|
||||
control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
|
||||
raise NotImplementedError("Attempting to build a graph-mode TF2-style "
|
||||
@ -1322,16 +1218,17 @@ def build_ta_with_new_flow(old_ta, flow):
|
||||
"inside a tf.function or tf.data map function. "
|
||||
"Instead, construct a new TensorArray inside "
|
||||
"the function.")
|
||||
ta = TensorArray(
|
||||
dtype=old_ta.dtype,
|
||||
dynamic_size=old_ta._dynamic_size,
|
||||
handle=old_ta.handle,
|
||||
new_ta = TensorArray(
|
||||
dtype=impl.dtype,
|
||||
handle=impl.handle,
|
||||
flow=flow,
|
||||
infer_shape=old_ta._infer_shape,
|
||||
colocate_with_first_write_call=old_ta._colocate_with_first_write_call)
|
||||
ta._colocate_with = old_ta._colocate_with
|
||||
ta._element_shape = old_ta._element_shape
|
||||
return ta
|
||||
infer_shape=impl._infer_shape,
|
||||
colocate_with_first_write_call=impl._colocate_with_first_write_call)
|
||||
new_impl = new_ta._implementation
|
||||
new_impl._dynamic_size = impl._dynamic_size
|
||||
new_impl._colocate_with = impl._colocate_with
|
||||
new_impl._element_shape = impl._element_shape # Share _element_shape.
|
||||
return new_ta
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user