Record the type of a variant tensor with its handle data
Only used for TensorList at the moment, but we'll likely need to add similar information for other variants as they'll also need special cases in tf.vectorized_map. PiperOrigin-RevId: 334708004 Change-Id: I484845f855c5c9cfeea13b78a520a1d7c60c9fc5
This commit is contained in:
parent
0fa896962f
commit
8310eda05e
tensorflow
c
core
python
eager
framework
kernel_tests
ops
@ -136,6 +136,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
|
|||||||
auto* out_shape_and_type = handle_data.add_shape_and_type();
|
auto* out_shape_and_type = handle_data.add_shape_and_type();
|
||||||
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
|
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
|
||||||
out_shape_and_type->set_dtype(p.dtype);
|
out_shape_and_type->set_dtype(p.dtype);
|
||||||
|
out_shape_and_type->set_specialized_type(p.specialized_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
string result;
|
string result;
|
||||||
@ -163,7 +164,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
|||||||
status->status =
|
status->status =
|
||||||
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
|
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
|
||||||
|
shape_and_type_proto.specialized_type());
|
||||||
}
|
}
|
||||||
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
||||||
}
|
}
|
||||||
|
@ -133,9 +133,14 @@ struct DimensionOrConstant {
|
|||||||
struct ShapeAndType {
|
struct ShapeAndType {
|
||||||
ShapeAndType() {}
|
ShapeAndType() {}
|
||||||
ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
|
ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
|
||||||
|
ShapeAndType(ShapeHandle s, DataType t, SpecializedType specialized_t)
|
||||||
|
: shape(s), dtype(t), specialized_type(specialized_t) {}
|
||||||
|
|
||||||
ShapeHandle shape;
|
ShapeHandle shape;
|
||||||
DataType dtype = DT_INVALID;
|
DataType dtype = DT_INVALID;
|
||||||
|
// The type of a variant-dtype tensor sometimes affects graph building
|
||||||
|
// (e.g. for vectorization), and needs to be know statically in such cases.
|
||||||
|
SpecializedType specialized_type = ST_INVALID;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Shape inference functions registered on ops in REGISTER_OP implement
|
// Shape inference functions registered on ops in REGISTER_OP implement
|
||||||
|
@ -74,3 +74,14 @@ enum DataType {
|
|||||||
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
|
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
|
||||||
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
|
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
|
||||||
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)
|
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)
|
||||||
|
|
||||||
|
// For identifying the underlying type of a variant. For variants, the types
|
||||||
|
// listed here are a subset of the types in the variant type registry,
|
||||||
|
// corresponding to commonly used variants which must occasionally be
|
||||||
|
// special-cased.
|
||||||
|
enum SpecializedType {
|
||||||
|
// Invalid/unknown specialized type.
|
||||||
|
ST_INVALID = 0;
|
||||||
|
// "tensorflow::TensorList" in the variant type registry.
|
||||||
|
ST_TENSOR_LIST = 1;
|
||||||
|
}
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -369,7 +370,7 @@ REGISTER_OP("TensorListFromTensor")
|
|||||||
&tensor_shape_except_first_dim));
|
&tensor_shape_except_first_dim));
|
||||||
c->set_output_handle_shapes_and_types(
|
c->set_output_handle_shapes_and_types(
|
||||||
0, std::vector<shape_inference::ShapeAndType>{
|
0, std::vector<shape_inference::ShapeAndType>{
|
||||||
{element_shape, element_dtype}});
|
{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -409,7 +410,7 @@ REGISTER_OP("TensorListReserve")
|
|||||||
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
|
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
|
||||||
c->set_output_handle_shapes_and_types(
|
c->set_output_handle_shapes_and_types(
|
||||||
0, std::vector<shape_inference::ShapeAndType>{
|
0, std::vector<shape_inference::ShapeAndType>{
|
||||||
{element_shape, element_dtype}});
|
{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -481,7 +482,7 @@ REGISTER_OP("TensorListSetItem")
|
|||||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||||
} else {
|
} else {
|
||||||
c->set_output_handle_shapes_and_types(
|
c->set_output_handle_shapes_and_types(
|
||||||
0, {{c->UnknownShape(), element_dtype}});
|
0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}});
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
@ -532,8 +533,8 @@ REGISTER_OP("TensorListScatter")
|
|||||||
shape_inference::ShapeHandle element_shape;
|
shape_inference::ShapeHandle element_shape;
|
||||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
||||||
2, &element_shape));
|
2, &element_shape));
|
||||||
c->set_output_handle_shapes_and_types(0,
|
c->set_output_handle_shapes_and_types(
|
||||||
{{element_shape, element_dtype}});
|
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||||
c->set_output(0, c->Scalar());
|
c->set_output(0, c->Scalar());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
@ -552,8 +553,8 @@ REGISTER_OP("TensorListScatterV2")
|
|||||||
shape_inference::ShapeHandle element_shape;
|
shape_inference::ShapeHandle element_shape;
|
||||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
||||||
2, &element_shape));
|
2, &element_shape));
|
||||||
c->set_output_handle_shapes_and_types(0,
|
c->set_output_handle_shapes_and_types(
|
||||||
{{element_shape, element_dtype}});
|
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||||
c->set_output(0, c->Scalar());
|
c->set_output(0, c->Scalar());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
@ -580,8 +581,8 @@ REGISTER_OP("TensorListScatterIntoExistingList")
|
|||||||
TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
|
TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
|
||||||
element_shape = GetElementShapeFromHandleData(*handle_data);
|
element_shape = GetElementShapeFromHandleData(*handle_data);
|
||||||
}
|
}
|
||||||
c->set_output_handle_shapes_and_types(0,
|
c->set_output_handle_shapes_and_types(
|
||||||
{{element_shape, element_dtype}});
|
0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
|
||||||
c->set_output(0, c->Scalar());
|
c->set_output(0, c->Scalar());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
@ -606,7 +607,7 @@ REGISTER_OP("TensorListConcatLists")
|
|||||||
bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
|
bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
|
||||||
if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
|
if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
|
||||||
c->set_output_handle_shapes_and_types(
|
c->set_output_handle_shapes_and_types(
|
||||||
0, {{c->UnknownShape(), element_dtype}});
|
0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
shape_inference::ShapeAndType list_shape_type_a =
|
shape_inference::ShapeAndType list_shape_type_a =
|
||||||
|
@ -2174,6 +2174,7 @@ class ConcreteFunction(object):
|
|||||||
j = 0
|
j = 0
|
||||||
for i, o in enumerate(outputs_list):
|
for i, o in enumerate(outputs_list):
|
||||||
if o is not None:
|
if o is not None:
|
||||||
|
custom_gradient.copy_handle_data(self.outputs[j], result[j])
|
||||||
outputs_list[i] = result[j]
|
outputs_list[i] = result[j]
|
||||||
j += 1
|
j += 1
|
||||||
ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
|
ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
|
||||||
|
@ -11,6 +11,10 @@ message CppShapeInferenceResult {
|
|||||||
message HandleShapeAndType {
|
message HandleShapeAndType {
|
||||||
TensorShapeProto shape = 1;
|
TensorShapeProto shape = 1;
|
||||||
DataType dtype = 2;
|
DataType dtype = 2;
|
||||||
|
// For dtype==DT_VARIANT, specialized_type may indicate a more specific
|
||||||
|
// type. For other dtypes or when the information is unavailable it is set
|
||||||
|
// to ST_INVALID.
|
||||||
|
SpecializedType specialized_type = 3;
|
||||||
}
|
}
|
||||||
message HandleData {
|
message HandleData {
|
||||||
bool is_set = 1;
|
bool is_set = 1;
|
||||||
|
@ -121,6 +121,7 @@ cuda_py_test(
|
|||||||
"noasan", # TODO(b/155406705): flaky
|
"noasan", # TODO(b/155406705): flaky
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np # pylint: disable=unused-import
|
import numpy as np # pylint: disable=unused-import
|
||||||
|
|
||||||
|
from tensorflow.core.framework import types_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -40,6 +41,7 @@ from tensorflow.python.ops import gradients_impl
|
|||||||
from tensorflow.python.ops import list_ops
|
from tensorflow.python.ops import list_ops
|
||||||
from tensorflow.python.ops import map_fn
|
from tensorflow.python.ops import map_fn
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
@ -1600,9 +1602,18 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
def func():
|
def func():
|
||||||
t = constant_op.constant([1., 2., 3.])
|
t = constant_op.constant([1., 2., 3.])
|
||||||
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
|
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
|
||||||
|
handle_data = resource_variable_ops.get_eager_safe_handle_data(l)
|
||||||
|
self.assertTrue(handle_data.is_set)
|
||||||
|
self.assertEqual(types_pb2.ST_TENSOR_LIST,
|
||||||
|
handle_data.shape_and_type[0].specialized_type)
|
||||||
return l
|
return l
|
||||||
|
|
||||||
tensor_list = func()
|
tensor_list = func()
|
||||||
|
handle_data = resource_variable_ops.get_eager_safe_handle_data(tensor_list)
|
||||||
|
self.assertTrue(handle_data.is_set)
|
||||||
|
self.assertEqual(dtypes.float32, handle_data.shape_and_type[0].dtype)
|
||||||
|
self.assertEqual(types_pb2.ST_TENSOR_LIST,
|
||||||
|
handle_data.shape_and_type[0].specialized_type)
|
||||||
element = list_ops.tensor_list_get_item(
|
element = list_ops.tensor_list_get_item(
|
||||||
tensor_list, 0, element_dtype=dtypes.float32)
|
tensor_list, 0, element_dtype=dtypes.float32)
|
||||||
self.assertAllEqual(element.shape.as_list(), [])
|
self.assertAllEqual(element.shape.as_list(), [])
|
||||||
|
@ -67,22 +67,13 @@ def copy_handle_data(source_t, target_t):
|
|||||||
and handle_data.is_set
|
and handle_data.is_set
|
||||||
and handle_data.shape_and_type):
|
and handle_data.shape_and_type):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
if isinstance(target_t, ops.EagerTensor):
|
||||||
|
target_t._handle_data = handle_data
|
||||||
|
return
|
||||||
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
|
pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
|
||||||
target_t._as_tf_output(),
|
target_t._as_tf_output(),
|
||||||
handle_data.SerializeToString())
|
handle_data.SerializeToString())
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
# Ensure that shapes and dtypes are propagated.
|
|
||||||
shapes, types = zip(*[(pair.shape, pair.dtype)
|
|
||||||
for pair in handle_data.shape_and_type])
|
|
||||||
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
|
|
||||||
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
|
|
||||||
if not s.unknown_rank else None for s in shapes]
|
|
||||||
pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
|
|
||||||
target_t._op._graph._c_graph, # pylint: disable=protected-access
|
|
||||||
target_t._as_tf_output(), # pylint: disable=protected-access
|
|
||||||
shapes,
|
|
||||||
ranks,
|
|
||||||
types)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("custom_gradient")
|
@tf_export("custom_gradient")
|
||||||
|
Loading…
Reference in New Issue
Block a user