pfor: Stop tiling/untiling TensorLists
TensorLists are vecorized by vectorizing the list components, so it's unnecessary. It also interferes with gradients since we don't have the kernels to run the gradient definitions (Tile needs Sum, Gather needs UnsortedSegmentSum). Unfortunately tracing back to find the pre-tile tensor when untiling (rather than actually running gather) still ends up with tiled tensors as returns from control flow, and so still triggers tile gradients. Keeps tiling for non-TensorList variants, since presumably we need that to fall back to while_loop. This means pfor is more reliant on handle data to identify variants than it was previously. Adds handle data to TensorLists created eagerly. This is just for sanity while writing/maintaining unit tests; I don't expect users to create TensorLists manually. PiperOrigin-RevId: 336177069 Change-Id: If37774ede39ddfd631bd5f02db5b5269bd8b37f5
This commit is contained in:
parent
d483dea0c6
commit
e01adec56d
@ -3725,6 +3725,7 @@ py_library(
|
||||
":gradients",
|
||||
":gradients_util",
|
||||
":graph_to_function_def",
|
||||
":handle_data_util",
|
||||
":pywrap_tensorflow",
|
||||
":util",
|
||||
"//tensorflow/python/compat",
|
||||
@ -3859,6 +3860,21 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "handle_data_util",
|
||||
srcs = [
|
||||
"ops/handle_data_util.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":protos_all_py",
|
||||
":pywrap_tf_session",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
@ -3869,6 +3885,7 @@ py_library(
|
||||
deps = [
|
||||
":gradients_impl",
|
||||
":gradients_util",
|
||||
":handle_data_util",
|
||||
":pywrap_tf_session",
|
||||
":unconnected_gradients",
|
||||
"//tensorflow/python/eager:forwardprop",
|
||||
@ -4263,6 +4280,7 @@ py_library(
|
||||
":auto_control_deps_utils",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":handle_data_util",
|
||||
":pywrap_tf_session",
|
||||
":resource_variable_ops_gen",
|
||||
":tensor_shape",
|
||||
@ -4296,6 +4314,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":handle_data_util",
|
||||
":list_ops_gen",
|
||||
],
|
||||
)
|
||||
|
@ -589,6 +589,8 @@ class _EagerDefinedFunction(object):
|
||||
config=config,
|
||||
executor_type=executor_type)
|
||||
|
||||
for i, func_graph_output in enumerate(self._func_graph_outputs):
|
||||
custom_gradient.copy_handle_data(func_graph_output, outputs[i])
|
||||
if executing_eagerly:
|
||||
return outputs
|
||||
else:
|
||||
@ -597,8 +599,6 @@ class _EagerDefinedFunction(object):
|
||||
# once that's done.
|
||||
for i, shape in enumerate(self._output_shapes):
|
||||
outputs[i].set_shape(shape)
|
||||
for i, func_graph_output in enumerate(self._func_graph_outputs):
|
||||
custom_gradient.copy_handle_data(func_graph_output, outputs[i])
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
@ -42,6 +43,7 @@ from tensorflow.python.ops import default_gradient
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_functional_ops
|
||||
from tensorflow.python.ops import gradients_util
|
||||
from tensorflow.python.ops import handle_data_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -286,6 +288,7 @@ def _build_cond(pred,
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
if_op.graph.prevent_fetching(if_op)
|
||||
|
||||
_copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
|
||||
# Return identities for each output of the If op, rather than the output of
|
||||
# the If op directly. This makes pruning work if the output of cond() is
|
||||
# fetched: the lowering pass converts the If outputs into IdentityN outputs,
|
||||
@ -813,6 +816,32 @@ def _get_output_shapes(*branch_graph_outputs):
|
||||
return output_shapes
|
||||
|
||||
|
||||
def _copy_handle_data(external_tensors, *branch_graph_outputs):
|
||||
"""Combines shapes in handle data and sets metadata on `external_tensors`."""
|
||||
for tensors in zip(external_tensors, *branch_graph_outputs):
|
||||
external = tensors[0]
|
||||
internal = tensors[1:]
|
||||
internal_handle_data = []
|
||||
for tensor in internal:
|
||||
handle_data = handle_data_util.get_resource_handle_data(tensor)
|
||||
# NOTE: Assumes handle data has only one ShapeAndType entry. It's
|
||||
# unclear how to combine different lengths across branches.
|
||||
if not handle_data.is_set or len(handle_data.shape_and_type) != 1:
|
||||
break
|
||||
internal_handle_data.append(handle_data)
|
||||
else: # There is handle data, so we need to combine it.
|
||||
combined_shape = tensor_shape.TensorShape(None)
|
||||
for handle_data in internal_handle_data:
|
||||
handle_shape = tensor_shape.TensorShape(
|
||||
handle_data.shape_and_type[0].shape)
|
||||
combined_shape = combined_shape.most_specific_compatible_shape(
|
||||
handle_shape)
|
||||
combined_handle_data = internal_handle_data[0]
|
||||
combined_handle_data.shape_and_type[0].shape.CopyFrom(
|
||||
combined_shape.as_proto())
|
||||
handle_data_util.set_handle_data(external, combined_handle_data)
|
||||
|
||||
|
||||
def verify_captures(op_type, branch_graphs):
|
||||
"""Verify that a branch's tensor is not accessed in another branch fn."""
|
||||
# Note: It is technically not possible for lower-branch_index branches to
|
||||
@ -1143,6 +1172,7 @@ def _build_case(branch_index,
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
case_op.graph.prevent_fetching(case_op)
|
||||
|
||||
_copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
|
||||
# Return identities for each output of the Case op, rather than the output of
|
||||
# the Case op directly. This makes pruning work if the output of switch_case()
|
||||
# is fetched: the lowering pass converts the Case outputs into IdentityN
|
||||
|
@ -17,7 +17,6 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import tape as tape_lib
|
||||
@ -25,6 +24,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import handle_data_util
|
||||
from tensorflow.python.ops import op_selector
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -42,38 +42,8 @@ VAR_OP_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
def copy_handle_data(source_t, target_t):
|
||||
"""Copies HandleData for variant and resource type tensors if available.
|
||||
|
||||
The CppShapeInferenceResult::HandleData proto contains information about the
|
||||
shapes and types of the element tensors of resource/variant type tensors.
|
||||
We need to copy this across function boundaries, i.e., when capturing a
|
||||
placeholder or when returning a function tensor as output. If we don't do this
|
||||
the element tensors will have unknown shapes, e.g., if a TensorList variant
|
||||
tensor is captured as a placeholder, elements popped from that list would have
|
||||
unknown shape.
|
||||
|
||||
Args:
|
||||
source_t: The tensor to copy HandleData from.
|
||||
target_t: The tensor to copy HandleData to.
|
||||
"""
|
||||
if (target_t.dtype == dtypes.resource or
|
||||
target_t.dtype == dtypes.variant):
|
||||
if isinstance(source_t, ops.EagerTensor):
|
||||
handle_data = source_t._handle_data # pylint: disable=protected-access
|
||||
else:
|
||||
handle_data = resource_variable_ops.get_resource_handle_data(source_t)
|
||||
if (handle_data is not None
|
||||
and handle_data.is_set
|
||||
and handle_data.shape_and_type):
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(target_t, ops.EagerTensor):
|
||||
target_t._handle_data = handle_data
|
||||
return
|
||||
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
|
||||
target_t._as_tf_output(),
|
||||
handle_data.SerializeToString())
|
||||
# pylint: enable=protected-access
|
||||
# TODO(allenl): Remove this alias and migrate callers.
|
||||
copy_handle_data = handle_data_util.copy_handle_data
|
||||
|
||||
|
||||
@tf_export("custom_gradient")
|
||||
|
73
tensorflow/python/ops/handle_data_util.py
Normal file
73
tensorflow/python/ops/handle_data_util.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Decorator to overrides the gradient for a function."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def get_resource_handle_data(graph_op):
|
||||
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
handle_data = pywrap_tf_session.GetHandleShapeAndType(
|
||||
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
|
||||
|
||||
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
|
||||
compat.as_bytes(handle_data))
|
||||
|
||||
|
||||
def copy_handle_data(source_t, target_t):
|
||||
"""Copies HandleData for variant and resource type tensors if available.
|
||||
|
||||
The CppShapeInferenceResult::HandleData proto contains information about the
|
||||
shapes and types of the element tensors of resource/variant type tensors.
|
||||
We need to copy this across function boundaries, i.e., when capturing a
|
||||
placeholder or when returning a function tensor as output. If we don't do this
|
||||
the element tensors will have unknown shapes, e.g., if a TensorList variant
|
||||
tensor is captured as a placeholder, elements popped from that list would have
|
||||
unknown shape.
|
||||
|
||||
Args:
|
||||
source_t: The tensor to copy HandleData from.
|
||||
target_t: The tensor to copy HandleData to.
|
||||
"""
|
||||
if (target_t.dtype == dtypes.resource or
|
||||
target_t.dtype == dtypes.variant):
|
||||
if isinstance(source_t, ops.EagerTensor):
|
||||
handle_data = source_t._handle_data # pylint: disable=protected-access
|
||||
else:
|
||||
handle_data = get_resource_handle_data(source_t)
|
||||
if (handle_data is not None
|
||||
and handle_data.is_set
|
||||
and handle_data.shape_and_type):
|
||||
set_handle_data(target_t, handle_data)
|
||||
|
||||
|
||||
def set_handle_data(target_t, handle_data):
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(target_t, ops.EagerTensor):
|
||||
target_t._handle_data = handle_data
|
||||
return
|
||||
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
|
||||
target_t._as_tf_output(),
|
||||
handle_data.SerializeToString())
|
||||
# pylint: enable=protected-access
|
@ -19,11 +19,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_list_ops
|
||||
from tensorflow.python.ops import handle_data_util
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_list_ops import *
|
||||
@ -56,19 +60,45 @@ def empty_tensor_list(element_shape,
|
||||
name=name)
|
||||
|
||||
|
||||
def _set_handle_data(list_handle, element_shape, element_dtype):
|
||||
"""Sets type information on `list_handle` for consistency with graphs."""
|
||||
# TODO(b/169968286): It would be better if we had a consistent story for
|
||||
# creating handle data from eager operations (shared with VarHandleOp).
|
||||
if isinstance(list_handle, ops.EagerTensor):
|
||||
if tensor_util.is_tensor(element_shape):
|
||||
element_shape = tensor_shape.TensorShape(None)
|
||||
elif not isinstance(element_shape, tensor_shape.TensorShape):
|
||||
element_shape = tensor_shape.TensorShape(element_shape)
|
||||
handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
|
||||
handle_data.is_set = True
|
||||
handle_data.shape_and_type.append(
|
||||
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
|
||||
shape=element_shape.as_proto(),
|
||||
dtype=element_dtype.as_datatype_enum,
|
||||
specialized_type=types_pb2.ST_TENSOR_LIST))
|
||||
list_handle._handle_data = handle_data # pylint: disable=protected-access
|
||||
|
||||
|
||||
def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
|
||||
return gen_list_ops.tensor_list_reserve(
|
||||
result = gen_list_ops.tensor_list_reserve(
|
||||
element_shape=_build_element_shape(element_shape),
|
||||
num_elements=num_elements,
|
||||
element_dtype=element_dtype,
|
||||
name=name)
|
||||
# TODO(b/169968286): gen_ops needs to ensure the metadata is properly
|
||||
# populated for eager operations.
|
||||
_set_handle_data(result, element_shape, element_dtype)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_list_from_tensor(tensor, element_shape, name=None):
|
||||
return gen_list_ops.tensor_list_from_tensor(
|
||||
tensor = ops.convert_to_tensor(tensor)
|
||||
result = gen_list_ops.tensor_list_from_tensor(
|
||||
tensor=tensor,
|
||||
element_shape=_build_element_shape(element_shape),
|
||||
name=name)
|
||||
_set_handle_data(result, tensor.shape, tensor.dtype)
|
||||
return result
|
||||
|
||||
|
||||
def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
|
||||
@ -107,16 +137,22 @@ def tensor_list_scatter(tensor,
|
||||
element_shape=None,
|
||||
input_handle=None,
|
||||
name=None):
|
||||
"""Returns a TensorList created or updated by scattering `tensor`."""
|
||||
tensor = ops.convert_to_tensor(tensor)
|
||||
if input_handle is not None:
|
||||
return gen_list_ops.tensor_list_scatter_into_existing_list(
|
||||
output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(
|
||||
input_handle=input_handle, tensor=tensor, indices=indices, name=name)
|
||||
handle_data_util.copy_handle_data(input_handle, output_handle)
|
||||
return output_handle
|
||||
else:
|
||||
return gen_list_ops.tensor_list_scatter_v2(
|
||||
output_handle = gen_list_ops.tensor_list_scatter_v2(
|
||||
tensor=tensor,
|
||||
indices=indices,
|
||||
element_shape=_build_element_shape(element_shape),
|
||||
num_elements=-1,
|
||||
name=name)
|
||||
_set_handle_data(output_handle, element_shape, tensor.dtype)
|
||||
return output_handle
|
||||
|
||||
|
||||
def tensor_list_stack(input_handle,
|
||||
@ -167,8 +203,10 @@ def tensor_list_set_item(input_handle,
|
||||
lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda
|
||||
input_handle, index + 1),
|
||||
lambda: input_handle)
|
||||
return gen_list_ops.tensor_list_set_item(
|
||||
output_handle = gen_list_ops.tensor_list_set_item(
|
||||
input_handle=input_handle, index=index, item=item, name=name)
|
||||
handle_data_util.copy_handle_data(input_handle, output_handle)
|
||||
return output_handle
|
||||
|
||||
|
||||
@ops.RegisterGradient("TensorListPushBack")
|
||||
|
@ -450,6 +450,13 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
|
||||
results of applying fn to tensors unpacked from elems along the first
|
||||
dimension, from first to last.
|
||||
|
||||
Although they are less common as user-visible inputs and outputs, note that
|
||||
tensors of type `tf.variant` which represent tensor lists (for example from
|
||||
`tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list
|
||||
contents rather than the variant itself, and so the container tensor will
|
||||
have a scalar shape when returned rather than the usual stacked shape. This
|
||||
improves the performance of control flow gradient vectorization.
|
||||
|
||||
Raises:
|
||||
ValueError: If vectorization fails and fallback_to_while_loop is False.
|
||||
"""
|
||||
|
@ -43,10 +43,11 @@ from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import cond_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_list_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients as gradient_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
@ -1092,13 +1093,8 @@ class TensorListTest(PForTestCase):
|
||||
def loop_fn(i):
|
||||
l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32)
|
||||
l2 = list_ops.tensor_list_set_item(l2, 1, i)
|
||||
l1_graph = array_ops.identity(l1)
|
||||
# TODO(b/169968286): Typically TensorLists are both created and used in a
|
||||
# graph; creating TensorLists eagerly with handle data doesn't work at the
|
||||
# moment. Copying the handle data manually reproduces the expected case.
|
||||
custom_gradient.copy_handle_data(l2, l1_graph)
|
||||
return list_ops.tensor_list_stack(
|
||||
math_ops.add_n([l1_graph, l2]), dtypes.int32)
|
||||
math_ops.add_n([l1, l2]), dtypes.int32)
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
@ -1571,6 +1567,23 @@ class WhileV2Test(PForTestCase):
|
||||
y = constant_op.constant(np.random.uniform(size=(3, 3)))
|
||||
self.assertAllClose(_f(x, y, True), _f(x, y, False))
|
||||
|
||||
def test_scan(self):
|
||||
np.random.seed(seed=42)
|
||||
data = np.random.randn(3).astype(np.float32)
|
||||
|
||||
def log_prob(x):
|
||||
return math_ops.reduce_sum(functional_ops.scan_v2(
|
||||
lambda _, yi: (x - yi)**2,
|
||||
elems=data,
|
||||
initializer=constant_op.constant(0.)))
|
||||
|
||||
x = variables.Variable(array_ops.ones([2]))
|
||||
self.evaluate(x.initializer)
|
||||
v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x)
|
||||
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||
v_log_prob, (x,), delta=1e-3)
|
||||
self.assertAllClose(theoretical, numerical, rtol=1e-2)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class NestedControlFlowTest(PForTestCase):
|
||||
|
@ -78,6 +78,29 @@ flags.DEFINE_bool(
|
||||
"DEPRECATED: Flag is ignored.")
|
||||
|
||||
|
||||
def _variant_handle_data(t):
|
||||
"""Fetches handle data for a variant tensor `t`, or None if unavailable."""
|
||||
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
||||
if not handle_data.is_set:
|
||||
return None
|
||||
if len(handle_data.shape_and_type) != 1:
|
||||
raise ValueError("Expected handle data of length 1, got {!r} of length {}"
|
||||
.format(handle_data, len(handle_data.shape_and_type)))
|
||||
return handle_data.shape_and_type[0]
|
||||
|
||||
|
||||
def _is_tensor_list(t):
|
||||
"""True if `t` is a TensorList, False if it isn't, None if unknown."""
|
||||
if t.dtype != dtypes.variant:
|
||||
return False
|
||||
shape_and_type = _variant_handle_data(t)
|
||||
if shape_and_type is None:
|
||||
# TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we
|
||||
# can make this an error instead of assuming TensorLists have handle data.
|
||||
return None # Presumed not a TensorList
|
||||
return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST
|
||||
|
||||
|
||||
def _stack(t, length):
|
||||
"""stacks `t` `length` times."""
|
||||
# Note that this stacking may currently be triggered, for example, when a
|
||||
@ -86,13 +109,9 @@ def _stack(t, length):
|
||||
# suitable since operations on stacked handles may expect a vectorized version
|
||||
# of the variant.
|
||||
if t.dtype == dtypes.variant:
|
||||
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
||||
if not handle_data.is_set:
|
||||
shape_and_type = _variant_handle_data(t)
|
||||
if shape_and_type is None:
|
||||
raise ValueError("Required handle data not set for {!r}".format(t))
|
||||
if len(handle_data.shape_and_type) != 1:
|
||||
raise ValueError("Expected handle data of length 1, got {!r} of length {}"
|
||||
.format(handle_data, len(handle_data.shape_and_type)))
|
||||
shape_and_type = handle_data.shape_and_type[0]
|
||||
if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST:
|
||||
return wrap(
|
||||
_stack_tensor_list(t, shape_and_type.dtype, length),
|
||||
@ -1606,7 +1625,10 @@ class PFor(object):
|
||||
else:
|
||||
batch_dim = tensor_shape.TensorShape(loop_len)
|
||||
output_shape = batch_dim.concatenate(output_shape)
|
||||
new_output.t.set_shape(output_shape)
|
||||
if _is_tensor_list(new_output.t):
|
||||
new_output.t.set_shape([])
|
||||
else:
|
||||
new_output.t.set_shape(output_shape)
|
||||
self._add_conversion(old_output, new_output)
|
||||
stack.pop(0)
|
||||
|
||||
@ -3576,6 +3598,9 @@ def _stack_tensor_list_shape(shape, first_dim):
|
||||
|
||||
def _tile_variant_with_length(t, length):
|
||||
"""stacks `t` `length` times."""
|
||||
if _is_tensor_list(t):
|
||||
# The content of TensorLists is vectorized, not the variant itself.
|
||||
return t
|
||||
original_tensor = t
|
||||
t.set_shape([])
|
||||
t = array_ops.reshape(t, [-1])
|
||||
@ -3593,6 +3618,13 @@ def _tile_variant(t, pfor_input):
|
||||
|
||||
|
||||
def _untile_variant(t):
|
||||
if _is_tensor_list(t):
|
||||
# The content of TensorLists is vectorized, not the variant itself.
|
||||
if not t.shape.is_compatible_with([]):
|
||||
raise AssertionError(
|
||||
"Unexpectedly saw a TensorList with non-scalar shape: {!r}"
|
||||
.format(t))
|
||||
return t
|
||||
return array_ops.gather(t, 0)
|
||||
|
||||
|
||||
@ -4201,8 +4233,12 @@ class WhileV2(object):
|
||||
shapes = [tensor_shape.TensorShape(shape) for shape in shapes]
|
||||
for i, shape in enumerate(shapes):
|
||||
shape = shape.merge_with(output_shapes[i])
|
||||
if self._pfor_input.input(i).is_stacked:
|
||||
shape = tensor_shape.TensorShape([None]).concatenate(shape)
|
||||
pfor_input = self._pfor_input.input(i)
|
||||
if pfor_input.is_stacked:
|
||||
if _is_tensor_list(pfor_input.t):
|
||||
shape = tensor_shape.TensorShape([]).concatenate(shape)
|
||||
else:
|
||||
shape = tensor_shape.TensorShape([None]).concatenate(shape)
|
||||
output_shapes[i] = shape
|
||||
assert len(output_shapes) == self._pfor_input.num_inputs
|
||||
return output_shapes
|
||||
|
@ -43,6 +43,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import gen_state_ops
|
||||
from tensorflow.python.ops import handle_data_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -62,14 +63,8 @@ acd.register_read_only_resource_op("ResourceGatherNd")
|
||||
acd.register_read_only_resource_op("_ReadVariablesOp")
|
||||
|
||||
|
||||
def get_resource_handle_data(graph_op):
|
||||
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
handle_data = pywrap_tf_session.GetHandleShapeAndType(
|
||||
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
|
||||
|
||||
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
|
||||
compat.as_bytes(handle_data))
|
||||
# TODO(allenl): Remove this alias and migrate callers.
|
||||
get_resource_handle_data = handle_data_util.get_resource_handle_data
|
||||
|
||||
|
||||
def get_eager_safe_handle_data(handle):
|
||||
|
Loading…
Reference in New Issue
Block a user