Change shape inference so that a single resource tensor can carry

shape and type information for multiple tensors.

Apply this to QueueDequeueV2 handled by grappler.

PiperOrigin-RevId: 157163757
This commit is contained in:
A. Unique TensorFlower 2017-05-25 15:45:54 -07:00 committed by TensorFlower Gardener
parent 8f0d0bdca8
commit eae7e833f1
25 changed files with 754 additions and 401 deletions

View File

@ -71,8 +71,7 @@ class _ExperimentalFuncGraph(function._FuncGraph):
self.extra_inputs.append(x)
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
# pylint: disable=protected-access
ph._handle_shape = x._handle_shape
ph._handle_dtype = x._handle_dtype
ph._handle_data = x._handle_data
# pylint: enable=protected-access
inputs[i] = ph
self._captured[x] = ph

View File

@ -34,8 +34,6 @@ REGISTER_OP("ImageProjectiveTransform")
.Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
c->set_output_handle_shape(0, c->input_handle_shape(0));
return Status::OK();
})
.Doc(R"doc(

View File

@ -29,6 +29,7 @@ namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeAndType;
using shape_inference::ShapeHandle;
ShapeRefiner::ShapeRefiner(int graph_def_version,
@ -49,8 +50,8 @@ Status ShapeRefiner::AddNode(const Node* node) {
// indexed by 'node's input.
std::vector<Node*> input_nodes(node->num_inputs());
std::vector<ShapeHandle> input_shapes(node->num_inputs());
std::vector<DataType> input_handle_dtypes(node->num_inputs());
std::vector<ShapeHandle> input_handle_shapes(node->num_inputs());
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
input_handle_shapes_and_types(node->num_inputs());
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
@ -67,13 +68,13 @@ Status ShapeRefiner::AddNode(const Node* node) {
input_nodes[e->dst_input()] = input;
input_shapes[e->dst_input()] = c->output(e->src_output());
// Only propagate handle xshape and dtype of edges which are carrying
// resource handles.
// Only propagate handle data of edges which are carrying resource handles.
if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
input_handle_dtypes[e->dst_input()] =
c->output_handle_dtype(e->src_output());
input_handle_shapes[e->dst_input()] =
c->output_handle_shape(e->src_output());
const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
if (in_v != nullptr) {
input_handle_shapes_and_types[e->dst_input()].reset(
new std::vector<ShapeAndType>(*in_v));
}
}
}
@ -95,7 +96,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
std::unique_ptr<InferenceContext> c(
new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
input_shapes, input_tensors, input_tensors_as_shapes,
input_handle_shapes, input_handle_dtypes));
std::move(input_handle_shapes_and_types)));
if (!c->construction_status().ok()) {
return c->construction_status();
}
@ -170,12 +171,11 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
// Also propagate handle shape and dtype of edges which are carrying
// resource handles.
if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
if (node_context->set_input_handle_dtype(
e->dst_input(), c->output_handle_dtype(e->src_output()))) {
*refined = true;
}
if (node_context->MergeInputHandleShape(
e->dst_input(), c->output_handle_shape(e->src_output()))) {
auto* shapes_and_types =
c->output_handle_shapes_and_types(e->src_output());
if (shapes_and_types != nullptr &&
node_context->MergeInputHandleShapesAndTypes(e->dst_input(),
*shapes_and_types)) {
*refined = true;
}
}

View File

@ -791,8 +791,8 @@ TEST(ShapeRefinerTest, IncrementalUpdates) {
// Inject a shape, and incrementally propagate it to the dequeue op.
ctx = m.GetContext(queue);
shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7});
ctx->set_output_handle_shape(0, shp);
ctx->set_output_handle_dtype(0, DT_FLOAT);
ctx->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{shp, DT_FLOAT}});
bool refined = false;
TF_ASSERT_OK(m.UpdateNode(dequeue, &refined));

View File

