Record the type of a variant tensor with its handle data
Only used for TensorList at the moment, but we'll likely need to add similar information for other variants as they'll also need special cases in tf.vectorized_map. PiperOrigin-RevId: 334708004 Change-Id: I484845f855c5c9cfeea13b78a520a1d7c60c9fc5
This commit is contained in:
parent
0fa896962f
commit
8310eda05e
tensorflow
c
core
python
eager
framework
kernel_tests
ops
@ -136,6 +136,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
|
||||
auto* out_shape_and_type = handle_data.add_shape_and_type();
|
||||
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
|
||||
out_shape_and_type->set_dtype(p.dtype);
|
||||
out_shape_and_type->set_specialized_type(p.specialized_type);
|
||||
}
|
||||
}
|
||||
string result;
|
||||
@ -163,7 +164,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
||||
status->status =
|
||||
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
|
||||
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
|
||||
shape_and_type_proto.specialized_type());
|
||||
}
|
||||
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
||||
}
|
||||
|
@ -133,9 +133,14 @@ struct DimensionOrConstant {
|
||||
struct ShapeAndType {
|
||||
ShapeAndType() {}
|
||||
ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
|
||||
ShapeAndType(ShapeHandle s, DataType t, SpecializedType specialized_t)
|
||||
: shape(s), dtype(t), specialized_type(specialized_t) {}
|
||||
|
||||
ShapeHandle shape;
|
||||
DataType dtype = DT_INVALID;
|
||||
// The type of a variant-dtype tensor sometimes affects graph building
|
||||
// (e.g. for vectorization), and needs to be know statically in such cases.
|
||||
SpecializedType specialized_type = ST_INVALID;
|
||||
};
|
||||
|
||||
// Shape inference functions registered on ops in REGISTER_OP implement
|
||||
|
@ -74,3 +74,14 @@ enum DataType {
|
||||
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
|
||||
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
|
||||
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)
|
||||
|
||||
// For identifying the underlying type of a variant. For variants, the types
|
||||
// listed here are a subset of the types in the variant type registry,
|
||||
// corresponding to commonly used variants which must occasionally be
|
||||
// special-cased.
|
||||
enum SpecializedType {
|
||||
// Invalid/unknown specialized type.
|
||||
ST_INVALID = 0;
|
||||
// "tensorflow::TensorList" in the variant type registry.
|
||||
ST_TENSOR_LIST = 1;
|
||||
}
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -369,7 +370,7 @@ REGISTER_OP("TensorListFromTensor")
|
||||
&tensor_shape_except_first_dim));
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{
|
||||
{element_shape, element_dtype}});
|
||||
{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
@ -409,7 +410,7 @@ REGISTER_OP("TensorListReserve")
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{
|
||||
{element_shape, element_dtype}});
|
||||
{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
@ -481,7 +482,7 @@ REGISTER_OP("TensorListSetItem")
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
} else {
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, {{c->UnknownShape(), element_dtype}});
|
||||
0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}});
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
@ -532,8 +533,8 @@ REGISTER_OP("TensorListScatter")
|
||||
shape_inference::ShapeHandle element_shape;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
||||
2, &element_shape));
|
||||
c->set_output_handle_shapes_and_types(0,
|
||||
{{element_shape, element_dtype}});
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
@ -552,8 +553,8 @@ REGISTER_OP("TensorListScatterV2")
|
||||
shape_inference::ShapeHandle element_shape;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
||||
2, &element_shape));
|
||||
c->set_output_handle_shapes_and_types(0,
|
||||
{{element_shape, element_dtype}});
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
@ -580,8 +581,8 @@ REGISTER_OP("TensorListScatterIntoExistingList")
|
||||
TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
|
||||
element_shape = GetElementShapeFromHandleData(*handle_data);
|
||||
}
|
||||
c->set_output_handle_shapes_and_types(0,
|
||||
{{element_shape, element_dtype}});
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||
c->set_output(0, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
@ -606,7 +607,7 @@ REGISTER_OP("TensorListConcatLists")
|
||||
bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
|
||||
if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, {{c->UnknownShape(), element_dtype}});
|
||||
0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}});
|
||||
return Status::OK();
|
||||
}
|
||||
shape_inference::ShapeAndType list_shape_type_a =
|
||||
|
@ -2174,6 +2174,7 @@ class ConcreteFunction(object):
|
||||
j = 0
|
||||
for i, o in enumerate(outputs_list):
|
||||
if o is not None:
|
||||
custom_gradient.copy_handle_data(self.outputs[j], result[j])
|
||||
outputs_list[i] = result[j]
|
||||
j += 1
|
||||
ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
|
||||
|
@ -11,6 +11,10 @@ message CppShapeInferenceResult {
|
||||
message HandleShapeAndType {
|
||||
TensorShapeProto shape = 1;
|
||||
DataType dtype = 2;
|
||||
// For dtype==DT_VARIANT, specialized_type may indicate a more specific
|
||||
// type. For other dtypes or when the information is unavailable it is set
|
||||
// to ST_INVALID.
|
||||
SpecializedType specialized_type = 3;
|
||||
}
|
||||
message HandleData {
|
||||
bool is_set = 1;
|
||||
|
@ -121,6 +121,7 @@ cuda_py_test(
|
||||
"noasan", # TODO(b/155406705): flaky
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np # pylint: disable=unused-import
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -40,6 +41,7 @@ from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
@ -1600,9 +1602,18 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
def func():
|
||||
t = constant_op.constant([1., 2., 3.])
|
||||
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
|
||||
handle_data = resource_variable_ops.get_eager_safe_handle_data(l)
|
||||
self.assertTrue(handle_data.is_set)
|
||||
self.assertEqual(types_pb2.ST_TENSOR_LIST,
|
||||
handle_data.shape_and_type[0].specialized_type)
|
||||
return l
|
||||
|
||||
tensor_list = func()
|
||||
handle_data = resource_variable_ops.get_eager_safe_handle_data(tensor_list)
|
||||
self.assertTrue(handle_data.is_set)
|
||||
self.assertEqual(dtypes.float32, handle_data.shape_and_type[0].dtype)
|
||||
self.assertEqual(types_pb2.ST_TENSOR_LIST,
|
||||
handle_data.shape_and_type[0].specialized_type)
|
||||
element = list_ops.tensor_list_get_item(
|
||||
tensor_list, 0, element_dtype=dtypes.float32)
|
||||
self.assertAllEqual(element.shape.as_list(), [])
|
||||
|
@ -67,22 +67,13 @@ def copy_handle_data(source_t, target_t):
|
||||
and handle_data.is_set
|
||||
and handle_data.shape_and_type):
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(target_t, ops.EagerTensor):
|
||||
target_t._handle_data = handle_data
|
||||
return
|
||||
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
|
||||
target_t._as_tf_output(),
|
||||
handle_data.SerializeToString())
|
||||
# pylint: enable=protected-access
|
||||
# Ensure that shapes and dtypes are propagated.
|
||||
shapes, types = zip(*[(pair.shape, pair.dtype)
|
||||
for pair in handle_data.shape_and_type])
|
||||
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
|
||||
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
|
||||
if not s.unknown_rank else None for s in shapes]
|
||||
pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
||||
target_t._op._graph._c_graph, # pylint: disable=protected-access
|
||||
target_t._as_tf_output(), # pylint: disable=protected-access
|
||||
shapes,
|
||||
ranks,
|
||||
types)
|
||||
|
||||
|
||||
@tf_export("custom_gradient")
|
||||
|
Loading…
Reference in New Issue
Block a user