STT-tensorflow/tensorflow/python/ops/handle_data_util.py
Allen Lavoie e01adec56d pfor: Stop tiling/untiling TensorLists
TensorLists are vecorized by vectorizing the list components, so it's unnecessary. It also interferes with gradients since we don't have the kernels to run the gradient definitions (Tile needs Sum, Gather needs UnsortedSegmentSum).

Unfortunately tracing back to find the pre-tile tensor when untiling (rather than actually running gather) still ends up with tiled tensors as returns from control flow, and so still triggers tile gradients.

Keeps tiling for non-TensorList variants, since presumably we need that to fall back to while_loop. This means pfor is more reliant on handle data to identify variants than it was previously.

Adds handle data to TensorLists created eagerly. This is just for sanity while writing/maintaining unit tests; I don't expect users to create TensorLists manually.

PiperOrigin-RevId: 336177069
Change-Id: If37774ede39ddfd631bd5f02db5b5269bd8b37f5
2020-10-08 15:45:56 -07:00

74 lines
3.0 KiB
Python

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decorator to overrides the gradient for a function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.util import compat
def get_resource_handle_data(graph_op):
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
handle_data = pywrap_tf_session.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
def copy_handle_data(source_t, target_t):
"""Copies HandleData for variant and resource type tensors if available.
The CppShapeInferenceResult::HandleData proto contains information about the
shapes and types of the element tensors of resource/variant type tensors.
We need to copy this across function boundaries, i.e., when capturing a
placeholder or when returning a function tensor as output. If we don't do this
the element tensors will have unknown shapes, e.g., if a TensorList variant
tensor is captured as a placeholder, elements popped from that list would have
unknown shape.
Args:
source_t: The tensor to copy HandleData from.
target_t: The tensor to copy HandleData to.
"""
if (target_t.dtype == dtypes.resource or
target_t.dtype == dtypes.variant):
if isinstance(source_t, ops.EagerTensor):
handle_data = source_t._handle_data # pylint: disable=protected-access
else:
handle_data = get_resource_handle_data(source_t)
if (handle_data is not None
and handle_data.is_set
and handle_data.shape_and_type):
set_handle_data(target_t, handle_data)
def set_handle_data(target_t, handle_data):
# pylint: disable=protected-access
if isinstance(target_t, ops.EagerTensor):
target_t._handle_data = handle_data
return
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
target_t._as_tf_output(),
handle_data.SerializeToString())
# pylint: enable=protected-access