Record the type of a variant tensor with its handle data

Only used for TensorList at the moment, but we'll likely need to add similar information for other variants as they'll also need special cases in tf.vectorized_map.

PiperOrigin-RevId: 334708004
Change-Id: I484845f855c5c9cfeea13b78a520a1d7c60c9fc5
This commit is contained in:
Allen Lavoie 2020-09-30 16:54:40 -07:00 committed by TensorFlower Gardener
parent 0fa896962f
commit 8310eda05e
9 changed files with 50 additions and 23 deletions

View File

@ -136,6 +136,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
auto* out_shape_and_type = handle_data.add_shape_and_type();
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
out_shape_and_type->set_dtype(p.dtype);
out_shape_and_type->set_specialized_type(p.specialized_type);
}
}
string result;
@ -163,7 +164,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
if (TF_GetCode(status) != TF_OK) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
shape_and_type_proto.specialized_type());
}
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
}

View File

@ -133,9 +133,14 @@ struct DimensionOrConstant {
struct ShapeAndType {
ShapeAndType() {}
ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
ShapeAndType(ShapeHandle s, DataType t, SpecializedType specialized_t)
: shape(s), dtype(t), specialized_type(specialized_t) {}
ShapeHandle shape;
DataType dtype = DT_INVALID;
// The type of a variant-dtype tensor sometimes affects graph building
// (e.g. for vectorization), and needs to be know statically in such cases.
SpecializedType specialized_type = ST_INVALID;
};
// Shape inference functions registered on ops in REGISTER_OP implement

View File

@ -74,3 +74,14 @@ enum DataType {
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)
// For identifying the underlying type of a variant. For variants, the types
// listed here are a subset of the types in the variant type registry,
// corresponding to commonly used variants which must occasionally be
// special-cased.
enum SpecializedType {
// Invalid/unknown specialized type.
ST_INVALID = 0;
// "tensorflow::TensorList" in the variant type registry.
ST_TENSOR_LIST = 1;
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
@ -369,7 +370,7 @@ REGISTER_OP("TensorListFromTensor")
&tensor_shape_except_first_dim));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{element_shape, element_dtype}});
{element_shape, element_dtype, ST_TENSOR_LIST}});
return Status::OK();
});
@ -409,7 +410,7 @@ REGISTER_OP("TensorListReserve")
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{element_shape, element_dtype}});
{element_shape, element_dtype, ST_TENSOR_LIST}});
return Status::OK();
});
@ -481,7 +482,7 @@ REGISTER_OP("TensorListSetItem")
c->set_output_handle_shapes_and_types(0, *handle_data);
} else {
c->set_output_handle_shapes_and_types(
0, {{c->UnknownShape(), element_dtype}});
0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}});
}
return Status::OK();
});
@ -532,8 +533,8 @@ REGISTER_OP("TensorListScatter")
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
2, &element_shape));
c->set_output_handle_shapes_and_types(0,
{{element_shape, element_dtype}});
c->set_output_handle_shapes_and_types(
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
c->set_output(0, c->Scalar());
return Status::OK();
});
@ -552,8 +553,8 @@ REGISTER_OP("TensorListScatterV2")
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
2, &element_shape));
c->set_output_handle_shapes_and_types(0,
{{element_shape, element_dtype}});
c->set_output_handle_shapes_and_types(
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
c->set_output(0, c->Scalar());
return Status::OK();
});
@ -580,8 +581,8 @@ REGISTER_OP("TensorListScatterIntoExistingList")
TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
element_shape = GetElementShapeFromHandleData(*handle_data);
}
c->set_output_handle_shapes_and_types(0,
{{element_shape, element_dtype}});
c->set_output_handle_shapes_and_types(
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
c->set_output(0, c->Scalar());
return Status::OK();
});
@ -606,7 +607,7 @@ REGISTER_OP("TensorListConcatLists")
bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
c->set_output_handle_shapes_and_types(
0, {{c->UnknownShape(), element_dtype}});
0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}});
return Status::OK();
}
shape_inference::ShapeAndType list_shape_type_a =

View File

@ -2174,6 +2174,7 @@ class ConcreteFunction(object):
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
custom_gradient.copy_handle_data(self.outputs[j], result[j])
outputs_list[i] = result[j]
j += 1
ret = nest.pack_sequence_as(self._func_graph.structured_outputs,

View File

@ -11,6 +11,10 @@ message CppShapeInferenceResult {
message HandleShapeAndType {
TensorShapeProto shape = 1;
DataType dtype = 2;
// For dtype==DT_VARIANT, specialized_type may indicate a more specific
// type. For other dtypes or when the information is unavailable it is set
// to ST_INVALID.
SpecializedType specialized_type = 3;
}
message HandleData {
bool is_set = 1;

View File

@ -121,6 +121,7 @@ cuda_py_test(
"noasan", # TODO(b/155406705): flaky
],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",

View File

@ -22,6 +22,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np # pylint: disable=unused-import
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -40,6 +41,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope as vs
@ -1600,9 +1602,18 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def func():
t = constant_op.constant([1., 2., 3.])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
handle_data = resource_variable_ops.get_eager_safe_handle_data(l)
self.assertTrue(handle_data.is_set)
self.assertEqual(types_pb2.ST_TENSOR_LIST,
handle_data.shape_and_type[0].specialized_type)
return l
tensor_list = func()
handle_data = resource_variable_ops.get_eager_safe_handle_data(tensor_list)
self.assertTrue(handle_data.is_set)
self.assertEqual(dtypes.float32, handle_data.shape_and_type[0].dtype)
self.assertEqual(types_pb2.ST_TENSOR_LIST,
handle_data.shape_and_type[0].specialized_type)
element = list_ops.tensor_list_get_item(
tensor_list, 0, element_dtype=dtypes.float32)
self.assertAllEqual(element.shape.as_list(), [])

View File

@ -67,22 +67,13 @@ def copy_handle_data(source_t, target_t):
and handle_data.is_set
and handle_data.shape_and_type):
# pylint: disable=protected-access
if isinstance(target_t, ops.EagerTensor):
target_t._handle_data = handle_data
return
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
target_t._as_tf_output(),
handle_data.SerializeToString())
# pylint: enable=protected-access
# Ensure that shapes and dtypes are propagated.
shapes, types = zip(*[(pair.shape, pair.dtype)
for pair in handle_data.shape_and_type])
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
if not s.unknown_rank else None for s in shapes]
pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
target_t._op._graph._c_graph, # pylint: disable=protected-access
target_t._as_tf_output(), # pylint: disable=protected-access
shapes,
ranks,
types)
@tf_export("custom_gradient")