@ -70,7 +70,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {},
{}, {}, {});
{}, {});
TF_EXPECT_OK(NoOutputs(&c));
EXPECT_EQ(0, c.num_outputs());
}
@ -88,8 +88,7 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {},
{});
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@ -97,7 +96,7 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({1, 23, 4, 4, 2})}, {}, {}, {}, {});
{S({1, 23, 4, 4, 2})}, {}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@ -125,7 +124,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, 3}), S({3, 4})}, {}, {}, {}, {});
{S({2, 3}), S({3, 4})}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -135,7 +134,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown inner dimension for one
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, -1}), S({3, 4})}, {}, {}, {}, {});
{S({2, -1}), S({3, 4})}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -145,7 +144,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Invalid rank.
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})},
{}, {}, {}, {});
{}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -156,7 +155,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown outer dimension
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, 3}), S({3, -1})}, {}, {}, {}, {});
{S({2, 3}), S({3, -1})}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -166,7 +165,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, 5}), S({3, 4})}, {}, {}, {}, {});
{S({2, 5}), S({3, 4})}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -178,7 +177,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}, {});
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -197,7 +196,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({3, 2}), S({3, 4})}, {}, {}, {}, {});
{S({3, 2}), S({3, 4})}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -215,7 +214,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, 3}), S({4, 3})}, {}, {}, {}, {});
{S({2, 3}), S({4, 3})}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -240,7 +239,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, 10}), S({10})}, {}, {}, {}, {});
{S({2, 10}), S({10})}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -250,7 +249,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Unknown ranks.
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{Unknown(), Unknown()}, {}, {}, {}, {});
{Unknown(), Unknown()}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_FALSE(c.RankKnown(output));
@ -259,7 +258,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Rank > 2
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {}, {});
{S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
@ -273,7 +272,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
{S({2, 3, 4, 5}), S({3})}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
@ -287,7 +286,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
{S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
@ -301,7 +300,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({10, 11, 12}), S({10})}, {}, {}, {}, {});
{S({10, 11, 12}), S({10})}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[10,11,12]", c.DebugString(output));
@ -310,7 +309,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Input rank not high enough
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {},
{}, {}, {});
{}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@ -323,7 +322,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
// NCHW format
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})},
{}, {}, {}, {});
{}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@ -343,7 +342,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {},
{}, {});
{});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -352,7 +351,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Rank > 2
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})},
{}, {}, {}, {});
{}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -365,7 +364,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})},
{}, {}, {}, {});
{}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@ -378,7 +377,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {}, {});
{S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@ -391,7 +390,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})},
{}, {}, {}, {});
{}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -399,7 +398,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Input rank not high enough
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {}, {},
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {},
{});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@ -412,7 +411,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
// NCHW format
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {},
{}, {});
{});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}
@ -832,7 +831,7 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{Unknown(), Unknown(), Unknown()}, {}, {}, {}, {});
{Unknown(), Unknown(), Unknown()}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -845,7 +844,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({-1, -1}), S({-1}), S({-1})}, {}, {}, {}, {});
{S({-1, -1}), S({-1}), S({-1})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -858,7 +857,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({-1}), S({-1}), S({-1})}, {}, {}, {}, {});
{S({-1}), S({-1}), S({-1})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -872,7 +871,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({5, 3}), S({4}), S({3})}, {}, {}, {}, {});
{S({5, 3}), S({4}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -886,7 +885,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({5, 3}), S({5}), S({4})}, {}, {}, {}, {});
{S({5, 3}), S({5}), S({4})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -900,7 +899,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({-1, 3}), S({5}), S({3})}, {}, {}, {}, {});
{S({-1, 3}), S({5}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -913,7 +912,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({5, 3}), S({-1}), S({3})}, {}, {}, {}, {});
{S({5, 3}), S({-1}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -926,7 +925,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({5, -1}), S({5}), S({3})}, {}, {}, {}, {});
{S({5, -1}), S({5}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -939,7 +938,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({5, 3}), S({5}), S({-1})}, {}, {}, {}, {});
{S({5, 3}), S({5}), S({-1})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -952,7 +951,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
{S({5, 3}), S({5}), S({3})}, {}, {}, {}, {});
{S({5, 3}), S({5}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());

View File

@ -33,8 +33,9 @@ InferenceContext::InferenceContext(
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<TensorShapeProto>& input_tensors_as_shapes,
const std::vector<TensorShapeProto>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes)
const std::vector<
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
input_handle_shapes_and_types)
: graph_def_version_(graph_def_version),
node_def_(*CHECK_NOTNULL(node_def)) {
std::vector<ShapeHandle> input_tensors_as_shape_handles;
@ -56,16 +57,26 @@ InferenceContext::InferenceContext(
}
inputs_.push_back(shape);
}
std::vector<ShapeHandle> handle_shapes;
for (const auto& p : input_handle_shapes) {
ShapeHandle shape;
construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
if (!construction_status_.ok()) {
return;
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
input_shapes.size());
for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
const auto& v = input_handle_shapes_and_types[i];
if (v == nullptr) {
continue;
}
handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
auto& new_v = *handle_data[i];
for (int j = 0; j < v->size(); ++j) {
const auto& p = (*v)[j];
construction_status_.Update(
MakeShapeFromShapeProto(p.first, &new_v[j].shape));
if (!construction_status_.ok()) {
return;
}
new_v[j].dtype = p.second;
}
handle_shapes.push_back(shape);
}
PostInputInit(handle_shapes, input_handle_dtypes);
PostInputInit(std::move(handle_data));
}
InferenceContext::InferenceContext(
@ -73,14 +84,15 @@ InferenceContext::InferenceContext(
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes,
const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes)
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
input_handle_shapes_and_types)
: graph_def_version_(graph_def_version),
node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
PostInputInit(input_handle_shapes, input_handle_dtypes);
PostInputInit(std::move(input_handle_shapes_and_types));
}
InferenceContext::~InferenceContext() {}
@ -149,16 +161,11 @@ void InferenceContext::PreInputInit(
for (int i = 0; i < num_outputs; ++i) {
outputs_.push_back(nullptr);
}
output_handle_shape_.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
output_handle_shape_.push_back(UnknownShape());
}
output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID);
output_handle_shapes_and_types_.resize(num_outputs);
}
void InferenceContext::PostInputInit(
const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes) {
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
int num_inputs_from_node_def = 0;
for (const auto& e : input_name_map_) {
num_inputs_from_node_def =
@ -166,25 +173,16 @@ void InferenceContext::PostInputInit(
}
// Allow passing empty shapes/dtypes to avoid changing every single test.
if (input_handle_shapes.empty()) {
input_handle_shape_.resize(inputs_.size());
if (input_handle_data.empty()) {
input_handle_shapes_and_types_.resize(inputs_.size());
} else {
input_handle_shape_ = input_handle_shapes;
if (input_handle_shape_.size() != inputs_.size()) {
if (input_handle_data.size() != inputs_.size()) {
construction_status_ = errors::InvalidArgument(
"Wrong number of handle shapes passed; expected ", inputs_.size(),
" got ", input_handle_shape_.size());
}
}
if (input_handle_dtypes.empty()) {
input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID);
} else {
input_handle_dtype_ = input_handle_dtypes;
if (input_handle_dtype_.size() != inputs_.size()) {
construction_status_ = errors::InvalidArgument(
"Wrong number of handle dtypes passed; expected ", inputs_.size(),
" got ", input_handle_dtype_.size());
" got ", input_handle_data.size());
return;
}
input_handle_shapes_and_types_ = std::move(input_handle_data);
}
if (inputs_.size() != num_inputs_from_node_def) {
@ -756,7 +754,7 @@ Status InferenceContext::Multiply(DimensionHandle first,
} else if (first_value == kUnknownDim || second_value == kUnknownDim) {
*out = UnknownDim();
} else {
// Invariant: Both values are known and and greater than 1.
// Invariant: Both values are known and greater than 1.
const int64 product = first_value * second_value;
if (product < 0) {
return errors::InvalidArgument(
@ -847,10 +845,65 @@ Status InferenceContext::AttachContext(const Status& status) {
}
ShapeHandle InferenceContext::input_handle_shape(int idx) {
if (!input_handle_shape_[idx].IsSet()) {
input_handle_shape_[idx] = UnknownShape();
if (input_handle_shapes_and_types_[idx] == nullptr) {
input_handle_shapes_and_types_[idx].reset(
new std::vector<ShapeAndType>{{UnknownShape(), DT_INVALID}});
}
return input_handle_shape_[idx];
return (*input_handle_shapes_and_types_[idx])[0].shape;
}
bool InferenceContext::MergeHandleShapesAndTypes(
const std::vector<ShapeAndType>& shapes_and_types,
std::vector<ShapeAndType>* to_update) {
if (shapes_and_types.size() != to_update->size()) {
return false;
}
std::vector<ShapeAndType> new_values(shapes_and_types.size());
bool refined = false;
for (int i = 0; i < shapes_and_types.size(); ++i) {
const ShapeAndType& existing = (*to_update)[i];
new_values[i].dtype = shapes_and_types[i].dtype;
if (new_values[i].dtype != existing.dtype && existing.dtype == DT_INVALID) {
refined = true;
}
if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
.ok()) {
// merge failed, ignore the new value.
new_values[i].shape = existing.shape;
}
if (!existing.shape.SameHandle(new_values[i].shape)) {
refined = true;
}
}
if (!refined) {
return false;
}
for (int i = 0; i < new_values.size(); ++i) {
(*to_update)[i] = new_values[i];
}
return true;
}
bool InferenceContext::MergeOutputHandleShapesAndTypes(
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
if (output_handle_shapes_and_types_[idx] == nullptr) {
output_handle_shapes_and_types_[idx].reset(
new std::vector<ShapeAndType>(shapes_and_types));
return true;
}
return MergeHandleShapesAndTypes(shapes_and_types,
output_handle_shapes_and_types_[idx].get());
}
bool InferenceContext::MergeInputHandleShapesAndTypes(
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
if (input_handle_shapes_and_types_[idx] == nullptr) {
input_handle_shapes_and_types_[idx].reset(
new std::vector<ShapeAndType>(shapes_and_types));
return true;
}
return MergeHandleShapesAndTypes(shapes_and_types,
input_handle_shapes_and_types_[idx].get());
}
// -----------------------------------------------------------------------------

View File

@ -121,6 +121,14 @@ struct DimensionOrConstant {
DimensionOrConstant();
};
struct ShapeAndType {
ShapeAndType() {}
ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
ShapeHandle shape;
DataType dtype = DT_INVALID;
};
// Shape inference functions registered on ops in REGISTER_OP implement
// their shape functions in terms of this InferenceContext. An InferenceContext
// is created by the framework and passed to a shape inference function. The
@ -149,26 +157,28 @@ class InferenceContext {
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes,
const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes);
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
input_handle_shapes_and_types);
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
//
// Elements of <input_tensors_as_shapes> are used for when a shape function
// makes a call to MakeShapeFromShapeTensor; in particular, when the
// input_tensors[i] is nullptr but the shape represented by it is partially
// known from analysis of the graph.
// <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
// Values of <input_tensors_as_shapes> do not need to outlive the context.
// Elements of <input_tensors_as_shapes> are used for when a shape
// function makes a call to MakeShapeFromShapeTensor; in particular, when
// the input_tensors[i] is nullptr but the shape represented by it is
// partially known from analysis of the graph. <input_tensors_as_shapes>
// can have fewer elements than <input_shapes>. Values of
// <input_tensors_as_shapes> do not need to outlive the context.
//
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
InferenceContext(int graph_def_version, const NodeDef* node_def,
const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<TensorShapeProto>& input_tensors_as_shapes,
const std::vector<TensorShapeProto>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes);
// REQUIRES: <node_def> is not NULL, and must outlive the
// InferenceContext.
InferenceContext(
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<TensorShapeProto>& input_tensors_as_shapes,
const std::vector<
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
input_handle_shapes_and_types);
~InferenceContext();
@ -441,70 +451,62 @@ class InferenceContext {
// propagate that information. Output handle dtypes and shapes are ignored if
// the output tensor is not of type DT_RESOURCE.
// Merge the stored shape corresponding to the input handle in position idx
// with the specified shape. This requires idx to be in the [0, num_inputs)
// range. If the merge is successful and the new shape differs from the old
// one, store the new shape and return true. Return false otherwise.
bool MergeInputHandleShape(int idx, ShapeHandle shape) {
ShapeHandle new_shape;
if (!Merge(input_handle_shape_[idx], shape, &new_shape).ok() ||
input_handle_shape_[idx].SameHandle(new_shape)) {
return false;
}
input_handle_shape_[idx] = shape;
return true;
// Merge the stored shapes and types corresponding to the input handle in
// position idx with the specified shapes and types. This requires idx to be
// in the [0, num_inputs) range.
//
// If the merge is successful and any of the new shapes differs from the old
// one, or any of the old dtypes was DT_INVALID, store the new shapes and
// return true. Return false otherwise.
bool MergeInputHandleShapesAndTypes(
int idx,
const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
// As MergeInputHandleShapesAndTypes, but for an output.
bool MergeOutputHandleShapesAndTypes(
int idx, const std::vector<ShapeAndType>& shapes) TF_MUST_USE_RESULT;
// Returns the output handle shapes and types, for the resource tensor output
// at index <idx>. Returns NULL if the shape and types were never set.
const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
return output_handle_shapes_and_types_[idx].get();
}
// Set the type corresponding to the resource in position idx. This requires
// idx to be in the [0, num_inputs) range. Returns true iff the stored type
// has been updated.
bool set_input_handle_dtype(int idx, DataType dtype) {
if (input_handle_dtype_[idx] != dtype) {
input_handle_dtype_[idx] = dtype;
return true;
}
return false;
// Returns the inputs handle shapes and types, for the resource tensor output
// at index <idx>. Returns NULL if the shape and types were not available.
const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
return input_handle_shapes_and_types_[idx].get();
}
// DEPRECATED: use input_handle_shapes_and_types.
ShapeHandle input_handle_shape(int idx);
// DEPRECATED: use input_handle_shapes_and_types.
DataType input_handle_dtype(int idx) const {
return input_handle_dtype_[idx];
}
// Merge the stored shape corresponding to the output handle in position idx
// with the specified shape. This requires idx to be in the [0, num_outputs)
// range. If the merge is successful and the new shape differs from the old
// one, store the new shape and return true. Return false otherwise.
bool MergeOutputHandleShape(int idx, ShapeHandle shape) {
ShapeHandle new_shape;
if (!Merge(output_handle_shape_[idx], shape, &new_shape).ok() ||
output_handle_shape_[idx].SameHandle(new_shape)) {
return false;
if (input_handle_shapes_and_types_[idx] == nullptr) {
return DT_INVALID;
} else {
DCHECK_EQ(input_handle_shapes_and_types_[idx]->size(), 1);
return (*input_handle_shapes_and_types_[idx])[0].dtype;
}
output_handle_shape_[idx] = shape;
return true;
}
// Overwrite the shape corresponding to the output handle in position idx with
// the specified shape.
void set_output_handle_shape(int idx, ShapeHandle shape) {
output_handle_shape_[idx] = shape;
}
// Set the type corresponding to the resource in position idx. This requires
// idx to be in the [0, num_outputs) range. Returns true iff the stored type
// has been updated.
bool set_output_handle_dtype(int idx, DataType dtype) {
if (output_handle_dtype_[idx] != dtype) {
output_handle_dtype_[idx] = dtype;
return true;
}
return false;
void set_output_handle_shapes_and_types(
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
output_handle_shapes_and_types_[idx].reset(
new std::vector<ShapeAndType>(shapes_and_types));
}
ShapeHandle output_handle_shape(int idx) const {
return output_handle_shape_[idx];
// DEPRECATED: use output_handle_shapes_and_types.
ShapeHandle output_handle_shape(int idx) {
return output_handle_shapes_and_types_[idx] == nullptr
? UnknownShape()
: (*output_handle_shapes_and_types_[idx])[0].shape;
}
// DEPRECATED: use output_handle_shapes_and_types.
DataType output_handle_dtype(int idx) const {
return output_handle_dtype_[idx];
return output_handle_shapes_and_types_[idx] == nullptr
? DT_INVALID
: (*output_handle_shapes_and_types_[idx])[0].dtype;
}
// Note that shape functions should usually call MakeShapeFromShapeTensor,
@ -555,8 +557,8 @@ class InferenceContext {
void PreInputInit(const OpDef& op_def,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
void PostInputInit(const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes);
void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
input_handle_data);
DimensionHandle GetDimension(const DimensionOrConstant& d);
@ -573,6 +575,12 @@ class InferenceContext {
// Adds additional context to the given status.
Status AttachContext(const Status& status);
// Used to implement MergeInputHandleShapesAndTypes and
// MergeOutputHandleShapesAndTypes.
bool MergeHandleShapesAndTypes(
const std::vector<ShapeAndType>& shapes_and_types,
std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
ShapeManager shape_manager_;
// inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
@ -585,10 +593,19 @@ class InferenceContext {
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
std::vector<ShapeHandle> input_handle_shape_;
std::vector<DataType> input_handle_dtype_;
std::vector<ShapeHandle> output_handle_shape_;
std::vector<DataType> output_handle_dtype_;
// input_handle_shapes_and_types_[i] is the list of shape/type pairs available
// through the resource handle passed along input i of the node.
//
// Values may be NULL.
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
input_handle_shapes_and_types_;
// output_handle_shapes_and_types_[i] is the list of shape/type pairs
// available through the resource handle passed along output i of the node.
//
// Values may be NULL.
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
output_handle_shapes_and_types_;
const int graph_def_version_;
const NodeDef& node_def_;

View File

@ -61,6 +61,7 @@ class ShapeInferenceTest : public ::testing::Test {
bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
bool IsSet(DimensionHandle d) { return d.IsSet(); }
bool IsSet(ShapeHandle s) { return s.IsSet(); }
void TestMergeHandles(bool input_not_output);
static const int kVersion = 0; // used for graph-def version.
};
@ -74,7 +75,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
.Input(FakeInput(DT_FLOAT))
.Finalize(&def);
InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
{}, {}, {}, {});
{}, {}, {});
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
@ -110,8 +111,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {},
{});
InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
EXPECT_EQ(InferenceContext::kUnknownDim,
c.Value(InferenceContext::kUnknownDim));
EXPECT_EQ(1, c.Value(1));
@ -126,7 +126,7 @@ TEST_F(ShapeInferenceTest, Run) {
NodeDef def;
def.set_name("foo");
def.set_op("foo_op");
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
TF_ASSERT_OK(c.construction_status());
{
@ -166,7 +166,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
// Error when no constant tensors were requested.
{
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
{}, {});
{});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
ShapeHandle h;
@ -185,8 +185,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
Tensor input_t =
::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
{S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {},
{});
{S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
c->input_tensor(0); // get this one, but it's null - won't be in error.
@ -208,7 +207,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
{
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
{nullptr, &input_t}, {}, {}, {});
{nullptr, &input_t}, {}, {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
ShapeHandle s;
@ -231,8 +230,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
{
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
{nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {},
{});
{nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
ShapeHandle s;
@ -255,7 +253,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
{Unknown(), S({1, -1, 3}), S({})}, {}, {}, {}, {});
{Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@ -296,8 +294,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
TEST_F(ShapeInferenceTest, NumElements) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {},
{});
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
@ -311,7 +308,7 @@ TEST_F(ShapeInferenceTest, NumElements) {
TEST_F(ShapeInferenceTest, WithRank) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
{Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
{Unknown(), S({1, -1, 3})}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -350,7 +347,7 @@ TEST_F(ShapeInferenceTest, WithRank) {
TEST_F(ShapeInferenceTest, WithRankAtMost) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
{Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
{Unknown(), S({1, -1, 3})}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -389,7 +386,7 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
{Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
{Unknown(), S({1, -1, 3})}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -427,8 +424,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
TEST_F(ShapeInferenceTest, WithValue) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {},
{});
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@ -470,7 +466,7 @@ TEST_F(ShapeInferenceTest, WithValue) {
TEST_F(ShapeInferenceTest, MergeDim) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
{}, {}, {}, {});
{}, {}, {});
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@ -519,7 +515,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
Unknown(), S({1})},
{}, {}, {}, {});
{}, {}, {});
auto s_unknown = c.input(0);
auto s_1_2 = c.input(1);
@ -595,7 +591,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
{
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
},
{}, {}, {}, {});
{}, {}, {});
auto s_unknown = c.input(0);
auto s_u_2 = c.input(1);
@ -648,7 +644,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
TEST_F(ShapeInferenceTest, Subshape) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
{S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {}, {});
{S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
ShapeHandle unknown = c.input(1);
ShapeHandle out;
@ -723,7 +719,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
TEST_F(ShapeInferenceTest, Concatenate) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {});
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -750,7 +746,7 @@ TEST_F(ShapeInferenceTest, Concatenate) {
TEST_F(ShapeInferenceTest, ReplaceDim) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
{}, {}, {}, {});
{}, {}, {});
auto in = c.input(0);
auto unknown = c.input(1);
@ -782,7 +778,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
TEST_F(ShapeInferenceTest, MakeShape) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
{}, {}, {});
{}, {});
std::vector<DimensionHandle> dims;
auto in0 = c.input(0);
@ -807,7 +803,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
TEST_F(ShapeInferenceTest, UnknownShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
auto u0 = c.UnknownShape();
auto u1 = c.UnknownShape();
@ -819,7 +815,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
TEST_F(ShapeInferenceTest, Scalar) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
auto s0 = c.Scalar();
EXPECT_EQ("[]", c.DebugString(s0));
@ -830,7 +826,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
TEST_F(ShapeInferenceTest, Vector) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
auto s0 = c.Vector(1);
EXPECT_EQ("[1]", c.DebugString(s0));
@ -846,7 +842,7 @@ TEST_F(ShapeInferenceTest, Vector) {
TEST_F(ShapeInferenceTest, Matrix) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
auto s0 = c.Matrix(1, 2);
EXPECT_EQ("[1,2]", c.DebugString(s0));
@ -869,7 +865,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
auto create = [&](Tensor* t) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
{}, {});
{});
ShapeHandle out;
Status s = c.MakeShapeFromShapeTensor(0, &out);
if (s.ok()) {
@ -922,7 +918,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
{
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
{}, {}, {});
{}, {});
ShapeHandle out;
EXPECT_EQ("Shape must be rank 1 but is rank 2",
c.MakeShapeFromShapeTensor(0, &out).error_message());
@ -932,7 +928,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
// With an unknown rank.
ShapeHandle out;
@ -951,7 +947,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
ShapeHandle out;
TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
@ -965,7 +961,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
TensorShapeProto proto;
// With a set unknown rank.
@ -1001,7 +997,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
TEST_F(ShapeInferenceTest, MakeDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
auto d0 = c.MakeDim(1);
auto d1 = c.MakeDim(1);
@ -1015,7 +1011,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
TEST_F(ShapeInferenceTest, UnknownDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
auto d0 = c.UnknownDim();
auto d1 = c.UnknownDim();
@ -1027,7 +1023,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
@ -1041,7 +1037,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
{&t1, &t2}, {}, {}, {});
{&t1, &t2}, {}, {});
EXPECT_TRUE(c.input_tensor(0) == &t1);
EXPECT_TRUE(c.input_tensor(1) == &t2);
@ -1053,7 +1049,7 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
{&t1, &t2}, {}, {}, {});
{&t1, &t2}, {}, {});
DimensionHandle d;
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
@ -1084,7 +1080,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
.ok());
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {});
string value;
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
EXPECT_EQ("bar", value);
@ -1093,7 +1089,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
TEST_F(ShapeInferenceTest, Divide) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
{}, {}, {});
{}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1156,7 +1152,7 @@ TEST_F(ShapeInferenceTest, Divide) {
TEST_F(ShapeInferenceTest, Add) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
{}, {});
{});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1208,7 +1204,7 @@ TEST_F(ShapeInferenceTest, Add) {
TEST_F(ShapeInferenceTest, Subtract) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
{}, {}, {});
{}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1258,7 +1254,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
TEST_F(ShapeInferenceTest, Multiply) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
{}, {}, {});
{}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1311,7 +1307,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
TEST_F(ShapeInferenceTest, FullyDefined) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
// No rank or missing dimension information should return false.
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
@ -1325,7 +1321,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
TEST_F(ShapeInferenceTest, Min) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
{}, {}, {});
{}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@ -1374,7 +1370,7 @@ TEST_F(ShapeInferenceTest, Min) {
TEST_F(ShapeInferenceTest, Max) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
{}, {});
{});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@ -1409,5 +1405,116 @@ TEST_F(ShapeInferenceTest, Max) {
EXPECT_TRUE(SameHandle(d_2, out));
}
void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
{});
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
ShapeHandle s;
TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s));
return s;
};
auto get_shapes_and_types_from_context = [&](int idx) {
if (input_not_output) {
return c.input_handle_shapes_and_types(idx);
} else {
return c.output_handle_shapes_and_types(idx);
}
};
auto merge_shapes_and_types_to_context =
[&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
if (input_not_output) {
return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
} else {
return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
}
};
EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
// First merge will take the input completely.
std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
{c.UnknownShape(), DT_INVALID},
{make_shape({4, 3, 2, 1}), DT_INT32}};
ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
ASSERT_EQ(3, v.size());
for (int i = 0; i < v.size(); ++i) {
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
EXPECT_EQ(t[i].dtype, v[i].dtype);
}
// Merge that fails because wrong number of values passed.
// Fails, and no changes made.
ASSERT_FALSE(merge_shapes_and_types_to_context(
0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
v = *get_shapes_and_types_from_context(0);
ASSERT_EQ(3, v.size());
for (int i = 0; i < v.size(); ++i) {
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
EXPECT_EQ(t[i].dtype, v[i].dtype);
}
// Only difference is in a mismatched shape. That is ignored,
// and there are no other changes, so nothing is done.
//
// TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
// return an error (separate error from 'refined' output)?
auto t2 = t;
t2[2].shape = make_shape({4, 3, 4, 1});
ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
v = *get_shapes_and_types_from_context(0);
ASSERT_EQ(3, v.size());
for (int i = 0; i < v.size(); ++i) {
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
EXPECT_EQ(t[i].dtype, v[i].dtype);
}
// Only difference is in a mismatched dtype. That is ignored,
// and there are no other changes, so nothing is done.
t2 = t;
t2[2].dtype = DT_FLOAT;
ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
v = *get_shapes_and_types_from_context(0);
ASSERT_EQ(3, v.size());
for (int i = 0; i < v.size(); ++i) {
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
EXPECT_EQ(t[i].dtype, v[i].dtype);
}
// Difference is mergeable (new shape).
t[1].shape = make_shape({1, 10});
ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
v = *get_shapes_and_types_from_context(0);
ASSERT_EQ(3, v.size());
for (int i = 0; i < v.size(); ++i) {
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
EXPECT_EQ(t[i].dtype, v[i].dtype);
}
// Difference is mergeable (new type).
t[1].dtype = DT_DOUBLE;
ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
v = *get_shapes_and_types_from_context(0);
ASSERT_EQ(3, v.size());
for (int i = 0; i < v.size(); ++i) {
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
EXPECT_EQ(t[i].dtype, v[i].dtype);
}
// No difference.
ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
}
TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) {
TestMergeHandles(true);
}
TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
TestMergeHandles(false);
}
} // namespace shape_inference
} // namespace tensorflow

