pfor: Stop tiling/untiling TensorLists

TensorLists are vecorized by vectorizing the list components, so it's unnecessary. It also interferes with gradients since we don't have the kernels to run the gradient definitions (Tile needs Sum, Gather needs UnsortedSegmentSum).

Unfortunately tracing back to find the pre-tile tensor when untiling (rather than actually running gather) still ends up with tiled tensors as returns from control flow, and so still triggers tile gradients.

Keeps tiling for non-TensorList variants, since presumably we need that to fall back to while_loop. This means pfor is more reliant on handle data to identify variants than it was previously.

Adds handle data to TensorLists created eagerly. This is just for sanity while writing/maintaining unit tests; I don't expect users to create TensorLists manually.

PiperOrigin-RevId: 336177069
Change-Id: If37774ede39ddfd631bd5f02db5b5269bd8b37f5
This commit is contained in:
Allen Lavoie 2020-10-08 15:27:24 -07:00 committed by TensorFlower Gardener
parent d483dea0c6
commit e01adec56d
10 changed files with 245 additions and 64 deletions

View File

@ -3725,6 +3725,7 @@ py_library(
":gradients",
":gradients_util",
":graph_to_function_def",
":handle_data_util",
":pywrap_tensorflow",
":util",
"//tensorflow/python/compat",
@ -3859,6 +3860,21 @@ py_library(
],
)
py_library(
name = "handle_data_util",
srcs = [
"ops/handle_data_util.py",
],
srcs_version = "PY2AND3",
deps = [
":dtypes",
":framework_ops",
":protos_all_py",
":pywrap_tf_session",
":util",
],
)
py_library(
name = "gradients",
srcs = [
@ -3869,6 +3885,7 @@ py_library(
deps = [
":gradients_impl",
":gradients_util",
":handle_data_util",
":pywrap_tf_session",
":unconnected_gradients",
"//tensorflow/python/eager:forwardprop",
@ -4263,6 +4280,7 @@ py_library(
":auto_control_deps_utils",
":dtypes",
":framework_ops",
":handle_data_util",
":pywrap_tf_session",
":resource_variable_ops_gen",
":tensor_shape",
@ -4296,6 +4314,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":array_ops",
":handle_data_util",
":list_ops_gen",
],
)

View File

@ -589,6 +589,8 @@ class _EagerDefinedFunction(object):
config=config,
executor_type=executor_type)
for i, func_graph_output in enumerate(self._func_graph_outputs):
custom_gradient.copy_handle_data(func_graph_output, outputs[i])
if executing_eagerly:
return outputs
else:
@ -597,8 +599,6 @@ class _EagerDefinedFunction(object):
# once that's done.
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
for i, func_graph_output in enumerate(self._func_graph_outputs):
custom_gradient.copy_handle_data(func_graph_output, outputs[i])
return outputs

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
@ -42,6 +43,7 @@ from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import handle_data_util
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
@ -286,6 +288,7 @@ def _build_cond(pred,
# Prevent fetching since the variant outputs can't be fetched directly.
if_op.graph.prevent_fetching(if_op)
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
# Return identities for each output of the If op, rather than the output of
# the If op directly. This makes pruning work if the output of cond() is
# fetched: the lowering pass converts the If outputs into IdentityN outputs,
@ -813,6 +816,32 @@ def _get_output_shapes(*branch_graph_outputs):
return output_shapes
def _copy_handle_data(external_tensors, *branch_graph_outputs):
"""Combines shapes in handle data and sets metadata on `external_tensors`."""
for tensors in zip(external_tensors, *branch_graph_outputs):
external = tensors[0]
internal = tensors[1:]
internal_handle_data = []
for tensor in internal:
handle_data = handle_data_util.get_resource_handle_data(tensor)
# NOTE: Assumes handle data has only one ShapeAndType entry. It's
# unclear how to combine different lengths across branches.
if not handle_data.is_set or len(handle_data.shape_and_type) != 1:
break
internal_handle_data.append(handle_data)
else: # There is handle data, so we need to combine it.
combined_shape = tensor_shape.TensorShape(None)
for handle_data in internal_handle_data:
handle_shape = tensor_shape.TensorShape(
handle_data.shape_and_type[0].shape)
combined_shape = combined_shape.most_specific_compatible_shape(
handle_shape)
combined_handle_data = internal_handle_data[0]
combined_handle_data.shape_and_type[0].shape.CopyFrom(
combined_shape.as_proto())
handle_data_util.set_handle_data(external, combined_handle_data)
def verify_captures(op_type, branch_graphs):
"""Verify that a branch's tensor is not accessed in another branch fn."""
# Note: It is technically not possible for lower-branch_index branches to
@ -1143,6 +1172,7 @@ def _build_case(branch_index,
# Prevent fetching since the variant outputs can't be fetched directly.
case_op.graph.prevent_fetching(case_op)
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
# Return identities for each output of the Case op, rather than the output of
# the Case op directly. This makes pruning work if the output of switch_case()
# is fetched: the lowering pass converts the Case outputs into IdentityN

View File

@ -17,7 +17,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
@ -25,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import handle_data_util
from tensorflow.python.ops import op_selector
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@ -42,38 +42,8 @@ VAR_OP_TYPES = [
]
def copy_handle_data(source_t, target_t):
"""Copies HandleData for variant and resource type tensors if available.
The CppShapeInferenceResult::HandleData proto contains information about the
shapes and types of the element tensors of resource/variant type tensors.
We need to copy this across function boundaries, i.e., when capturing a
placeholder or when returning a function tensor as output. If we don't do this
the element tensors will have unknown shapes, e.g., if a TensorList variant
tensor is captured as a placeholder, elements popped from that list would have
unknown shape.
Args:
source_t: The tensor to copy HandleData from.
target_t: The tensor to copy HandleData to.
"""
if (target_t.dtype == dtypes.resource or
target_t.dtype == dtypes.variant):
if isinstance(source_t, ops.EagerTensor):
handle_data = source_t._handle_data # pylint: disable=protected-access
else:
handle_data = resource_variable_ops.get_resource_handle_data(source_t)
if (handle_data is not None
and handle_data.is_set
and handle_data.shape_and_type):
# pylint: disable=protected-access
if isinstance(target_t, ops.EagerTensor):
target_t._handle_data = handle_data
return
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
target_t._as_tf_output(),
handle_data.SerializeToString())
# pylint: enable=protected-access
# TODO(allenl): Remove this alias and migrate callers.
copy_handle_data = handle_data_util.copy_handle_data
@tf_export("custom_gradient")

View File

@ -0,0 +1,73 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Decorator to overrides the gradient for a function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.util import compat
def get_resource_handle_data(graph_op):
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
handle_data = pywrap_tf_session.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
def copy_handle_data(source_t, target_t):
"""Copies HandleData for variant and resource type tensors if available.
The CppShapeInferenceResult::HandleData proto contains information about the
shapes and types of the element tensors of resource/variant type tensors.
We need to copy this across function boundaries, i.e., when capturing a
placeholder or when returning a function tensor as output. If we don't do this
the element tensors will have unknown shapes, e.g., if a TensorList variant
tensor is captured as a placeholder, elements popped from that list would have
unknown shape.
Args:
source_t: The tensor to copy HandleData from.
target_t: The tensor to copy HandleData to.
"""
if (target_t.dtype == dtypes.resource or
target_t.dtype == dtypes.variant):
if isinstance(source_t, ops.EagerTensor):
handle_data = source_t._handle_data # pylint: disable=protected-access
else:
handle_data = get_resource_handle_data(source_t)
if (handle_data is not None
and handle_data.is_set
and handle_data.shape_and_type):
set_handle_data(target_t, handle_data)
def set_handle_data(target_t, handle_data):
# pylint: disable=protected-access
if isinstance(target_t, ops.EagerTensor):
target_t._handle_data = handle_data
return
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
target_t._as_tf_output(),
handle_data.SerializeToString())
# pylint: enable=protected-access

View File

@ -19,11 +19,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import handle_data_util
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_list_ops import *
@ -56,19 +60,45 @@ def empty_tensor_list(element_shape,
name=name)
def _set_handle_data(list_handle, element_shape, element_dtype):
"""Sets type information on `list_handle` for consistency with graphs."""
# TODO(b/169968286): It would be better if we had a consistent story for
# creating handle data from eager operations (shared with VarHandleOp).
if isinstance(list_handle, ops.EagerTensor):
if tensor_util.is_tensor(element_shape):
element_shape = tensor_shape.TensorShape(None)
elif not isinstance(element_shape, tensor_shape.TensorShape):
element_shape = tensor_shape.TensorShape(element_shape)
handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
handle_data.is_set = True
handle_data.shape_and_type.append(
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
shape=element_shape.as_proto(),
dtype=element_dtype.as_datatype_enum,
specialized_type=types_pb2.ST_TENSOR_LIST))
list_handle._handle_data = handle_data # pylint: disable=protected-access
def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
return gen_list_ops.tensor_list_reserve(
result = gen_list_ops.tensor_list_reserve(
element_shape=_build_element_shape(element_shape),
num_elements=num_elements,
element_dtype=element_dtype,
name=name)
# TODO(b/169968286): gen_ops needs to ensure the metadata is properly
# populated for eager operations.
_set_handle_data(result, element_shape, element_dtype)
return result
def tensor_list_from_tensor(tensor, element_shape, name=None):
return gen_list_ops.tensor_list_from_tensor(
tensor = ops.convert_to_tensor(tensor)
result = gen_list_ops.tensor_list_from_tensor(
tensor=tensor,
element_shape=_build_element_shape(element_shape),
name=name)
_set_handle_data(result, tensor.shape, tensor.dtype)
return result
def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
@ -107,16 +137,22 @@ def tensor_list_scatter(tensor,
element_shape=None,
input_handle=None,
name=None):
"""Returns a TensorList created or updated by scattering `tensor`."""
tensor = ops.convert_to_tensor(tensor)
if input_handle is not None:
return gen_list_ops.tensor_list_scatter_into_existing_list(
output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(
input_handle=input_handle, tensor=tensor, indices=indices, name=name)
handle_data_util.copy_handle_data(input_handle, output_handle)
return output_handle
else:
return gen_list_ops.tensor_list_scatter_v2(
output_handle = gen_list_ops.tensor_list_scatter_v2(
tensor=tensor,
indices=indices,
element_shape=_build_element_shape(element_shape),
num_elements=-1,
name=name)
_set_handle_data(output_handle, element_shape, tensor.dtype)
return output_handle
def tensor_list_stack(input_handle,
@ -167,8 +203,10 @@ def tensor_list_set_item(input_handle,
lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda
input_handle, index + 1),
lambda: input_handle)
return gen_list_ops.tensor_list_set_item(
output_handle = gen_list_ops.tensor_list_set_item(
input_handle=input_handle, index=index, item=item, name=name)
handle_data_util.copy_handle_data(input_handle, output_handle)
return output_handle
@ops.RegisterGradient("TensorListPushBack")

View File

@ -450,6 +450,13 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
results of applying fn to tensors unpacked from elems along the first
dimension, from first to last.
Although they are less common as user-visible inputs and outputs, note that
tensors of type `tf.variant` which represent tensor lists (for example from
`tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list
contents rather than the variant itself, and so the container tensor will
have a scalar shape when returned rather than the usual stacked shape. This
improves the performance of control flow gradient vectorization.
Raises:
ValueError: If vectorization fails and fallback_to_while_loop is False.
"""

View File

@ -43,10 +43,11 @@ from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import list_ops
@ -1092,13 +1093,8 @@ class TensorListTest(PForTestCase):
def loop_fn(i):
l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
l2 = list_ops.tensor_list_set_item(l2, 1, i)
l1_graph = array_ops.identity(l1)
# TODO(b/169968286): Typically TensorLists are both created and used in a
# graph; creating TensorLists eagerly with handle data doesn't work at the
# moment. Copying the handle data manually reproduces the expected case.
custom_gradient.copy_handle_data(l2, l1_graph)
return list_ops.tensor_list_stack(
math_ops.add_n([l1_graph, l2]), dtypes.int32)
math_ops.add_n([l1, l2]), dtypes.int32)
self._test_loop_fn(loop_fn, 2)
@ -1571,6 +1567,23 @@ class WhileV2Test(PForTestCase):
y = constant_op.constant(np.random.uniform(size=(3, 3)))
self.assertAllClose(_f(x, y, True), _f(x, y, False))
def test_scan(self):
np.random.seed(seed=42)
data = np.random.randn(3).astype(np.float32)
def log_prob(x):
return math_ops.reduce_sum(functional_ops.scan_v2(
lambda _, yi: (x - yi)**2,
elems=data,
initializer=constant_op.constant(0.)))
x = variables.Variable(array_ops.ones([2]))
self.evaluate(x.initializer)
v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x)
theoretical, numerical = gradient_checker_v2.compute_gradient(
v_log_prob, (x,), delta=1e-3)
self.assertAllClose(theoretical, numerical, rtol=1e-2)
@test_util.run_all_in_graph_and_eager_modes
class NestedControlFlowTest(PForTestCase):

View File

@ -78,6 +78,29 @@ flags.DEFINE_bool(
"DEPRECATED: Flag is ignored.")
def _variant_handle_data(t):
"""Fetches handle data for a variant tensor `t`, or None if unavailable."""
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if not handle_data.is_set:
return None
if len(handle_data.shape_and_type) != 1:
raise ValueError("Expected handle data of length 1, got {!r} of length {}"
.format(handle_data, len(handle_data.shape_and_type)))
return handle_data.shape_and_type[0]
def _is_tensor_list(t):
"""True if `t` is a TensorList, False if it isn't, None if unknown."""
if t.dtype != dtypes.variant:
return False
shape_and_type = _variant_handle_data(t)
if shape_and_type is None:
# TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we
# can make this an error instead of assuming TensorLists have handle data.
return None # Presumed not a TensorList
return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST
def _stack(t, length):
"""stacks `t` `length` times."""
# Note that this stacking may currently be triggered, for example, when a
@ -86,13 +109,9 @@ def _stack(t, length):
# suitable since operations on stacked handles may expect a vectorized version
# of the variant.
if t.dtype == dtypes.variant:
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if not handle_data.is_set:
shape_and_type = _variant_handle_data(t)
if shape_and_type is None:
raise ValueError("Required handle data not set for {!r}".format(t))
if len(handle_data.shape_and_type) != 1:
raise ValueError("Expected handle data of length 1, got {!r} of length {}"
.format(handle_data, len(handle_data.shape_and_type)))
shape_and_type = handle_data.shape_and_type[0]
if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST:
return wrap(
_stack_tensor_list(t, shape_and_type.dtype, length),
@ -1606,7 +1625,10 @@ class PFor(object):
else:
batch_dim = tensor_shape.TensorShape(loop_len)
output_shape = batch_dim.concatenate(output_shape)
new_output.t.set_shape(output_shape)
if _is_tensor_list(new_output.t):
new_output.t.set_shape([])
else:
new_output.t.set_shape(output_shape)
self._add_conversion(old_output, new_output)
stack.pop(0)
@ -3576,6 +3598,9 @@ def _stack_tensor_list_shape(shape, first_dim):
def _tile_variant_with_length(t, length):
"""stacks `t` `length` times."""
if _is_tensor_list(t):
# The content of TensorLists is vectorized, not the variant itself.
return t
original_tensor = t
t.set_shape([])
t = array_ops.reshape(t, [-1])
@ -3593,6 +3618,13 @@ def _tile_variant(t, pfor_input):
def _untile_variant(t):
if _is_tensor_list(t):
# The content of TensorLists is vectorized, not the variant itself.
if not t.shape.is_compatible_with([]):
raise AssertionError(
"Unexpectedly saw a TensorList with non-scalar shape: {!r}"
.format(t))
return t
return array_ops.gather(t, 0)
@ -4201,8 +4233,12 @@ class WhileV2(object):
shapes = [tensor_shape.TensorShape(shape) for shape in shapes]
for i, shape in enumerate(shapes):
shape = shape.merge_with(output_shapes[i])
if self._pfor_input.input(i).is_stacked:
shape = tensor_shape.TensorShape([None]).concatenate(shape)
pfor_input = self._pfor_input.input(i)
if pfor_input.is_stacked:
if _is_tensor_list(pfor_input.t):
shape = tensor_shape.TensorShape([]).concatenate(shape)
else:
shape = tensor_shape.TensorShape([None]).concatenate(shape)
output_shapes[i] = shape
assert len(output_shapes) == self._pfor_input.num_inputs
return output_shapes

View File

@ -43,6 +43,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import handle_data_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
@ -62,14 +63,8 @@ acd.register_read_only_resource_op("ResourceGatherNd")
acd.register_read_only_resource_op("_ReadVariablesOp")
def get_resource_handle_data(graph_op):
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
handle_data = pywrap_tf_session.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
# TODO(allenl): Remove this alias and migrate callers.
get_resource_handle_data = handle_data_util.get_resource_handle_data
def get_eager_safe_handle_data(handle):