diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 3bdaa866ee6..8d7a0cd3a18 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -136,6 +136,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { auto* out_shape_and_type = handle_data.add_shape_and_type(); ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); out_shape_and_type->set_dtype(p.dtype); + out_shape_and_type->set_specialized_type(p.specialized_type); } } string result; @@ -163,7 +164,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, status->status = ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); if (TF_GetCode(status) != TF_OK) return; - shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), + shape_and_type_proto.specialized_type()); } ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); } diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index bb79b278cb1..10b54476d18 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -133,9 +133,14 @@ struct DimensionOrConstant { struct ShapeAndType { ShapeAndType() {} ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {} + ShapeAndType(ShapeHandle s, DataType t, SpecializedType specialized_t) + : shape(s), dtype(t), specialized_type(specialized_t) {} ShapeHandle shape; DataType dtype = DT_INVALID; + // The type of a variant-dtype tensor sometimes affects graph building + // (e.g. for vectorization), and needs to be know statically in such cases. + SpecializedType specialized_type = ST_INVALID; }; // Shape inference functions registered on ops in REGISTER_OP implement diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto index 900132c0db9..61549ae08ce 100644 --- a/tensorflow/core/framework/types.proto +++ b/tensorflow/core/framework/types.proto @@ -74,3 +74,14 @@ enum DataType { // https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, // https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, // https://www.tensorflow.org/code/tensorflow/python/framework/function.py) + +// For identifying the underlying type of a variant. For variants, the types +// listed here are a subset of the types in the variant type registry, +// corresponding to commonly used variants which must occasionally be +// special-cased. +enum SpecializedType { + // Invalid/unknown specialized type. + ST_INVALID = 0; + // "tensorflow::TensorList" in the variant type registry. + ST_TENSOR_LIST = 1; +} \ No newline at end of file diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index 4ad676c37ea..91bcc3be49a 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -369,7 +370,7 @@ REGISTER_OP("TensorListFromTensor") &tensor_shape_except_first_dim)); c->set_output_handle_shapes_and_types( 0, std::vector<shape_inference::ShapeAndType>{ - {element_shape, element_dtype}}); + {element_shape, element_dtype, ST_TENSOR_LIST}}); return Status::OK(); }); @@ -409,7 +410,7 @@ REGISTER_OP("TensorListReserve") TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype)); c->set_output_handle_shapes_and_types( 0, std::vector<shape_inference::ShapeAndType>{ - {element_shape, element_dtype}}); + {element_shape, element_dtype, ST_TENSOR_LIST}}); return Status::OK(); }); @@ -481,7 +482,7 @@ REGISTER_OP("TensorListSetItem") c->set_output_handle_shapes_and_types(0, *handle_data); } else { c->set_output_handle_shapes_and_types( - 0, {{c->UnknownShape(), element_dtype}}); + 0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}}); } return Status::OK(); }); @@ -532,8 +533,8 @@ REGISTER_OP("TensorListScatter") shape_inference::ShapeHandle element_shape; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( 2, &element_shape)); - c->set_output_handle_shapes_and_types(0, - {{element_shape, element_dtype}}); + c->set_output_handle_shapes_and_types( + 0, {{element_shape, element_dtype, ST_TENSOR_LIST}}); c->set_output(0, c->Scalar()); return Status::OK(); }); @@ -552,8 +553,8 @@ REGISTER_OP("TensorListScatterV2") shape_inference::ShapeHandle element_shape; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( 2, &element_shape)); - c->set_output_handle_shapes_and_types(0, - {{element_shape, element_dtype}}); + c->set_output_handle_shapes_and_types( + 0, {{element_shape, element_dtype, ST_TENSOR_LIST}}); c->set_output(0, c->Scalar()); return Status::OK(); }); @@ -580,8 +581,8 @@ REGISTER_OP("TensorListScatterIntoExistingList") TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype)); element_shape = GetElementShapeFromHandleData(*handle_data); } - c->set_output_handle_shapes_and_types(0, - {{element_shape, element_dtype}}); + c->set_output_handle_shapes_and_types( + 0, {{element_shape, element_dtype, ST_TENSOR_LIST}}); c->set_output(0, c->Scalar()); return Status::OK(); }); @@ -606,7 +607,7 @@ REGISTER_OP("TensorListConcatLists") bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty(); if (!(handle_data_a_nonempty || handle_data_b_nonempty)) { c->set_output_handle_shapes_and_types( - 0, {{c->UnknownShape(), element_dtype}}); + 0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}}); return Status::OK(); } shape_inference::ShapeAndType list_shape_type_a = diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 02f167b4688..ba424193532 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2174,6 +2174,7 @@ class ConcreteFunction(object): j = 0 for i, o in enumerate(outputs_list): if o is not None: + custom_gradient.copy_handle_data(self.outputs[j], result[j]) outputs_list[i] = result[j] j += 1 ret = nest.pack_sequence_as(self._func_graph.structured_outputs, diff --git a/tensorflow/python/framework/cpp_shape_inference.proto b/tensorflow/python/framework/cpp_shape_inference.proto index 1bf14570292..aa4df78c40b 100644 --- a/tensorflow/python/framework/cpp_shape_inference.proto +++ b/tensorflow/python/framework/cpp_shape_inference.proto @@ -11,6 +11,10 @@ message CppShapeInferenceResult { message HandleShapeAndType { TensorShapeProto shape = 1; DataType dtype = 2; + // For dtype==DT_VARIANT, specialized_type may indicate a more specific + // type. For other dtypes or when the information is unavailable it is set + // to ST_INVALID. + SpecializedType specialized_type = 3; } message HandleData { bool is_set = 1; diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 33838aa502e..d874f4f685c 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -121,6 +121,7 @@ cuda_py_test( "noasan", # TODO(b/155406705): flaky ], deps = [ + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index f792cda6ea1..f2f6dc33b84 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -22,6 +22,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np # pylint: disable=unused-import +from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -40,6 +41,7 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope as vs @@ -1600,9 +1602,18 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): def func(): t = constant_op.constant([1., 2., 3.]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + handle_data = resource_variable_ops.get_eager_safe_handle_data(l) + self.assertTrue(handle_data.is_set) + self.assertEqual(types_pb2.ST_TENSOR_LIST, + handle_data.shape_and_type[0].specialized_type) return l tensor_list = func() + handle_data = resource_variable_ops.get_eager_safe_handle_data(tensor_list) + self.assertTrue(handle_data.is_set) + self.assertEqual(dtypes.float32, handle_data.shape_and_type[0].dtype) + self.assertEqual(types_pb2.ST_TENSOR_LIST, + handle_data.shape_and_type[0].specialized_type) element = list_ops.tensor_list_get_item( tensor_list, 0, element_dtype=dtypes.float32) self.assertAllEqual(element.shape.as_list(), []) diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index f081f036b58..33156f7c9c7 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -67,22 +67,13 @@ def copy_handle_data(source_t, target_t): and handle_data.is_set and handle_data.shape_and_type): # pylint: disable=protected-access + if isinstance(target_t, ops.EagerTensor): + target_t._handle_data = handle_data + return pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph, target_t._as_tf_output(), handle_data.SerializeToString()) # pylint: enable=protected-access - # Ensure that shapes and dtypes are propagated. - shapes, types = zip(*[(pair.shape, pair.dtype) - for pair in handle_data.shape_and_type]) - ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension - if not s.unknown_rank else None for s in shapes] - pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - target_t._op._graph._c_graph, # pylint: disable=protected-access - target_t._as_tf_output(), # pylint: disable=protected-access - shapes, - ranks, - types) @tf_export("custom_gradient")