View File

@ -45,7 +45,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
shape_inference::InferenceContext c(op.graph_def_version, &op.node_def,
op_reg_data->op_def, in_shapes,
op.input_tensors, {}, {}, {});
op.input_tensors, {}, {});
TF_RETURN_IF_ERROR(c.construction_status());
if (op_reg_data->shape_inference_fn == nullptr) {
return errors::InvalidArgument(

View File

@ -26,6 +26,38 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
using shape_inference::InferenceContext;
using shape_inference::ShapeAndType;
using shape_inference::ShapeHandle;
namespace {
// Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into
// <*queue_shapes_and_types>.
Status MergeEnqueueShapesAndTypes(
const std::vector<ShapeAndType>& shapes_and_types, InferenceContext* qctx,
std::vector<ShapeAndType>* queue_shapes_and_types) {
if (shapes_and_types.size() != queue_shapes_and_types->size()) {
return errors::InvalidArgument(
"Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
" vs ", queue_shapes_and_types->size());
}
for (int i = 0; i < shapes_and_types.size(); ++i) {
const ShapeAndType& a = shapes_and_types[i];
ShapeAndType& b = (*queue_shapes_and_types)[i];
if (a.dtype != b.dtype) {
return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
i, ": ", DataTypeString(a.dtype), " vs ",
DataTypeString(b.dtype));
}
TF_RETURN_IF_ERROR(qctx->Merge(a.shape, b.shape, &b.shape));
}
return Status::OK();
}
} // namespace
Status GraphProperties::InferStatically() {
Graph graph(OpRegistry::Global());
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
@ -60,32 +92,49 @@ Status GraphProperties::InferStatically() {
if (!qctx) {
continue;
}
DataType queue_type = qctx->output_handle_dtype(0);
shape_inference::ShapeHandle queue_shp = qctx->output_handle_shape(0);
if (qctx->FullyDefined(queue_shp) && queue_type != DT_INVALID) {
continue;
// Check to see if the shape is fully defined.
auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
if (queue_handle_data != nullptr) {
bool fully_defined = true;
for (const auto& shape_and_type : *queue_handle_data) {
if (!qctx->FullyDefined(shape_and_type.shape) ||
shape_and_type.dtype == DT_INVALID) {
fully_defined = false;
}
}
if (fully_defined) {
continue;
}
}
std::vector<ShapeAndType> queue_shapes_and_types;
if (queue_handle_data != nullptr) {
queue_shapes_and_types = *queue_handle_data;
}
for (const auto& node : resource_data.second) {
auto ctx = shape_refiner.GetContext(node);
if (!ctx) {
continue;
}
if (node->type_string().find("Enqueue") != std::string::npos) {
if (ctx->num_inputs() == 2) {
const DataType dtype = node->input_type(1);
if (queue_type == DT_INVALID) {
queue_type = dtype;
} else {
CHECK_EQ(queue_type, dtype);
}
shape_inference::ShapeHandle shp = ctx->input(1);
TF_RETURN_IF_ERROR(qctx->Merge(queue_shp, shp, &queue_shp));
// TODO(bsteiner): handle EnqueueMany as well.
if (node->type_string().find("Enqueue") != std::string::npos &&
node->type_string().find("EnqueueMany") == std::string::npos) {
std::vector<ShapeAndType> shapes_and_types;
for (int i = 1; i < ctx->num_inputs(); ++i) {
shapes_and_types.push_back({ctx->input(i), node->input_type(i)});
}
if (queue_shapes_and_types.empty()) {
queue_shapes_and_types = shapes_and_types;
} else {
TF_RETURN_IF_ERROR(MergeEnqueueShapesAndTypes(
shapes_and_types, qctx, &queue_shapes_and_types));
}
}
}
if (qctx->set_output_handle_dtype(0, queue_type) |
qctx->MergeOutputHandleShape(0, queue_shp)) {
if (!queue_shapes_and_types.empty() &&
qctx->MergeOutputHandleShapesAndTypes(0, queue_shapes_and_types)) {
new_shapes.push(qnode);
}
}
@ -115,7 +164,7 @@ Status GraphProperties::InferStatically() {
for (int i = 0; i < ctx->num_inputs(); ++i) {
OpInfo::TensorProperties properties;
properties.set_dtype(node->input_type(i));
shape_inference::ShapeHandle shp = ctx->input(i);
ShapeHandle shp = ctx->input(i);
if (!ctx->RankKnown(shp)) {
properties.mutable_shape()->set_unknown_rank(true);
} else {
@ -135,7 +184,7 @@ Status GraphProperties::InferStatically() {
for (int i = 0; i < ctx->num_outputs(); ++i) {
OpInfo::TensorProperties properties;
properties.set_dtype(node->output_type(i));
shape_inference::ShapeHandle shp = ctx->output(i);
ShapeHandle shp = ctx->output(i);
if (!ctx->RankKnown(shp)) {
properties.mutable_shape()->set_unknown_rank(true);
} else {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@ -38,6 +39,22 @@ class GraphPropertiesTest : public ::testing::Test {
void TearDown() override { cluster_.reset(); }
protected:
// Returns a string form of <p>, suitable for comparing type and shape.
// Example output for 4-d float tensor: "float: [10,2,30,4]"
string PropToString(const OpInfo::TensorProperties& p) {
string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
if (p.shape().unknown_rank()) {
strings::StrAppend(&s, "?");
} else {
strings::StrAppend(&s, "[");
for (int i = 0; i < p.shape().dim_size(); ++i) {
strings::StrAppend(&s, i == 0 ? "" : ",", p.shape().dim(i).size());
}
strings::StrAppend(&s, "]");
}
return s;
}
std::unique_ptr<SingleMachine> cluster_;
};
@ -194,6 +211,20 @@ TEST_F(GraphPropertiesTest, Queues) {
auto dequeue4 =
ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
// Create a queue that takes in three tensors.
auto q5 = ops::RandomShuffleQueue(
root.WithOpName("Queue5"),
{DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
Output rnd2 =
ops::RandomNormal(root.WithOpName("rnd"), {10}, DataType::DT_DOUBLE);
Output rnd3 =
ops::RandomNormal(root.WithOpName("rnd"), {1, 2, 3}, DataType::DT_FLOAT);
auto enqueue5 =
ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
auto dequeue5 = ops::QueueDequeue(
root.WithOpName("Dequeue5"), q5,
{DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
GrapplerItem item;
TF_CHECK_OK(root.ToGraphDef(&item.graph));
@ -201,34 +232,31 @@ TEST_F(GraphPropertiesTest, Queues) {
TF_CHECK_OK(properties.InferStatically());
const auto props1 = properties.GetOutputProperties("Dequeue1");
EXPECT_EQ(1, props1.size());
const OpInfo::TensorProperties& prop1 = props1[0];
EXPECT_EQ(DT_FLOAT, prop1.dtype());
EXPECT_FALSE(prop1.shape().unknown_rank());
EXPECT_EQ(2, prop1.shape().dim_size());
EXPECT_EQ(3, prop1.shape().dim(0).size());
EXPECT_EQ(7, prop1.shape().dim(1).size());
ASSERT_EQ(1, props1.size());
EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
const auto props2 = properties.GetOutputProperties("Dequeue2");
EXPECT_EQ(1, props2.size());
const OpInfo::TensorProperties& prop2 = props2[0];
EXPECT_EQ(DT_FLOAT, prop2.dtype());
EXPECT_FALSE(prop2.shape().unknown_rank());
EXPECT_EQ(2, prop2.shape().dim_size());
EXPECT_EQ(3, prop2.shape().dim(0).size());
EXPECT_EQ(7, prop2.shape().dim(1).size());
ASSERT_EQ(1, props2.size());
EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
// The dequeue3 op shape is unknown.
const auto props3 = properties.GetOutputProperties("Dequeue3");
ASSERT_EQ(1, props3.size());
EXPECT_EQ("float: ?", PropToString(props3[0]));
// The dequeue3 op shape is unknown. The square2 op shape is known. Verify
// that we merge the 2 properly to determine the shape of the data coming out
// of the queue.
const auto props4 = properties.GetOutputProperties("Dequeue4");
EXPECT_EQ(1, props4.size());
const OpInfo::TensorProperties& prop4 = props4[0];
EXPECT_EQ(DT_FLOAT, prop4.dtype());
EXPECT_FALSE(prop4.shape().unknown_rank());
EXPECT_EQ(2, prop4.shape().dim_size());
EXPECT_EQ(3, prop4.shape().dim(0).size());
EXPECT_EQ(7, prop4.shape().dim(1).size());
ASSERT_EQ(1, props4.size());
EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
// The dequeue5 op shape is known.
const auto props5 = properties.GetOutputProperties("Dequeue5");
ASSERT_EQ(3, props5.size());
EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
EXPECT_EQ("double: [10]", PropToString(props5[1]));
EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
}
TEST_F(GraphPropertiesTest, Loops) {

View File

@ -1534,8 +1534,10 @@ REGISTER_OP("Identity")
.Attr("T: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
c->set_output_handle_shape(0, c->input_handle_shape(0));
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
}
return Status::OK();
})
.Doc(R"Doc(
@ -1551,8 +1553,10 @@ REGISTER_OP("_MklIdentity")
.Attr("T: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
c->set_output_handle_shape(0, c->input_handle_shape(0));
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
}
return Status::OK();
})
.Doc(R"Doc( Mkl implementation of IdentityOp

View File

@ -162,9 +162,15 @@ TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
// Check that handle dtypes are preserved.
const OpRegistrationData* op_reg_data;
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
std::vector<
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>
handle_data;
handle_data.emplace_back(
new std::vector<std::pair<TensorShapeProto, DataType>>{
{TensorShapeProto(), DT_BOOL}});
shape_inference::InferenceContext c(TF_GRAPH_DEF_VERSION, &op.node_def,
op_reg_data->op_def, {TensorShapeProto()},
{}, {}, {}, {DT_BOOL});
{}, {}, handle_data);
TF_ASSERT_OK(c.construction_status());
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));

View File

@ -32,10 +32,11 @@ Status SwitchShape(InferenceContext* c) {
c->set_output(1, out);
// Handle resource shape / dtype.
c->set_output_handle_shape(0, c->input_handle_shape(0));
c->set_output_handle_shape(1, c->input_handle_shape(0));
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
c->set_output_handle_dtype(1, c->input_handle_dtype(0));
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
c->set_output_handle_shapes_and_types(1, *handle_data);
}
return Status::OK();
}
} // namespace
@ -200,8 +201,10 @@ REGISTER_OP("Enter")
c->set_output(0, c->UnknownShape());
// Handle resource shape / dtype, if present.
c->set_output_handle_shape(0, c->input_handle_shape(0));
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
}
return Status::OK();
})

View File

@ -644,15 +644,15 @@ REGISTER_OP("QueueDequeueV2")
.Attr("component_types: list(type) >= 1")
.Attr("timeout_ms: int = -1")
.SetShapeFn([](InferenceContext* c) {
if (c->num_outputs() == 1) {
c->set_output(0, c->input_handle_shape(0));
} else {
// TODO(vrv): handle the case of multiple outputs.
auto* t = c->input_handle_shapes_and_types(0);
if (t != nullptr && t->size() == c->num_outputs()) {
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
c->set_output(i, (*t)[i].shape);
}
return Status::OK();
} else {
return shape_inference::UnknownShape(c);
}
return Status::OK();
})
.Doc(R"doc(
Dequeues a tuple of one or more tensors from the given queue.

View File

@ -927,18 +927,33 @@ REGISTER_OP("Select")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
auto* handle_data_1 = c->input_handle_shapes_and_types(1);
auto* handle_data_2 = c->input_handle_shapes_and_types(2);
// Merge handle shape and dtype if applicable.
if (c->input_handle_dtype(1) != c->input_handle_dtype(2)) {
// TODO(apassos) resolve this in the manner of b/32476923
return errors::InvalidArgument(
"Trying to merge handles pointing to different dtypes.");
if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
const auto size = handle_data_1->size();
std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
if (size != handle_data_2->size()) {
return errors::InvalidArgument(
"Trying to merge handles pointing to different numbers of "
"tensors.");
}
for (int i = 0; i < size; ++i) {
const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
if (s1.dtype != s2.dtype) {
// TODO(apassos) resolve this in the manner of b/32476923
return errors::InvalidArgument(
"Trying to merge handles pointing to different dtypes.");
}
merged_handle_data[i].dtype = s1.dtype;
TF_RETURN_IF_ERROR(
c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
}
c->set_output_handle_shapes_and_types(0, merged_handle_data);
}
c->set_output_handle_dtype(0, c->input_handle_dtype(1));
ShapeHandle output_handle_shape;
TF_RETURN_IF_ERROR(c->Merge(c->input_handle_shape(1),
c->input_handle_shape(2),
&output_handle_shape));
c->set_output_handle_shape(0, output_handle_shape);
// The inputs 'then' and 'else' must have the same shape.
ShapeHandle data = c->input(1);
@ -2074,7 +2089,6 @@ REGISTER_OP("Bincount")
.Output("bins: T")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->UnknownShapeOfRank(1));
c->set_output_handle_dtype(0, c->input_handle_dtype(2));
return Status::OK();
})
.Doc(R"doc(

View File

@ -188,36 +188,70 @@ TEST(MathOpsTest, Select_ShapeFn) {
INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
"[2,?,5];[?,?,3];[?,2,?]");
// Test that handle shapes were merged.
// Test that handles were merged.
//
// Tests below will modify handle_data and call run_inference_for_handles to
// rerun shape inference, updating the context <c>.
const OpRegistrationData* op_reg_data;
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
TensorShapeProto i0;
i0.add_dim()->set_size(1);
i0.add_dim()->set_size(-1);
TensorShapeProto i1;
i1.add_dim()->set_size(-1);
i1.add_dim()->set_size(2);
typedef std::vector<std::pair<TensorShapeProto, DataType>> ShapeDtypeV;
std::vector<std::unique_ptr<ShapeDtypeV>> handle_data;
std::unique_ptr<shape_inference::InferenceContext> c;
Status run_status;
auto run_inference_for_handles = [&]() -> Status {
CHECK(op_reg_data->shape_inference_fn != nullptr);
c.reset(new shape_inference::InferenceContext(
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
handle_data));
TF_CHECK_OK(c->construction_status());
Status s = c->Run(op_reg_data->shape_inference_fn);
LOG(INFO) << "Inference got " << s;
return s;
};
auto shape_proto = [](std::initializer_list<int64> dim_sizes) {
TensorShapeProto p;
for (auto i : dim_sizes) p.add_dim()->set_size(i);
return p;
};
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
shape_inference::InferenceContext c(
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
{TensorShapeProto(), i0, i1}, {});
TF_ASSERT_OK(c.construction_status());
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
EXPECT_TRUE(c.FullyDefined(c.output_handle_shape(0)));
EXPECT_EQ("[1,2]", c.DebugString(c.output_handle_shape(0)));
TensorShapeProto i0 = shape_proto({1, -1});
TensorShapeProto i1 = shape_proto({-1, 2});
TensorShapeProto unknown_shape;
unknown_shape.set_unknown_rank(true);
TensorShapeProto scalar;
handle_data.emplace_back(
new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}});
handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}});
handle_data.emplace_back(
new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}});
TF_ASSERT_OK(run_inference_for_handles());
auto* out = c->output_handle_shapes_and_types(0);
ASSERT_EQ(2, out->size());
EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape));
EXPECT_EQ(DT_FLOAT, out->at(0).dtype);
EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape));
EXPECT_EQ(DT_INT32, out->at(1).dtype);
// Expect an error when the shapes can't be merged.
TensorShapeProto i2;
i1.add_dim()->set_size(2);
i1.add_dim()->set_size(2);
shape_inference::InferenceContext c2(
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
{TensorShapeProto(), i0, i2}, {});
TF_ASSERT_OK(c.construction_status());
EXPECT_FALSE(c2.Run(op_reg_data->shape_inference_fn).ok());
handle_data[2]->at(0).first = shape_proto({2, 2});
EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
.contains("must be equal, but are 1 and 2"));
handle_data[2]->at(0).first = i1; // restore to valid
// Expect an error when the types can't be merged.
handle_data[2]->at(1).second = DT_INT64;
EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
.contains("pointing to different dtypes"));
handle_data[2]->at(1).second = DT_INT32; // restore to valid
// Expect an error when different numbers of tensors are merged.
handle_data[2]->push_back({i1, DT_FLOAT});
EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
.contains("pointing to different numbers of tensors"));
handle_data[2]->pop_back(); // restore to valid.
}
TEST(MathOpsTest, Range_ShapeFn) {

View File

@ -20,10 +20,43 @@
#include "tensorflow/core/framework/shape_inference.h"
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeAndType;
using ::tensorflow::shape_inference::ShapeHandle;
namespace tensorflow {
namespace {
Status ValidateVariableResourceHandle(InferenceContext* c,
ShapeAndType* shape_and_type) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr || handle_data->empty()) {
shape_and_type->shape = c->UnknownShape();
shape_and_type->dtype = DT_INVALID;
} else {
*shape_and_type = (*handle_data)[0];
DataType value_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
if (shape_and_type->dtype != value_dtype) {
return errors::InvalidArgument(
"Trying to read variable with wrong dtype. "
"Expected ",
DataTypeString(shape_and_type->dtype), " got ",
DataTypeString(value_dtype));
}
}
return Status::OK();
}
Status ReadVariableShapeFn(InferenceContext* c) {
ShapeAndType shape_and_type;
TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &shape_and_type));
c->set_output(0, shape_and_type.shape);
return Status::OK();
}
} // namespace
REGISTER_OP("VarHandleOp")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
@ -31,16 +64,17 @@ REGISTER_OP("VarHandleOp")
.Attr("shape: shape")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
c->set_output_handle_dtype(0, t);
TensorShapeProto p;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
shape_inference::ShapeHandle s;
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(p, &s));
c->set_output_handle_shape(0, s);
c->set_output_handle_shapes_and_types(0,
std::vector<ShapeAndType>{{s, t}});
return Status::OK();
})
.Doc(R"(
@ -57,19 +91,7 @@ REGISTER_OP("ReadVariableOp")
.Input("resource: resource")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn([](InferenceContext* c) {
DataType handle_dtype = c->input_handle_dtype(0);
DataType value_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
if (handle_dtype != value_dtype) {
return errors::InvalidArgument(
"Trying to read variable with wrong dtype. "
"Expected ",
DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
}
c->set_output(0, c->input_handle_shape(0));
return Status::OK();
})
.SetShapeFn(ReadVariableShapeFn)
.Doc(R"(
Reads the value of a variable.
@ -88,19 +110,7 @@ REGISTER_OP("_UnsafeReadVariable")
.Input("resource: resource")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn([](InferenceContext* c) {
DataType handle_dtype = c->input_handle_dtype(0);
DataType value_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
if (handle_dtype != value_dtype) {
return errors::InvalidArgument(
"Trying to read variable with wrong dtype. "
"Expected ",
DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
}
c->set_output(0, c->input_handle_shape(0));
return Status::OK();
})
.SetShapeFn(ReadVariableShapeFn)
.Doc(R"(
Reads the value of a variable without any memory model.
@ -130,19 +140,13 @@ ignore_lookup_error: whether to ignore the error when the resource
)");
Status CreateAssignShapeFn(InferenceContext* c) {
DataType handle_dtype = c->input_handle_dtype(0);
DataType value_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
if (handle_dtype != value_dtype) {
return errors::InvalidArgument(
"Trying to initialize handle for variable with wrong dtype. "
"Expected ",
DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
}
ShapeHandle s = c->input_handle_shape(0);
ShapeAndType handle_shape_and_type;
TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type));
ShapeHandle value_shape = c->input(1);
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
TF_RETURN_IF_ERROR(
c->Merge(handle_shape_and_type.shape, value_shape, &unused));
return Status::OK();
}
@ -220,18 +224,16 @@ REGISTER_OP("ResourceGather")
.Attr("dtype: type")
.Attr("Tindices: {int32,int64}")
.SetShapeFn([](InferenceContext* c) {
DataType dtype;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &dtype));
if (c->input_handle_dtype(0) != dtype) {
return errors::InvalidArgument(
"Trying to gather from a variable with the wrong dtype.");
}
ShapeAndType handle_shape_and_type;
TF_RETURN_IF_ERROR(
ValidateVariableResourceHandle(c, &handle_shape_and_type));
ShapeHandle unused;
TF_RETURN_IF_ERROR(
c->WithRankAtLeast(c->input_handle_shape(0), 1, &unused));
c->WithRankAtLeast(handle_shape_and_type.shape, 1, &unused));
ShapeHandle params_subshape;
TF_RETURN_IF_ERROR(
c->Subshape(c->input_handle_shape(0), 1, &params_subshape));
c->Subshape(handle_shape_and_type.shape, 1, &params_subshape));
ShapeHandle indices_shape = c->input(1);
ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
@ -264,7 +266,10 @@ REGISTER_OP("ResourceScatterAdd")
.Attr("dtype: numbertype")
.Attr("Tindices: {int32, int64}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle var_shape = c->input_handle_shape(0);
ShapeAndType handle_shape_and_type;
TF_RETURN_IF_ERROR(
ValidateVariableResourceHandle(c, &handle_shape_and_type));
ShapeHandle var_shape = handle_shape_and_type.shape;
ShapeHandle indices_shape = c->input(1);
ShapeHandle unused_updates_shape;

View File

@ -23,11 +23,12 @@ using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
static ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) {
auto h_dtype = c->input_handle_dtype(input);
if (h_dtype == DT_INVALID) {
return c->input(input);
auto* handle_data = c->input_handle_shapes_and_types(input);
if (handle_data != nullptr && !handle_data->empty() &&
(*handle_data)[0].dtype != DT_INVALID) {
return (*handle_data)[0].shape;
}
return c->input_handle_shape(input);
return c->input(input);
}
// Handle the gradient and, if <sparse>, indices inputs.

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
import six.moves
from tensorflow.core.framework import types_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import errors
@ -597,8 +596,7 @@ def call_cpp_shape_fn(op,
# calls the C / C-API directly, we should be able to remove this.
return {
"shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
"handle_shapes": [tensor_shape.TensorShape(None).as_proto()],
"handle_dtypes": [types_pb2.DT_INVALID]
"handle_data": [None]
}
input_tensors_needed = input_tensors_needed or []
@ -642,8 +640,8 @@ def _call_cpp_shape_fn_impl(
r = cpp_shape_inference_pb2.CppShapeInferenceResult()
r.shape.CopyFrom(t.get_shape().as_proto())
# pylint: disable=protected-access
r.handle_shape.CopyFrom(t._handle_shape)
r.handle_dtype = t._handle_dtype
if t._handle_data is not None:
r.handle_data.CopyFrom(t._handle_data)
# pylint: enable=protected-access
return r.SerializeToString()
input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
@ -689,8 +687,9 @@ def _call_cpp_shape_fn_impl(
for s in output_shapes
]
result = [r.shape for r in result_protos]
result_handle_shapes = [r.handle_shape for r in result_protos]
result_handle_dtypes = [r.handle_dtype for r in result_protos]
result_handle_data = [
r.handle_data if r.handle_data.is_set else None for r in result_protos
]
if debug_python_shape_fn:
try:
@ -708,10 +707,11 @@ def _call_cpp_shape_fn_impl(
str(result_as_shapes), str(python_result), str(op.node_def),
",".join([str(i.get_shape()) for i in op.inputs])))
return {"shapes": result,
"handle_shapes": result_handle_shapes,
"handle_dtypes": result_handle_dtypes,
"inputs_needed": output[-1]}
return {
"shapes": result,
"handle_data": result_handle_data,
"inputs_needed": output[-1]
}
# pylint: disable=protected-access
ops._set_call_cpp_shape_fn(call_cpp_shape_fn)

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_func.h"
namespace tensorflow {
namespace swig {
namespace {
@ -71,11 +72,11 @@ Status RunCppShapeInferenceImpl(
// Convert input shapes.
std::vector<TensorShapeProto> input_shapes;
std::vector<TensorShapeProto> input_handle_shapes;
std::vector<DataType> input_handle_dtypes;
std::vector<
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>
input_handle_shapes_and_types;
input_shapes.resize(input_serialized_shapes.size());
input_handle_shapes.resize(input_serialized_shapes.size());
input_handle_dtypes.resize(input_serialized_shapes.size());
input_handle_shapes_and_types.resize(input_serialized_shapes.size());
CppShapeInferenceResult tmp;
for (int i = 0; i < input_serialized_shapes.size(); ++i) {
tmp.Clear();
@ -83,9 +84,17 @@ Status RunCppShapeInferenceImpl(
return errors::InvalidArgument(
"Error parsing shape proto during cpp shape inference");
}
input_shapes[i].Swap(tmp.mutable_shape());
input_handle_dtypes[i] = tmp.handle_dtype();
input_handle_shapes[i].Swap(tmp.mutable_handle_shape());
if (tmp.handle_data().is_set()) {
input_handle_shapes_and_types[i].reset(
new std::vector<std::pair<TensorShapeProto, DataType>>);
auto& v = *input_handle_shapes_and_types[i];
for (const auto& x : tmp.handle_data().shape_and_type()) {
v.emplace_back(x.shape(), x.dtype());
}
}
}
// Convert input tensor values;
@ -116,8 +125,8 @@ Status RunCppShapeInferenceImpl(
// Run shape inference.
tensorflow::shape_inference::InferenceContext c(
graph_def_version, &node, op_reg_data->op_def, input_shapes,
input_tensors, input_tensor_as_shapes_protos, input_handle_shapes,
input_handle_dtypes);
input_tensors, input_tensor_as_shapes_protos,
input_handle_shapes_and_types);
TF_RETURN_IF_ERROR(c.construction_status());
TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
@ -128,9 +137,18 @@ Status RunCppShapeInferenceImpl(
for (int i = 0; i < c.num_outputs(); ++i) {
out.Clear();
ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape());
ProtoFromShapeHandle(c.output_handle_shape(i), &c,
out.mutable_handle_shape());
out.set_handle_dtype(c.output_handle_dtype(i));
const auto* shapes_and_types = c.output_handle_shapes_and_types(i);
if (shapes_and_types != nullptr) {
auto* out_handle_data = out.mutable_handle_data();
out_handle_data->set_is_set(true);
for (const auto& p : *shapes_and_types) {
auto* out_shape_and_type = out_handle_data->add_shape_and_type();
ProtoFromShapeHandle(p.shape, &c, out_shape_and_type->mutable_shape());
out_shape_and_type->set_dtype(p.dtype);
}
}
CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
}

View File

@ -7,9 +7,21 @@ import "tensorflow/core/framework/types.proto";
import "tensorflow/core/framework/tensor_shape.proto";
message CppShapeInferenceResult {
message HandleShapeAndType {
TensorShapeProto shape = 1;
DataType dtype = 2;
}
message HandleData {
bool is_set = 1;
// Only valid if <is_set>.
repeated HandleShapeAndType shape_and_type = 2;
}
TensorShapeProto shape = 1;
TensorShapeProto handle_shape = 2;
DataType handle_dtype = 3;
reserved 2; // was handle_shape
reserved 3; // was handle_dtype
HandleData handle_data = 4;
}
message CppShapeInferenceInputsNeeded {

View File

@ -351,8 +351,7 @@ class _FuncGraph(ops.Graph):
self.extra_inputs.append(x)
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
# pylint: disable=protected-access
ph._handle_shape = x._handle_shape
ph._handle_dtype = x._handle_dtype
ph._handle_data = x._handle_data
# pylint: enable=protected-access
inputs[i] = ph
self._captured[x] = ph

View File

@ -31,8 +31,6 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import device as pydev
@ -323,8 +321,8 @@ class Tensor(_TensorLike):
self._consumers = []
# Attributes used for C++ shape inference. Not inspected, only forwarded.
self._handle_shape = tensor_shape_pb2.TensorShapeProto()
self._handle_dtype = types_pb2.DT_INVALID
# If set, will be a HandleData object from cpp_shape_inference.proto.
self._handle_data = None
@property
def op(self):
@ -1877,12 +1875,10 @@ def set_shapes_for_outputs(op):
# Returned by call_cpp_shape_fn
shapes_dict = shapes
shapes = shapes_dict["shapes"]
handle_shapes = shapes_dict["handle_shapes"]
handle_dtypes = shapes_dict["handle_dtypes"]
for output, handle_shape, handle_dtype in zip(op.outputs, handle_shapes, handle_dtypes):
handle_datas = shapes_dict["handle_data"]
for output, handle_data in zip(op.outputs, handle_datas):
# pylint: disable=protected-access
output._handle_shape = handle_shape
output._handle_dtype = handle_dtype
output._handle_data = handle_data
# pylint: enable=protected-access
if len(op.outputs) != len(shapes):

View File

@ -237,6 +237,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
_ = x_read.eval()
def testShape(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(
name="var1", initial_value=array_ops.ones(shape=[10, 20, 35]))
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
self.assertEqual("(10, 20, 35)", str(v.value().shape))
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
self.assertEqual(
"<unknown>",
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
def testSetInitialValue(self):
with self.test_session():
# Initialize variable with a value different from the initial value passed