diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f323a6c918e..47fb5bfb4d0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3725,6 +3725,7 @@ py_library( ":gradients", ":gradients_util", ":graph_to_function_def", + ":handle_data_util", ":pywrap_tensorflow", ":util", "//tensorflow/python/compat", @@ -3859,6 +3860,21 @@ py_library( ], ) +py_library( + name = "handle_data_util", + srcs = [ + "ops/handle_data_util.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ":framework_ops", + ":protos_all_py", + ":pywrap_tf_session", + ":util", + ], +) + py_library( name = "gradients", srcs = [ @@ -3869,6 +3885,7 @@ py_library( deps = [ ":gradients_impl", ":gradients_util", + ":handle_data_util", ":pywrap_tf_session", ":unconnected_gradients", "//tensorflow/python/eager:forwardprop", @@ -4263,6 +4280,7 @@ py_library( ":auto_control_deps_utils", ":dtypes", ":framework_ops", + ":handle_data_util", ":pywrap_tf_session", ":resource_variable_ops_gen", ":tensor_shape", @@ -4296,6 +4314,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":handle_data_util", ":list_ops_gen", ], ) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 46d73613d62..afd5cb31374 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -589,6 +589,8 @@ class _EagerDefinedFunction(object): config=config, executor_type=executor_type) + for i, func_graph_output in enumerate(self._func_graph_outputs): + custom_gradient.copy_handle_data(func_graph_output, outputs[i]) if executing_eagerly: return outputs else: @@ -597,8 +599,6 @@ class _EagerDefinedFunction(object): # once that's done. for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) - for i, func_graph_output in enumerate(self._func_graph_outputs): - custom_gradient.copy_handle_data(func_graph_output, outputs[i]) return outputs diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 163f0fb7077..5bdd2494e91 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util @@ -42,6 +43,7 @@ from tensorflow.python.ops import default_gradient from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import gradients_util +from tensorflow.python.ops import handle_data_util from tensorflow.python.ops import math_ops from tensorflow.python.util import nest @@ -286,6 +288,7 @@ def _build_cond(pred, # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) + _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, @@ -813,6 +816,32 @@ def _get_output_shapes(*branch_graph_outputs): return output_shapes +def _copy_handle_data(external_tensors, *branch_graph_outputs): + """Combines shapes in handle data and sets metadata on `external_tensors`.""" + for tensors in zip(external_tensors, *branch_graph_outputs): + external = tensors[0] + internal = tensors[1:] + internal_handle_data = [] + for tensor in internal: + handle_data = handle_data_util.get_resource_handle_data(tensor) + # NOTE: Assumes handle data has only one ShapeAndType entry. It's + # unclear how to combine different lengths across branches. + if not handle_data.is_set or len(handle_data.shape_and_type) != 1: + break + internal_handle_data.append(handle_data) + else: # There is handle data, so we need to combine it. + combined_shape = tensor_shape.TensorShape(None) + for handle_data in internal_handle_data: + handle_shape = tensor_shape.TensorShape( + handle_data.shape_and_type[0].shape) + combined_shape = combined_shape.most_specific_compatible_shape( + handle_shape) + combined_handle_data = internal_handle_data[0] + combined_handle_data.shape_and_type[0].shape.CopyFrom( + combined_shape.as_proto()) + handle_data_util.set_handle_data(external, combined_handle_data) + + def verify_captures(op_type, branch_graphs): """Verify that a branch's tensor is not accessed in another branch fn.""" # Note: It is technically not possible for lower-branch_index branches to @@ -1143,6 +1172,7 @@ def _build_case(branch_index, # Prevent fetching since the variant outputs can't be fetched directly. case_op.graph.prevent_fetching(case_op) + _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) # Return identities for each output of the Case op, rather than the output of # the Case op directly. This makes pruning work if the output of switch_case() # is fetched: the lowering pass converts the Case outputs into IdentityN diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index 33156f7c9c7..3e38f68a0f7 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.client import pywrap_tf_session from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape as tape_lib @@ -25,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import handle_data_util from tensorflow.python.ops import op_selector from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -42,38 +42,8 @@ VAR_OP_TYPES = [ ] -def copy_handle_data(source_t, target_t): - """Copies HandleData for variant and resource type tensors if available. - - The CppShapeInferenceResult::HandleData proto contains information about the - shapes and types of the element tensors of resource/variant type tensors. - We need to copy this across function boundaries, i.e., when capturing a - placeholder or when returning a function tensor as output. If we don't do this - the element tensors will have unknown shapes, e.g., if a TensorList variant - tensor is captured as a placeholder, elements popped from that list would have - unknown shape. - - Args: - source_t: The tensor to copy HandleData from. - target_t: The tensor to copy HandleData to. - """ - if (target_t.dtype == dtypes.resource or - target_t.dtype == dtypes.variant): - if isinstance(source_t, ops.EagerTensor): - handle_data = source_t._handle_data # pylint: disable=protected-access - else: - handle_data = resource_variable_ops.get_resource_handle_data(source_t) - if (handle_data is not None - and handle_data.is_set - and handle_data.shape_and_type): - # pylint: disable=protected-access - if isinstance(target_t, ops.EagerTensor): - target_t._handle_data = handle_data - return - pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph, - target_t._as_tf_output(), - handle_data.SerializeToString()) - # pylint: enable=protected-access +# TODO(allenl): Remove this alias and migrate callers. +copy_handle_data = handle_data_util.copy_handle_data @tf_export("custom_gradient") diff --git a/tensorflow/python/ops/handle_data_util.py b/tensorflow/python/ops/handle_data_util.py new file mode 100644 index 00000000000..d83bea3cb18 --- /dev/null +++ b/tensorflow/python/ops/handle_data_util.py @@ -0,0 +1,73 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Decorator to overrides the gradient for a function.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import pywrap_tf_session +from tensorflow.python.framework import cpp_shape_inference_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.util import compat + + +def get_resource_handle_data(graph_op): + assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck + + handle_data = pywrap_tf_session.GetHandleShapeAndType( + graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access + + return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( + compat.as_bytes(handle_data)) + + +def copy_handle_data(source_t, target_t): + """Copies HandleData for variant and resource type tensors if available. + + The CppShapeInferenceResult::HandleData proto contains information about the + shapes and types of the element tensors of resource/variant type tensors. + We need to copy this across function boundaries, i.e., when capturing a + placeholder or when returning a function tensor as output. If we don't do this + the element tensors will have unknown shapes, e.g., if a TensorList variant + tensor is captured as a placeholder, elements popped from that list would have + unknown shape. + + Args: + source_t: The tensor to copy HandleData from. + target_t: The tensor to copy HandleData to. + """ + if (target_t.dtype == dtypes.resource or + target_t.dtype == dtypes.variant): + if isinstance(source_t, ops.EagerTensor): + handle_data = source_t._handle_data # pylint: disable=protected-access + else: + handle_data = get_resource_handle_data(source_t) + if (handle_data is not None + and handle_data.is_set + and handle_data.shape_and_type): + set_handle_data(target_t, handle_data) + + +def set_handle_data(target_t, handle_data): + # pylint: disable=protected-access + if isinstance(target_t, ops.EagerTensor): + target_t._handle_data = handle_data + return + pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph, + target_t._as_tf_output(), + handle_data.SerializeToString()) + # pylint: enable=protected-access diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py index 3e7c116ec97..8379a26a260 100644 --- a/tensorflow/python/ops/list_ops.py +++ b/tensorflow/python/ops/list_ops.py @@ -19,11 +19,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_list_ops +from tensorflow.python.ops import handle_data_util # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_list_ops import * @@ -56,19 +60,45 @@ def empty_tensor_list(element_shape, name=name) +def _set_handle_data(list_handle, element_shape, element_dtype): + """Sets type information on `list_handle` for consistency with graphs.""" + # TODO(b/169968286): It would be better if we had a consistent story for + # creating handle data from eager operations (shared with VarHandleOp). + if isinstance(list_handle, ops.EagerTensor): + if tensor_util.is_tensor(element_shape): + element_shape = tensor_shape.TensorShape(None) + elif not isinstance(element_shape, tensor_shape.TensorShape): + element_shape = tensor_shape.TensorShape(element_shape) + handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() + handle_data.is_set = True + handle_data.shape_and_type.append( + cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( + shape=element_shape.as_proto(), + dtype=element_dtype.as_datatype_enum, + specialized_type=types_pb2.ST_TENSOR_LIST)) + list_handle._handle_data = handle_data # pylint: disable=protected-access + + def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): - return gen_list_ops.tensor_list_reserve( + result = gen_list_ops.tensor_list_reserve( element_shape=_build_element_shape(element_shape), num_elements=num_elements, element_dtype=element_dtype, name=name) + # TODO(b/169968286): gen_ops needs to ensure the metadata is properly + # populated for eager operations. + _set_handle_data(result, element_shape, element_dtype) + return result def tensor_list_from_tensor(tensor, element_shape, name=None): - return gen_list_ops.tensor_list_from_tensor( + tensor = ops.convert_to_tensor(tensor) + result = gen_list_ops.tensor_list_from_tensor( tensor=tensor, element_shape=_build_element_shape(element_shape), name=name) + _set_handle_data(result, tensor.shape, tensor.dtype) + return result def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None, @@ -107,16 +137,22 @@ def tensor_list_scatter(tensor, element_shape=None, input_handle=None, name=None): + """Returns a TensorList created or updated by scattering `tensor`.""" + tensor = ops.convert_to_tensor(tensor) if input_handle is not None: - return gen_list_ops.tensor_list_scatter_into_existing_list( + output_handle = gen_list_ops.tensor_list_scatter_into_existing_list( input_handle=input_handle, tensor=tensor, indices=indices, name=name) + handle_data_util.copy_handle_data(input_handle, output_handle) + return output_handle else: - return gen_list_ops.tensor_list_scatter_v2( + output_handle = gen_list_ops.tensor_list_scatter_v2( tensor=tensor, indices=indices, element_shape=_build_element_shape(element_shape), num_elements=-1, name=name) + _set_handle_data(output_handle, element_shape, tensor.dtype) + return output_handle def tensor_list_stack(input_handle, @@ -167,8 +203,10 @@ def tensor_list_set_item(input_handle, lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda input_handle, index + 1), lambda: input_handle) - return gen_list_ops.tensor_list_set_item( + output_handle = gen_list_ops.tensor_list_set_item( input_handle=input_handle, index=index, item=item, name=name) + handle_data_util.copy_handle_data(input_handle, output_handle) + return output_handle @ops.RegisterGradient("TensorListPushBack") diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index b60bc210e9b..504574385c6 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -450,6 +450,13 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True): results of applying fn to tensors unpacked from elems along the first dimension, from first to last. + Although they are less common as user-visible inputs and outputs, note that + tensors of type `tf.variant` which represent tensor lists (for example from + `tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list + contents rather than the variant itself, and so the container tensor will + have a scalar shape when returned rather than the usual stacked shape. This + improves the performance of control flow gradient vectorization. + Raises: ValueError: If vectorization fails and fallback_to_while_loop is False. """ diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 63b99b28f5e..3a0c6cf1a14 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -43,10 +43,11 @@ from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_v2_toggles -from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_list_ops from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import gradients as gradient_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import list_ops @@ -1092,13 +1093,8 @@ class TensorListTest(PForTestCase): def loop_fn(i): l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) l2 = list_ops.tensor_list_set_item(l2, 1, i) - l1_graph = array_ops.identity(l1) - # TODO(b/169968286): Typically TensorLists are both created and used in a - # graph; creating TensorLists eagerly with handle data doesn't work at the - # moment. Copying the handle data manually reproduces the expected case. - custom_gradient.copy_handle_data(l2, l1_graph) return list_ops.tensor_list_stack( - math_ops.add_n([l1_graph, l2]), dtypes.int32) + math_ops.add_n([l1, l2]), dtypes.int32) self._test_loop_fn(loop_fn, 2) @@ -1571,6 +1567,23 @@ class WhileV2Test(PForTestCase): y = constant_op.constant(np.random.uniform(size=(3, 3))) self.assertAllClose(_f(x, y, True), _f(x, y, False)) + def test_scan(self): + np.random.seed(seed=42) + data = np.random.randn(3).astype(np.float32) + + def log_prob(x): + return math_ops.reduce_sum(functional_ops.scan_v2( + lambda _, yi: (x - yi)**2, + elems=data, + initializer=constant_op.constant(0.))) + + x = variables.Variable(array_ops.ones([2])) + self.evaluate(x.initializer) + v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x) + theoretical, numerical = gradient_checker_v2.compute_gradient( + v_log_prob, (x,), delta=1e-3) + self.assertAllClose(theoretical, numerical, rtol=1e-2) + @test_util.run_all_in_graph_and_eager_modes class NestedControlFlowTest(PForTestCase): diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 2489ecd713f..7e460176c61 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -78,6 +78,29 @@ flags.DEFINE_bool( "DEPRECATED: Flag is ignored.") +def _variant_handle_data(t): + """Fetches handle data for a variant tensor `t`, or None if unavailable.""" + handle_data = resource_variable_ops.get_eager_safe_handle_data(t) + if not handle_data.is_set: + return None + if len(handle_data.shape_and_type) != 1: + raise ValueError("Expected handle data of length 1, got {!r} of length {}" + .format(handle_data, len(handle_data.shape_and_type))) + return handle_data.shape_and_type[0] + + +def _is_tensor_list(t): + """True if `t` is a TensorList, False if it isn't, None if unknown.""" + if t.dtype != dtypes.variant: + return False + shape_and_type = _variant_handle_data(t) + if shape_and_type is None: + # TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we + # can make this an error instead of assuming TensorLists have handle data. + return None # Presumed not a TensorList + return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST + + def _stack(t, length): """stacks `t` `length` times.""" # Note that this stacking may currently be triggered, for example, when a @@ -86,13 +109,9 @@ def _stack(t, length): # suitable since operations on stacked handles may expect a vectorized version # of the variant. if t.dtype == dtypes.variant: - handle_data = resource_variable_ops.get_eager_safe_handle_data(t) - if not handle_data.is_set: + shape_and_type = _variant_handle_data(t) + if shape_and_type is None: raise ValueError("Required handle data not set for {!r}".format(t)) - if len(handle_data.shape_and_type) != 1: - raise ValueError("Expected handle data of length 1, got {!r} of length {}" - .format(handle_data, len(handle_data.shape_and_type))) - shape_and_type = handle_data.shape_and_type[0] if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST: return wrap( _stack_tensor_list(t, shape_and_type.dtype, length), @@ -1606,7 +1625,10 @@ class PFor(object): else: batch_dim = tensor_shape.TensorShape(loop_len) output_shape = batch_dim.concatenate(output_shape) - new_output.t.set_shape(output_shape) + if _is_tensor_list(new_output.t): + new_output.t.set_shape([]) + else: + new_output.t.set_shape(output_shape) self._add_conversion(old_output, new_output) stack.pop(0) @@ -3576,6 +3598,9 @@ def _stack_tensor_list_shape(shape, first_dim): def _tile_variant_with_length(t, length): """stacks `t` `length` times.""" + if _is_tensor_list(t): + # The content of TensorLists is vectorized, not the variant itself. + return t original_tensor = t t.set_shape([]) t = array_ops.reshape(t, [-1]) @@ -3593,6 +3618,13 @@ def _tile_variant(t, pfor_input): def _untile_variant(t): + if _is_tensor_list(t): + # The content of TensorLists is vectorized, not the variant itself. + if not t.shape.is_compatible_with([]): + raise AssertionError( + "Unexpectedly saw a TensorList with non-scalar shape: {!r}" + .format(t)) + return t return array_ops.gather(t, 0) @@ -4201,8 +4233,12 @@ class WhileV2(object): shapes = [tensor_shape.TensorShape(shape) for shape in shapes] for i, shape in enumerate(shapes): shape = shape.merge_with(output_shapes[i]) - if self._pfor_input.input(i).is_stacked: - shape = tensor_shape.TensorShape([None]).concatenate(shape) + pfor_input = self._pfor_input.input(i) + if pfor_input.is_stacked: + if _is_tensor_list(pfor_input.t): + shape = tensor_shape.TensorShape([]).concatenate(shape) + else: + shape = tensor_shape.TensorShape([None]).concatenate(shape) output_shapes[i] = shape assert len(output_shapes) == self._pfor_input.num_inputs return output_shapes diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 162f4057ff0..6cda36d556e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -43,6 +43,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_state_ops +from tensorflow.python.ops import handle_data_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables @@ -62,14 +63,8 @@ acd.register_read_only_resource_op("ResourceGatherNd") acd.register_read_only_resource_op("_ReadVariablesOp") -def get_resource_handle_data(graph_op): - assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck - - handle_data = pywrap_tf_session.GetHandleShapeAndType( - graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access - - return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( - compat.as_bytes(handle_data)) +# TODO(allenl): Remove this alias and migrate callers. +get_resource_handle_data = handle_data_util.get_resource_handle_data def get_eager_safe_handle_data(handle):