Change shape inference so that a single resource tensor can carry
shape and type information for multiple tensors. Apply this to QueueDequeueV2 handled by grappler. PiperOrigin-RevId: 157163757
This commit is contained in:
parent
8f0d0bdca8
commit
eae7e833f1
@ -71,8 +71,7 @@ class _ExperimentalFuncGraph(function._FuncGraph):
|
||||
self.extra_inputs.append(x)
|
||||
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
|
||||
# pylint: disable=protected-access
|
||||
ph._handle_shape = x._handle_shape
|
||||
ph._handle_dtype = x._handle_dtype
|
||||
ph._handle_data = x._handle_data
|
||||
# pylint: enable=protected-access
|
||||
inputs[i] = ph
|
||||
self._captured[x] = ph
|
||||
|
@ -34,8 +34,6 @@ REGISTER_OP("ImageProjectiveTransform")
|
||||
.Output("transformed_images: dtype")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
|
||||
c->set_output_handle_shape(0, c->input_handle_shape(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
|
@ -29,6 +29,7 @@ namespace tensorflow {
|
||||
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeAndType;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
ShapeRefiner::ShapeRefiner(int graph_def_version,
|
||||
@ -49,8 +50,8 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
// indexed by 'node's input.
|
||||
std::vector<Node*> input_nodes(node->num_inputs());
|
||||
std::vector<ShapeHandle> input_shapes(node->num_inputs());
|
||||
std::vector<DataType> input_handle_dtypes(node->num_inputs());
|
||||
std::vector<ShapeHandle> input_handle_shapes(node->num_inputs());
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
|
||||
input_handle_shapes_and_types(node->num_inputs());
|
||||
for (const Edge* e : node->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
|
||||
@ -67,13 +68,13 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
input_nodes[e->dst_input()] = input;
|
||||
input_shapes[e->dst_input()] = c->output(e->src_output());
|
||||
|
||||
// Only propagate handle xshape and dtype of edges which are carrying
|
||||
// resource handles.
|
||||
// Only propagate handle data of edges which are carrying resource handles.
|
||||
if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
|
||||
input_handle_dtypes[e->dst_input()] =
|
||||
c->output_handle_dtype(e->src_output());
|
||||
input_handle_shapes[e->dst_input()] =
|
||||
c->output_handle_shape(e->src_output());
|
||||
const auto* in_v = c->output_handle_shapes_and_types(e->src_output());
|
||||
if (in_v != nullptr) {
|
||||
input_handle_shapes_and_types[e->dst_input()].reset(
|
||||
new std::vector<ShapeAndType>(*in_v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,7 +96,7 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
std::unique_ptr<InferenceContext> c(
|
||||
new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
|
||||
input_shapes, input_tensors, input_tensors_as_shapes,
|
||||
input_handle_shapes, input_handle_dtypes));
|
||||
std::move(input_handle_shapes_and_types)));
|
||||
if (!c->construction_status().ok()) {
|
||||
return c->construction_status();
|
||||
}
|
||||
@ -170,12 +171,11 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) {
|
||||
// Also propagate handle shape and dtype of edges which are carrying
|
||||
// resource handles.
|
||||
if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
|
||||
if (node_context->set_input_handle_dtype(
|
||||
e->dst_input(), c->output_handle_dtype(e->src_output()))) {
|
||||
*refined = true;
|
||||
}
|
||||
if (node_context->MergeInputHandleShape(
|
||||
e->dst_input(), c->output_handle_shape(e->src_output()))) {
|
||||
auto* shapes_and_types =
|
||||
c->output_handle_shapes_and_types(e->src_output());
|
||||
if (shapes_and_types != nullptr &&
|
||||
node_context->MergeInputHandleShapesAndTypes(e->dst_input(),
|
||||
*shapes_and_types)) {
|
||||
*refined = true;
|
||||
}
|
||||
}
|
||||
|
@ -791,8 +791,8 @@ TEST(ShapeRefinerTest, IncrementalUpdates) {
|
||||
// Inject a shape, and incrementally propagate it to the dequeue op.
|
||||
ctx = m.GetContext(queue);
|
||||
shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7});
|
||||
ctx->set_output_handle_shape(0, shp);
|
||||
ctx->set_output_handle_dtype(0, DT_FLOAT);
|
||||
ctx->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{{shp, DT_FLOAT}});
|
||||
|
||||
bool refined = false;
|
||||
TF_ASSERT_OK(m.UpdateNode(dequeue, &refined));
|
||||
|
@ -70,7 +70,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {},
|
||||
{}, {}, {});
|
||||
{}, {});
|
||||
TF_EXPECT_OK(NoOutputs(&c));
|
||||
EXPECT_EQ(0, c.num_outputs());
|
||||
}
|
||||
@ -88,8 +88,7 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
|
||||
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {},
|
||||
{});
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {});
|
||||
TF_EXPECT_OK(ScalarShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(0, c.Rank(output));
|
||||
@ -97,7 +96,7 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({1, 23, 4, 4, 2})}, {}, {}, {}, {});
|
||||
{S({1, 23, 4, 4, 2})}, {}, {}, {});
|
||||
TF_EXPECT_OK(ScalarShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(0, c.Rank(output));
|
||||
@ -125,7 +124,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, 3}), S({3, 4})}, {}, {}, {}, {});
|
||||
{S({2, 3}), S({3, 4})}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -135,7 +134,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
{
|
||||
// Unknown inner dimension for one
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, -1}), S({3, 4})}, {}, {}, {}, {});
|
||||
{S({2, -1}), S({3, 4})}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -145,7 +144,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
{
|
||||
// Invalid rank.
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -156,7 +155,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
{
|
||||
// Unknown outer dimension
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, 3}), S({3, -1})}, {}, {}, {}, {});
|
||||
{S({2, 3}), S({3, -1})}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -166,7 +165,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
{
|
||||
// Inner shapes not compatible
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, 5}), S({3, 4})}, {}, {}, {}, {});
|
||||
{S({2, 5}), S({3, 4})}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -178,7 +177,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
{
|
||||
// Inner shapes not compatible
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {}, {});
|
||||
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -197,7 +196,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({3, 2}), S({3, 4})}, {}, {}, {}, {});
|
||||
{S({3, 2}), S({3, 4})}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -215,7 +214,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, 3}), S({4, 3})}, {}, {}, {}, {});
|
||||
{S({2, 3}), S({4, 3})}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -240,7 +239,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, 10}), S({10})}, {}, {}, {}, {});
|
||||
{S({2, 10}), S({10})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -250,7 +249,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
{
|
||||
// Unknown ranks.
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{Unknown(), Unknown()}, {}, {}, {}, {});
|
||||
{Unknown(), Unknown()}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_FALSE(c.RankKnown(output));
|
||||
@ -259,7 +258,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
{
|
||||
// Rank > 2
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {}, {});
|
||||
{S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
|
||||
@ -273,7 +272,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
|
||||
{S({2, 3, 4, 5}), S({3})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
|
||||
@ -287,7 +286,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
|
||||
{S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
|
||||
@ -301,7 +300,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({10, 11, 12}), S({10})}, {}, {}, {}, {});
|
||||
{S({10, 11, 12}), S({10})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[10,11,12]", c.DebugString(output));
|
||||
@ -310,7 +309,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {},
|
||||
{}, {}, {});
|
||||
{}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
|
||||
@ -323,7 +322,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
}
|
||||
@ -343,7 +342,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {},
|
||||
{}, {});
|
||||
{});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -352,7 +351,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
{
|
||||
// Rank > 2
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -365,7 +364,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||
@ -378,7 +377,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {}, {});
|
||||
{S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||
@ -391,7 +390,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -399,7 +398,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {}, {},
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {},
|
||||
{});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
@ -412,7 +411,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {},
|
||||
{}, {});
|
||||
{});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
}
|
||||
@ -832,7 +831,7 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{Unknown(), Unknown(), Unknown()}, {}, {}, {}, {});
|
||||
{Unknown(), Unknown(), Unknown()}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -845,7 +844,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({-1, -1}), S({-1}), S({-1})}, {}, {}, {}, {});
|
||||
{S({-1, -1}), S({-1}), S({-1})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -858,7 +857,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({-1}), S({-1}), S({-1})}, {}, {}, {}, {});
|
||||
{S({-1}), S({-1}), S({-1})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -872,7 +871,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({4}), S({3})}, {}, {}, {}, {});
|
||||
{S({5, 3}), S({4}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -886,7 +885,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({5}), S({4})}, {}, {}, {}, {});
|
||||
{S({5, 3}), S({5}), S({4})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -900,7 +899,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({-1, 3}), S({5}), S({3})}, {}, {}, {}, {});
|
||||
{S({-1, 3}), S({5}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -913,7 +912,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({-1}), S({3})}, {}, {}, {}, {});
|
||||
{S({5, 3}), S({-1}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -926,7 +925,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({5, -1}), S({5}), S({3})}, {}, {}, {}, {});
|
||||
{S({5, -1}), S({5}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -939,7 +938,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({5}), S({-1})}, {}, {}, {}, {});
|
||||
{S({5, 3}), S({5}), S({-1})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -952,7 +951,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({5}), S({3})}, {}, {}, {}, {});
|
||||
{S({5, 3}), S({5}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
|
@ -33,8 +33,9 @@ InferenceContext::InferenceContext(
|
||||
const std::vector<TensorShapeProto>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<TensorShapeProto>& input_tensors_as_shapes,
|
||||
const std::vector<TensorShapeProto>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes)
|
||||
const std::vector<
|
||||
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
|
||||
input_handle_shapes_and_types)
|
||||
: graph_def_version_(graph_def_version),
|
||||
node_def_(*CHECK_NOTNULL(node_def)) {
|
||||
std::vector<ShapeHandle> input_tensors_as_shape_handles;
|
||||
@ -56,16 +57,26 @@ InferenceContext::InferenceContext(
|
||||
}
|
||||
inputs_.push_back(shape);
|
||||
}
|
||||
std::vector<ShapeHandle> handle_shapes;
|
||||
for (const auto& p : input_handle_shapes) {
|
||||
ShapeHandle shape;
|
||||
construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
|
||||
if (!construction_status_.ok()) {
|
||||
return;
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
|
||||
input_shapes.size());
|
||||
for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
|
||||
const auto& v = input_handle_shapes_and_types[i];
|
||||
if (v == nullptr) {
|
||||
continue;
|
||||
}
|
||||
handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
|
||||
auto& new_v = *handle_data[i];
|
||||
for (int j = 0; j < v->size(); ++j) {
|
||||
const auto& p = (*v)[j];
|
||||
construction_status_.Update(
|
||||
MakeShapeFromShapeProto(p.first, &new_v[j].shape));
|
||||
if (!construction_status_.ok()) {
|
||||
return;
|
||||
}
|
||||
new_v[j].dtype = p.second;
|
||||
}
|
||||
handle_shapes.push_back(shape);
|
||||
}
|
||||
PostInputInit(handle_shapes, input_handle_dtypes);
|
||||
PostInputInit(std::move(handle_data));
|
||||
}
|
||||
|
||||
InferenceContext::InferenceContext(
|
||||
@ -73,14 +84,15 @@ InferenceContext::InferenceContext(
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes,
|
||||
const std::vector<ShapeHandle>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes)
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
|
||||
input_handle_shapes_and_types)
|
||||
: graph_def_version_(graph_def_version),
|
||||
node_def_(*CHECK_NOTNULL(node_def)) {
|
||||
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
|
||||
if (!construction_status_.ok()) return;
|
||||
inputs_ = input_shapes;
|
||||
PostInputInit(input_handle_shapes, input_handle_dtypes);
|
||||
|
||||
PostInputInit(std::move(input_handle_shapes_and_types));
|
||||
}
|
||||
|
||||
InferenceContext::~InferenceContext() {}
|
||||
@ -149,16 +161,11 @@ void InferenceContext::PreInputInit(
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
outputs_.push_back(nullptr);
|
||||
}
|
||||
output_handle_shape_.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
output_handle_shape_.push_back(UnknownShape());
|
||||
}
|
||||
output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID);
|
||||
output_handle_shapes_and_types_.resize(num_outputs);
|
||||
}
|
||||
|
||||
void InferenceContext::PostInputInit(
|
||||
const std::vector<ShapeHandle>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes) {
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
|
||||
int num_inputs_from_node_def = 0;
|
||||
for (const auto& e : input_name_map_) {
|
||||
num_inputs_from_node_def =
|
||||
@ -166,25 +173,16 @@ void InferenceContext::PostInputInit(
|
||||
}
|
||||
|
||||
// Allow passing empty shapes/dtypes to avoid changing every single test.
|
||||
if (input_handle_shapes.empty()) {
|
||||
input_handle_shape_.resize(inputs_.size());
|
||||
if (input_handle_data.empty()) {
|
||||
input_handle_shapes_and_types_.resize(inputs_.size());
|
||||
} else {
|
||||
input_handle_shape_ = input_handle_shapes;
|
||||
if (input_handle_shape_.size() != inputs_.size()) {
|
||||
if (input_handle_data.size() != inputs_.size()) {
|
||||
construction_status_ = errors::InvalidArgument(
|
||||
"Wrong number of handle shapes passed; expected ", inputs_.size(),
|
||||
" got ", input_handle_shape_.size());
|
||||
}
|
||||
}
|
||||
if (input_handle_dtypes.empty()) {
|
||||
input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID);
|
||||
} else {
|
||||
input_handle_dtype_ = input_handle_dtypes;
|
||||
if (input_handle_dtype_.size() != inputs_.size()) {
|
||||
construction_status_ = errors::InvalidArgument(
|
||||
"Wrong number of handle dtypes passed; expected ", inputs_.size(),
|
||||
" got ", input_handle_dtype_.size());
|
||||
" got ", input_handle_data.size());
|
||||
return;
|
||||
}
|
||||
input_handle_shapes_and_types_ = std::move(input_handle_data);
|
||||
}
|
||||
|
||||
if (inputs_.size() != num_inputs_from_node_def) {
|
||||
@ -756,7 +754,7 @@ Status InferenceContext::Multiply(DimensionHandle first,
|
||||
} else if (first_value == kUnknownDim || second_value == kUnknownDim) {
|
||||
*out = UnknownDim();
|
||||
} else {
|
||||
// Invariant: Both values are known and and greater than 1.
|
||||
// Invariant: Both values are known and greater than 1.
|
||||
const int64 product = first_value * second_value;
|
||||
if (product < 0) {
|
||||
return errors::InvalidArgument(
|
||||
@ -847,10 +845,65 @@ Status InferenceContext::AttachContext(const Status& status) {
|
||||
}
|
||||
|
||||
ShapeHandle InferenceContext::input_handle_shape(int idx) {
|
||||
if (!input_handle_shape_[idx].IsSet()) {
|
||||
input_handle_shape_[idx] = UnknownShape();
|
||||
if (input_handle_shapes_and_types_[idx] == nullptr) {
|
||||
input_handle_shapes_and_types_[idx].reset(
|
||||
new std::vector<ShapeAndType>{{UnknownShape(), DT_INVALID}});
|
||||
}
|
||||
return input_handle_shape_[idx];
|
||||
return (*input_handle_shapes_and_types_[idx])[0].shape;
|
||||
}
|
||||
|
||||
bool InferenceContext::MergeHandleShapesAndTypes(
|
||||
const std::vector<ShapeAndType>& shapes_and_types,
|
||||
std::vector<ShapeAndType>* to_update) {
|
||||
if (shapes_and_types.size() != to_update->size()) {
|
||||
return false;
|
||||
}
|
||||
std::vector<ShapeAndType> new_values(shapes_and_types.size());
|
||||
bool refined = false;
|
||||
for (int i = 0; i < shapes_and_types.size(); ++i) {
|
||||
const ShapeAndType& existing = (*to_update)[i];
|
||||
new_values[i].dtype = shapes_and_types[i].dtype;
|
||||
if (new_values[i].dtype != existing.dtype && existing.dtype == DT_INVALID) {
|
||||
refined = true;
|
||||
}
|
||||
if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
|
||||
.ok()) {
|
||||
// merge failed, ignore the new value.
|
||||
new_values[i].shape = existing.shape;
|
||||
}
|
||||
if (!existing.shape.SameHandle(new_values[i].shape)) {
|
||||
refined = true;
|
||||
}
|
||||
}
|
||||
if (!refined) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < new_values.size(); ++i) {
|
||||
(*to_update)[i] = new_values[i];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InferenceContext::MergeOutputHandleShapesAndTypes(
|
||||
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
|
||||
if (output_handle_shapes_and_types_[idx] == nullptr) {
|
||||
output_handle_shapes_and_types_[idx].reset(
|
||||
new std::vector<ShapeAndType>(shapes_and_types));
|
||||
return true;
|
||||
}
|
||||
return MergeHandleShapesAndTypes(shapes_and_types,
|
||||
output_handle_shapes_and_types_[idx].get());
|
||||
}
|
||||
|
||||
bool InferenceContext::MergeInputHandleShapesAndTypes(
|
||||
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
|
||||
if (input_handle_shapes_and_types_[idx] == nullptr) {
|
||||
input_handle_shapes_and_types_[idx].reset(
|
||||
new std::vector<ShapeAndType>(shapes_and_types));
|
||||
return true;
|
||||
}
|
||||
return MergeHandleShapesAndTypes(shapes_and_types,
|
||||
input_handle_shapes_and_types_[idx].get());
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -121,6 +121,14 @@ struct DimensionOrConstant {
|
||||
DimensionOrConstant();
|
||||
};
|
||||
|
||||
struct ShapeAndType {
|
||||
ShapeAndType() {}
|
||||
ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
|
||||
|
||||
ShapeHandle shape;
|
||||
DataType dtype = DT_INVALID;
|
||||
};
|
||||
|
||||
// Shape inference functions registered on ops in REGISTER_OP implement
|
||||
// their shape functions in terms of this InferenceContext. An InferenceContext
|
||||
// is created by the framework and passed to a shape inference function. The
|
||||
@ -149,26 +157,28 @@ class InferenceContext {
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes,
|
||||
const std::vector<ShapeHandle>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes);
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
|
||||
input_handle_shapes_and_types);
|
||||
|
||||
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
||||
//
|
||||
// Elements of <input_tensors_as_shapes> are used for when a shape function
|
||||
// makes a call to MakeShapeFromShapeTensor; in particular, when the
|
||||
// input_tensors[i] is nullptr but the shape represented by it is partially
|
||||
// known from analysis of the graph.
|
||||
// <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
|
||||
// Values of <input_tensors_as_shapes> do not need to outlive the context.
|
||||
// Elements of <input_tensors_as_shapes> are used for when a shape
|
||||
// function makes a call to MakeShapeFromShapeTensor; in particular, when
|
||||
// the input_tensors[i] is nullptr but the shape represented by it is
|
||||
// partially known from analysis of the graph. <input_tensors_as_shapes>
|
||||
// can have fewer elements than <input_shapes>. Values of
|
||||
// <input_tensors_as_shapes> do not need to outlive the context.
|
||||
//
|
||||
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
|
||||
InferenceContext(int graph_def_version, const NodeDef* node_def,
|
||||
const OpDef& op_def,
|
||||
const std::vector<TensorShapeProto>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<TensorShapeProto>& input_tensors_as_shapes,
|
||||
const std::vector<TensorShapeProto>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes);
|
||||
// REQUIRES: <node_def> is not NULL, and must outlive the
|
||||
// InferenceContext.
|
||||
InferenceContext(
|
||||
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<TensorShapeProto>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<TensorShapeProto>& input_tensors_as_shapes,
|
||||
const std::vector<
|
||||
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
|
||||
input_handle_shapes_and_types);
|
||||
|
||||
~InferenceContext();
|
||||
|
||||
@ -441,70 +451,62 @@ class InferenceContext {
|
||||
// propagate that information. Output handle dtypes and shapes are ignored if
|
||||
// the output tensor is not of type DT_RESOURCE.
|
||||
|
||||
// Merge the stored shape corresponding to the input handle in position idx
|
||||
// with the specified shape. This requires idx to be in the [0, num_inputs)
|
||||
// range. If the merge is successful and the new shape differs from the old
|
||||
// one, store the new shape and return true. Return false otherwise.
|
||||
bool MergeInputHandleShape(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(input_handle_shape_[idx], shape, &new_shape).ok() ||
|
||||
input_handle_shape_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
}
|
||||
input_handle_shape_[idx] = shape;
|
||||
return true;
|
||||
// Merge the stored shapes and types corresponding to the input handle in
|
||||
// position idx with the specified shapes and types. This requires idx to be
|
||||
// in the [0, num_inputs) range.
|
||||
//
|
||||
// If the merge is successful and any of the new shapes differs from the old
|
||||
// one, or any of the old dtypes was DT_INVALID, store the new shapes and
|
||||
// return true. Return false otherwise.
|
||||
bool MergeInputHandleShapesAndTypes(
|
||||
int idx,
|
||||
const std::vector<ShapeAndType>& shapes_and_types) TF_MUST_USE_RESULT;
|
||||
|
||||
// As MergeInputHandleShapesAndTypes, but for an output.
|
||||
bool MergeOutputHandleShapesAndTypes(
|
||||
int idx, const std::vector<ShapeAndType>& shapes) TF_MUST_USE_RESULT;
|
||||
|
||||
// Returns the output handle shapes and types, for the resource tensor output
|
||||
// at index <idx>. Returns NULL if the shape and types were never set.
|
||||
const std::vector<ShapeAndType>* output_handle_shapes_and_types(int idx) {
|
||||
return output_handle_shapes_and_types_[idx].get();
|
||||
}
|
||||
|
||||
// Set the type corresponding to the resource in position idx. This requires
|
||||
// idx to be in the [0, num_inputs) range. Returns true iff the stored type
|
||||
// has been updated.
|
||||
bool set_input_handle_dtype(int idx, DataType dtype) {
|
||||
if (input_handle_dtype_[idx] != dtype) {
|
||||
input_handle_dtype_[idx] = dtype;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
// Returns the inputs handle shapes and types, for the resource tensor output
|
||||
// at index <idx>. Returns NULL if the shape and types were not available.
|
||||
const std::vector<ShapeAndType>* input_handle_shapes_and_types(int idx) {
|
||||
return input_handle_shapes_and_types_[idx].get();
|
||||
}
|
||||
|
||||
// DEPRECATED: use input_handle_shapes_and_types.
|
||||
ShapeHandle input_handle_shape(int idx);
|
||||
// DEPRECATED: use input_handle_shapes_and_types.
|
||||
DataType input_handle_dtype(int idx) const {
|
||||
return input_handle_dtype_[idx];
|
||||
}
|
||||
|
||||
// Merge the stored shape corresponding to the output handle in position idx
|
||||
// with the specified shape. This requires idx to be in the [0, num_outputs)
|
||||
// range. If the merge is successful and the new shape differs from the old
|
||||
// one, store the new shape and return true. Return false otherwise.
|
||||
|
||||
bool MergeOutputHandleShape(int idx, ShapeHandle shape) {
|
||||
ShapeHandle new_shape;
|
||||
if (!Merge(output_handle_shape_[idx], shape, &new_shape).ok() ||
|
||||
output_handle_shape_[idx].SameHandle(new_shape)) {
|
||||
return false;
|
||||
if (input_handle_shapes_and_types_[idx] == nullptr) {
|
||||
return DT_INVALID;
|
||||
} else {
|
||||
DCHECK_EQ(input_handle_shapes_and_types_[idx]->size(), 1);
|
||||
return (*input_handle_shapes_and_types_[idx])[0].dtype;
|
||||
}
|
||||
output_handle_shape_[idx] = shape;
|
||||
return true;
|
||||
}
|
||||
// Overwrite the shape corresponding to the output handle in position idx with
|
||||
// the specified shape.
|
||||
void set_output_handle_shape(int idx, ShapeHandle shape) {
|
||||
output_handle_shape_[idx] = shape;
|
||||
}
|
||||
|
||||
// Set the type corresponding to the resource in position idx. This requires
|
||||
// idx to be in the [0, num_outputs) range. Returns true iff the stored type
|
||||
// has been updated.
|
||||
bool set_output_handle_dtype(int idx, DataType dtype) {
|
||||
if (output_handle_dtype_[idx] != dtype) {
|
||||
output_handle_dtype_[idx] = dtype;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
void set_output_handle_shapes_and_types(
|
||||
int idx, const std::vector<ShapeAndType>& shapes_and_types) {
|
||||
output_handle_shapes_and_types_[idx].reset(
|
||||
new std::vector<ShapeAndType>(shapes_and_types));
|
||||
}
|
||||
ShapeHandle output_handle_shape(int idx) const {
|
||||
return output_handle_shape_[idx];
|
||||
|
||||
// DEPRECATED: use output_handle_shapes_and_types.
|
||||
ShapeHandle output_handle_shape(int idx) {
|
||||
return output_handle_shapes_and_types_[idx] == nullptr
|
||||
? UnknownShape()
|
||||
: (*output_handle_shapes_and_types_[idx])[0].shape;
|
||||
}
|
||||
// DEPRECATED: use output_handle_shapes_and_types.
|
||||
DataType output_handle_dtype(int idx) const {
|
||||
return output_handle_dtype_[idx];
|
||||
return output_handle_shapes_and_types_[idx] == nullptr
|
||||
? DT_INVALID
|
||||
: (*output_handle_shapes_and_types_[idx])[0].dtype;
|
||||
}
|
||||
|
||||
// Note that shape functions should usually call MakeShapeFromShapeTensor,
|
||||
@ -555,8 +557,8 @@ class InferenceContext {
|
||||
void PreInputInit(const OpDef& op_def,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes);
|
||||
void PostInputInit(const std::vector<ShapeHandle>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes);
|
||||
void PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
|
||||
input_handle_data);
|
||||
|
||||
DimensionHandle GetDimension(const DimensionOrConstant& d);
|
||||
|
||||
@ -573,6 +575,12 @@ class InferenceContext {
|
||||
// Adds additional context to the given status.
|
||||
Status AttachContext(const Status& status);
|
||||
|
||||
// Used to implement MergeInputHandleShapesAndTypes and
|
||||
// MergeOutputHandleShapesAndTypes.
|
||||
bool MergeHandleShapesAndTypes(
|
||||
const std::vector<ShapeAndType>& shapes_and_types,
|
||||
std::vector<ShapeAndType>* to_update) TF_MUST_USE_RESULT;
|
||||
|
||||
ShapeManager shape_manager_;
|
||||
|
||||
// inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
|
||||
@ -585,10 +593,19 @@ class InferenceContext {
|
||||
std::vector<ShapeHandle> input_tensors_as_shapes_;
|
||||
std::vector<bool> requested_input_tensor_as_partial_shape_;
|
||||
|
||||
std::vector<ShapeHandle> input_handle_shape_;
|
||||
std::vector<DataType> input_handle_dtype_;
|
||||
std::vector<ShapeHandle> output_handle_shape_;
|
||||
std::vector<DataType> output_handle_dtype_;
|
||||
// input_handle_shapes_and_types_[i] is the list of shape/type pairs available
|
||||
// through the resource handle passed along input i of the node.
|
||||
//
|
||||
// Values may be NULL.
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
|
||||
input_handle_shapes_and_types_;
|
||||
|
||||
// output_handle_shapes_and_types_[i] is the list of shape/type pairs
|
||||
// available through the resource handle passed along output i of the node.
|
||||
//
|
||||
// Values may be NULL.
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
|
||||
output_handle_shapes_and_types_;
|
||||
|
||||
const int graph_def_version_;
|
||||
const NodeDef& node_def_;
|
||||
|
@ -61,6 +61,7 @@ class ShapeInferenceTest : public ::testing::Test {
|
||||
bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
|
||||
bool IsSet(DimensionHandle d) { return d.IsSet(); }
|
||||
bool IsSet(ShapeHandle s) { return s.IsSet(); }
|
||||
void TestMergeHandles(bool input_not_output);
|
||||
|
||||
static const int kVersion = 0; // used for graph-def version.
|
||||
};
|
||||
@ -74,7 +75,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(&def);
|
||||
InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
|
||||
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
|
||||
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
|
||||
@ -110,8 +111,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {},
|
||||
{});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
|
||||
EXPECT_EQ(InferenceContext::kUnknownDim,
|
||||
c.Value(InferenceContext::kUnknownDim));
|
||||
EXPECT_EQ(1, c.Value(1));
|
||||
@ -126,7 +126,7 @@ TEST_F(ShapeInferenceTest, Run) {
|
||||
NodeDef def;
|
||||
def.set_name("foo");
|
||||
def.set_op("foo_op");
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
|
||||
{
|
||||
@ -166,7 +166,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
// Error when no constant tensors were requested.
|
||||
{
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
|
||||
{}, {});
|
||||
{});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
auto fn = [](InferenceContext* c) {
|
||||
ShapeHandle h;
|
||||
@ -185,8 +185,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
Tensor input_t =
|
||||
::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
{S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {},
|
||||
{});
|
||||
{S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
auto fn = [](InferenceContext* c) {
|
||||
c->input_tensor(0); // get this one, but it's null - won't be in error.
|
||||
@ -208,7 +207,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
{
|
||||
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
|
||||
{nullptr, &input_t}, {}, {}, {});
|
||||
{nullptr, &input_t}, {}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
auto fn = [](InferenceContext* c) {
|
||||
ShapeHandle s;
|
||||
@ -231,8 +230,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
{
|
||||
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
|
||||
{nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {},
|
||||
{});
|
||||
{nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
auto fn = [](InferenceContext* c) {
|
||||
ShapeHandle s;
|
||||
@ -255,7 +253,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
|
||||
{Unknown(), S({1, -1, 3}), S({})}, {}, {}, {}, {});
|
||||
{Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(2, c.num_outputs());
|
||||
|
||||
@ -296,8 +294,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
|
||||
TEST_F(ShapeInferenceTest, NumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
|
||||
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {},
|
||||
{});
|
||||
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
|
||||
|
||||
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
|
||||
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
|
||||
@ -311,7 +308,7 @@ TEST_F(ShapeInferenceTest, NumElements) {
|
||||
TEST_F(ShapeInferenceTest, WithRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -350,7 +347,7 @@ TEST_F(ShapeInferenceTest, WithRank) {
|
||||
TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -389,7 +386,7 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {}, {});
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -427,8 +424,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithValue) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {},
|
||||
{});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
|
||||
|
||||
auto d0 = c.Dim(c.input(0), 0);
|
||||
auto d1 = c.Dim(c.input(0), 1);
|
||||
@ -470,7 +466,7 @@ TEST_F(ShapeInferenceTest, WithValue) {
|
||||
TEST_F(ShapeInferenceTest, MergeDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
|
||||
auto d2 = c.Dim(c.input(0), 0);
|
||||
auto d_unknown = c.Dim(c.input(0), 1);
|
||||
@ -519,7 +515,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
|
||||
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
|
||||
Unknown(), S({1})},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
|
||||
auto s_unknown = c.input(0);
|
||||
auto s_1_2 = c.input(1);
|
||||
@ -595,7 +591,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
||||
{
|
||||
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
|
||||
},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
|
||||
auto s_unknown = c.input(0);
|
||||
auto s_u_2 = c.input(1);
|
||||
@ -648,7 +644,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
||||
TEST_F(ShapeInferenceTest, Subshape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
{S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {}, {});
|
||||
{S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
|
||||
|
||||
ShapeHandle unknown = c.input(1);
|
||||
ShapeHandle out;
|
||||
@ -723,7 +719,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
|
||||
TEST_F(ShapeInferenceTest, Concatenate) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
|
||||
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {});
|
||||
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -750,7 +746,7 @@ TEST_F(ShapeInferenceTest, Concatenate) {
|
||||
TEST_F(ShapeInferenceTest, ReplaceDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
|
||||
{}, {}, {}, {});
|
||||
{}, {}, {});
|
||||
|
||||
auto in = c.input(0);
|
||||
auto unknown = c.input(1);
|
||||
@ -782,7 +778,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
|
||||
TEST_F(ShapeInferenceTest, MakeShape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
|
||||
{}, {}, {});
|
||||
{}, {});
|
||||
|
||||
std::vector<DimensionHandle> dims;
|
||||
auto in0 = c.input(0);
|
||||
@ -807,7 +803,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShape) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto u0 = c.UnknownShape();
|
||||
auto u1 = c.UnknownShape();
|
||||
@ -819,7 +815,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
|
||||
TEST_F(ShapeInferenceTest, Scalar) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto s0 = c.Scalar();
|
||||
EXPECT_EQ("[]", c.DebugString(s0));
|
||||
@ -830,7 +826,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
|
||||
TEST_F(ShapeInferenceTest, Vector) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto s0 = c.Vector(1);
|
||||
EXPECT_EQ("[1]", c.DebugString(s0));
|
||||
@ -846,7 +842,7 @@ TEST_F(ShapeInferenceTest, Vector) {
|
||||
TEST_F(ShapeInferenceTest, Matrix) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto s0 = c.Matrix(1, 2);
|
||||
EXPECT_EQ("[1,2]", c.DebugString(s0));
|
||||
@ -869,7 +865,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
auto create = [&](Tensor* t) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
|
||||
{}, {});
|
||||
{});
|
||||
ShapeHandle out;
|
||||
Status s = c.MakeShapeFromShapeTensor(0, &out);
|
||||
if (s.ok()) {
|
||||
@ -922,7 +918,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
{
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
|
||||
{}, {}, {});
|
||||
{}, {});
|
||||
ShapeHandle out;
|
||||
EXPECT_EQ("Shape must be rank 1 but is rank 2",
|
||||
c.MakeShapeFromShapeTensor(0, &out).error_message());
|
||||
@ -932,7 +928,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
// With an unknown rank.
|
||||
ShapeHandle out;
|
||||
@ -951,7 +947,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
ShapeHandle out;
|
||||
TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
|
||||
@ -965,7 +961,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
TensorShapeProto proto;
|
||||
|
||||
// With a set unknown rank.
|
||||
@ -1001,7 +997,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
||||
TEST_F(ShapeInferenceTest, MakeDim) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto d0 = c.MakeDim(1);
|
||||
auto d1 = c.MakeDim(1);
|
||||
@ -1015,7 +1011,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
|
||||
TEST_F(ShapeInferenceTest, UnknownDim) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto d0 = c.UnknownDim();
|
||||
auto d1 = c.UnknownDim();
|
||||
@ -1027,7 +1023,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
|
||||
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
|
||||
@ -1041,7 +1037,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
|
||||
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
|
||||
{&t1, &t2}, {}, {}, {});
|
||||
{&t1, &t2}, {}, {});
|
||||
|
||||
EXPECT_TRUE(c.input_tensor(0) == &t1);
|
||||
EXPECT_TRUE(c.input_tensor(1) == &t2);
|
||||
@ -1053,7 +1049,7 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
|
||||
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
|
||||
{&t1, &t2}, {}, {}, {});
|
||||
{&t1, &t2}, {}, {});
|
||||
|
||||
DimensionHandle d;
|
||||
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
|
||||
@ -1084,7 +1080,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
|
||||
.ok());
|
||||
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {});
|
||||
string value;
|
||||
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
|
||||
EXPECT_EQ("bar", value);
|
||||
@ -1093,7 +1089,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
|
||||
TEST_F(ShapeInferenceTest, Divide) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
|
||||
{}, {}, {});
|
||||
{}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1156,7 +1152,7 @@ TEST_F(ShapeInferenceTest, Divide) {
|
||||
TEST_F(ShapeInferenceTest, Add) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
|
||||
{}, {});
|
||||
{});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1208,7 +1204,7 @@ TEST_F(ShapeInferenceTest, Add) {
|
||||
TEST_F(ShapeInferenceTest, Subtract) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
|
||||
{}, {}, {});
|
||||
{}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1258,7 +1254,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
|
||||
TEST_F(ShapeInferenceTest, Multiply) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
|
||||
{}, {}, {});
|
||||
{}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1311,7 +1307,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
|
||||
TEST_F(ShapeInferenceTest, FullyDefined) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
// No rank or missing dimension information should return false.
|
||||
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
|
||||
@ -1325,7 +1321,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
|
||||
TEST_F(ShapeInferenceTest, Min) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
|
||||
{}, {}, {});
|
||||
{}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_1 = c.Dim(s, 0);
|
||||
@ -1374,7 +1370,7 @@ TEST_F(ShapeInferenceTest, Min) {
|
||||
TEST_F(ShapeInferenceTest, Max) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
|
||||
{}, {});
|
||||
{});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_1 = c.Dim(s, 0);
|
||||
@ -1409,5 +1405,116 @@ TEST_F(ShapeInferenceTest, Max) {
|
||||
EXPECT_TRUE(SameHandle(d_2, out));
|
||||
}
|
||||
|
||||
void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
|
||||
{});
|
||||
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
||||
ShapeHandle s;
|
||||
TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s));
|
||||
return s;
|
||||
};
|
||||
auto get_shapes_and_types_from_context = [&](int idx) {
|
||||
if (input_not_output) {
|
||||
return c.input_handle_shapes_and_types(idx);
|
||||
} else {
|
||||
return c.output_handle_shapes_and_types(idx);
|
||||
}
|
||||
};
|
||||
auto merge_shapes_and_types_to_context =
|
||||
[&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
|
||||
if (input_not_output) {
|
||||
return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
|
||||
} else {
|
||||
return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
|
||||
}
|
||||
};
|
||||
|
||||
EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
|
||||
EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
|
||||
|
||||
// First merge will take the input completely.
|
||||
std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
|
||||
{c.UnknownShape(), DT_INVALID},
|
||||
{make_shape({4, 3, 2, 1}), DT_INT32}};
|
||||
ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
|
||||
ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
|
||||
std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
|
||||
ASSERT_EQ(3, v.size());
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
|
||||
EXPECT_EQ(t[i].dtype, v[i].dtype);
|
||||
}
|
||||
|
||||
// Merge that fails because wrong number of values passed.
|
||||
// Fails, and no changes made.
|
||||
ASSERT_FALSE(merge_shapes_and_types_to_context(
|
||||
0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
|
||||
v = *get_shapes_and_types_from_context(0);
|
||||
ASSERT_EQ(3, v.size());
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
|
||||
EXPECT_EQ(t[i].dtype, v[i].dtype);
|
||||
}
|
||||
|
||||
// Only difference is in a mismatched shape. That is ignored,
|
||||
// and there are no other changes, so nothing is done.
|
||||
//
|
||||
// TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
|
||||
// return an error (separate error from 'refined' output)?
|
||||
auto t2 = t;
|
||||
t2[2].shape = make_shape({4, 3, 4, 1});
|
||||
ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
|
||||
v = *get_shapes_and_types_from_context(0);
|
||||
ASSERT_EQ(3, v.size());
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
|
||||
EXPECT_EQ(t[i].dtype, v[i].dtype);
|
||||
}
|
||||
|
||||
// Only difference is in a mismatched dtype. That is ignored,
|
||||
// and there are no other changes, so nothing is done.
|
||||
t2 = t;
|
||||
t2[2].dtype = DT_FLOAT;
|
||||
ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
|
||||
v = *get_shapes_and_types_from_context(0);
|
||||
ASSERT_EQ(3, v.size());
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
|
||||
EXPECT_EQ(t[i].dtype, v[i].dtype);
|
||||
}
|
||||
|
||||
// Difference is mergeable (new shape).
|
||||
t[1].shape = make_shape({1, 10});
|
||||
ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
|
||||
v = *get_shapes_and_types_from_context(0);
|
||||
ASSERT_EQ(3, v.size());
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
|
||||
EXPECT_EQ(t[i].dtype, v[i].dtype);
|
||||
}
|
||||
|
||||
// Difference is mergeable (new type).
|
||||
t[1].dtype = DT_DOUBLE;
|
||||
ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
|
||||
v = *get_shapes_and_types_from_context(0);
|
||||
ASSERT_EQ(3, v.size());
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
|
||||
EXPECT_EQ(t[i].dtype, v[i].dtype);
|
||||
}
|
||||
|
||||
// No difference.
|
||||
ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) {
|
||||
TestMergeHandles(true);
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
|
||||
TestMergeHandles(false);
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
@ -45,7 +45,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
|
||||
|
||||
shape_inference::InferenceContext c(op.graph_def_version, &op.node_def,
|
||||
op_reg_data->op_def, in_shapes,
|
||||
op.input_tensors, {}, {}, {});
|
||||
op.input_tensors, {}, {});
|
||||
TF_RETURN_IF_ERROR(c.construction_status());
|
||||
if (op_reg_data->shape_inference_fn == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
|
@ -26,6 +26,38 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeAndType;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
namespace {
|
||||
|
||||
// Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into
|
||||
// <*queue_shapes_and_types>.
|
||||
Status MergeEnqueueShapesAndTypes(
|
||||
const std::vector<ShapeAndType>& shapes_and_types, InferenceContext* qctx,
|
||||
std::vector<ShapeAndType>* queue_shapes_and_types) {
|
||||
if (shapes_and_types.size() != queue_shapes_and_types->size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
|
||||
" vs ", queue_shapes_and_types->size());
|
||||
}
|
||||
for (int i = 0; i < shapes_and_types.size(); ++i) {
|
||||
const ShapeAndType& a = shapes_and_types[i];
|
||||
ShapeAndType& b = (*queue_shapes_and_types)[i];
|
||||
if (a.dtype != b.dtype) {
|
||||
return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
|
||||
i, ": ", DataTypeString(a.dtype), " vs ",
|
||||
DataTypeString(b.dtype));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(qctx->Merge(a.shape, b.shape, &b.shape));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GraphProperties::InferStatically() {
|
||||
Graph graph(OpRegistry::Global());
|
||||
ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
|
||||
@ -60,32 +92,49 @@ Status GraphProperties::InferStatically() {
|
||||
if (!qctx) {
|
||||
continue;
|
||||
}
|
||||
DataType queue_type = qctx->output_handle_dtype(0);
|
||||
shape_inference::ShapeHandle queue_shp = qctx->output_handle_shape(0);
|
||||
if (qctx->FullyDefined(queue_shp) && queue_type != DT_INVALID) {
|
||||
continue;
|
||||
|
||||
// Check to see if the shape is fully defined.
|
||||
auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
|
||||
if (queue_handle_data != nullptr) {
|
||||
bool fully_defined = true;
|
||||
for (const auto& shape_and_type : *queue_handle_data) {
|
||||
if (!qctx->FullyDefined(shape_and_type.shape) ||
|
||||
shape_and_type.dtype == DT_INVALID) {
|
||||
fully_defined = false;
|
||||
}
|
||||
}
|
||||
if (fully_defined) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ShapeAndType> queue_shapes_and_types;
|
||||
if (queue_handle_data != nullptr) {
|
||||
queue_shapes_and_types = *queue_handle_data;
|
||||
}
|
||||
for (const auto& node : resource_data.second) {
|
||||
auto ctx = shape_refiner.GetContext(node);
|
||||
if (!ctx) {
|
||||
continue;
|
||||
}
|
||||
if (node->type_string().find("Enqueue") != std::string::npos) {
|
||||
if (ctx->num_inputs() == 2) {
|
||||
const DataType dtype = node->input_type(1);
|
||||
if (queue_type == DT_INVALID) {
|
||||
queue_type = dtype;
|
||||
} else {
|
||||
CHECK_EQ(queue_type, dtype);
|
||||
}
|
||||
shape_inference::ShapeHandle shp = ctx->input(1);
|
||||
TF_RETURN_IF_ERROR(qctx->Merge(queue_shp, shp, &queue_shp));
|
||||
// TODO(bsteiner): handle EnqueueMany as well.
|
||||
if (node->type_string().find("Enqueue") != std::string::npos &&
|
||||
node->type_string().find("EnqueueMany") == std::string::npos) {
|
||||
std::vector<ShapeAndType> shapes_and_types;
|
||||
for (int i = 1; i < ctx->num_inputs(); ++i) {
|
||||
shapes_and_types.push_back({ctx->input(i), node->input_type(i)});
|
||||
}
|
||||
|
||||
if (queue_shapes_and_types.empty()) {
|
||||
queue_shapes_and_types = shapes_and_types;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(MergeEnqueueShapesAndTypes(
|
||||
shapes_and_types, qctx, &queue_shapes_and_types));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (qctx->set_output_handle_dtype(0, queue_type) |
|
||||
qctx->MergeOutputHandleShape(0, queue_shp)) {
|
||||
if (!queue_shapes_and_types.empty() &&
|
||||
qctx->MergeOutputHandleShapesAndTypes(0, queue_shapes_and_types)) {
|
||||
new_shapes.push(qnode);
|
||||
}
|
||||
}
|
||||
@ -115,7 +164,7 @@ Status GraphProperties::InferStatically() {
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
OpInfo::TensorProperties properties;
|
||||
properties.set_dtype(node->input_type(i));
|
||||
shape_inference::ShapeHandle shp = ctx->input(i);
|
||||
ShapeHandle shp = ctx->input(i);
|
||||
if (!ctx->RankKnown(shp)) {
|
||||
properties.mutable_shape()->set_unknown_rank(true);
|
||||
} else {
|
||||
@ -135,7 +184,7 @@ Status GraphProperties::InferStatically() {
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
OpInfo::TensorProperties properties;
|
||||
properties.set_dtype(node->output_type(i));
|
||||
shape_inference::ShapeHandle shp = ctx->output(i);
|
||||
ShapeHandle shp = ctx->output(i);
|
||||
if (!ctx->RankKnown(shp)) {
|
||||
properties.mutable_shape()->set_unknown_rank(true);
|
||||
} else {
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -38,6 +39,22 @@ class GraphPropertiesTest : public ::testing::Test {
|
||||
void TearDown() override { cluster_.reset(); }
|
||||
|
||||
protected:
|
||||
// Returns a string form of <p>, suitable for comparing type and shape.
|
||||
// Example output for 4-d float tensor: "float: [10,2,30,4]"
|
||||
string PropToString(const OpInfo::TensorProperties& p) {
|
||||
string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
|
||||
if (p.shape().unknown_rank()) {
|
||||
strings::StrAppend(&s, "?");
|
||||
} else {
|
||||
strings::StrAppend(&s, "[");
|
||||
for (int i = 0; i < p.shape().dim_size(); ++i) {
|
||||
strings::StrAppend(&s, i == 0 ? "" : ",", p.shape().dim(i).size());
|
||||
}
|
||||
strings::StrAppend(&s, "]");
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
std::unique_ptr<SingleMachine> cluster_;
|
||||
};
|
||||
|
||||
@ -194,6 +211,20 @@ TEST_F(GraphPropertiesTest, Queues) {
|
||||
auto dequeue4 =
|
||||
ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
|
||||
|
||||
// Create a queue that takes in three tensors.
|
||||
auto q5 = ops::RandomShuffleQueue(
|
||||
root.WithOpName("Queue5"),
|
||||
{DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
|
||||
Output rnd2 =
|
||||
ops::RandomNormal(root.WithOpName("rnd"), {10}, DataType::DT_DOUBLE);
|
||||
Output rnd3 =
|
||||
ops::RandomNormal(root.WithOpName("rnd"), {1, 2, 3}, DataType::DT_FLOAT);
|
||||
auto enqueue5 =
|
||||
ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
|
||||
auto dequeue5 = ops::QueueDequeue(
|
||||
root.WithOpName("Dequeue5"), q5,
|
||||
{DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(root.ToGraphDef(&item.graph));
|
||||
|
||||
@ -201,34 +232,31 @@ TEST_F(GraphPropertiesTest, Queues) {
|
||||
TF_CHECK_OK(properties.InferStatically());
|
||||
|
||||
const auto props1 = properties.GetOutputProperties("Dequeue1");
|
||||
EXPECT_EQ(1, props1.size());
|
||||
const OpInfo::TensorProperties& prop1 = props1[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop1.dtype());
|
||||
EXPECT_FALSE(prop1.shape().unknown_rank());
|
||||
EXPECT_EQ(2, prop1.shape().dim_size());
|
||||
EXPECT_EQ(3, prop1.shape().dim(0).size());
|
||||
EXPECT_EQ(7, prop1.shape().dim(1).size());
|
||||
ASSERT_EQ(1, props1.size());
|
||||
EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
|
||||
|
||||
const auto props2 = properties.GetOutputProperties("Dequeue2");
|
||||
EXPECT_EQ(1, props2.size());
|
||||
const OpInfo::TensorProperties& prop2 = props2[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop2.dtype());
|
||||
EXPECT_FALSE(prop2.shape().unknown_rank());
|
||||
EXPECT_EQ(2, prop2.shape().dim_size());
|
||||
EXPECT_EQ(3, prop2.shape().dim(0).size());
|
||||
EXPECT_EQ(7, prop2.shape().dim(1).size());
|
||||
ASSERT_EQ(1, props2.size());
|
||||
EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
|
||||
|
||||
// The dequeue3 op shape is unknown.
|
||||
const auto props3 = properties.GetOutputProperties("Dequeue3");
|
||||
ASSERT_EQ(1, props3.size());
|
||||
EXPECT_EQ("float: ?", PropToString(props3[0]));
|
||||
|
||||
// The dequeue3 op shape is unknown. The square2 op shape is known. Verify
|
||||
// that we merge the 2 properly to determine the shape of the data coming out
|
||||
// of the queue.
|
||||
const auto props4 = properties.GetOutputProperties("Dequeue4");
|
||||
EXPECT_EQ(1, props4.size());
|
||||
const OpInfo::TensorProperties& prop4 = props4[0];
|
||||
EXPECT_EQ(DT_FLOAT, prop4.dtype());
|
||||
EXPECT_FALSE(prop4.shape().unknown_rank());
|
||||
EXPECT_EQ(2, prop4.shape().dim_size());
|
||||
EXPECT_EQ(3, prop4.shape().dim(0).size());
|
||||
EXPECT_EQ(7, prop4.shape().dim(1).size());
|
||||
ASSERT_EQ(1, props4.size());
|
||||
EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
|
||||
|
||||
// The dequeue5 op shape is known.
|
||||
const auto props5 = properties.GetOutputProperties("Dequeue5");
|
||||
ASSERT_EQ(3, props5.size());
|
||||
EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
|
||||
EXPECT_EQ("double: [10]", PropToString(props5[1]));
|
||||
EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, Loops) {
|
||||
|
@ -1534,8 +1534,10 @@ REGISTER_OP("Identity")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
|
||||
c->set_output_handle_shape(0, c->input_handle_shape(0));
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr) {
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"Doc(
|
||||
@ -1551,8 +1553,10 @@ REGISTER_OP("_MklIdentity")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
|
||||
c->set_output_handle_shape(0, c->input_handle_shape(0));
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr) {
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"Doc( Mkl implementation of IdentityOp
|
||||
|
@ -162,9 +162,15 @@ TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
|
||||
// Check that handle dtypes are preserved.
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
|
||||
std::vector<
|
||||
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>
|
||||
handle_data;
|
||||
handle_data.emplace_back(
|
||||
new std::vector<std::pair<TensorShapeProto, DataType>>{
|
||||
{TensorShapeProto(), DT_BOOL}});
|
||||
shape_inference::InferenceContext c(TF_GRAPH_DEF_VERSION, &op.node_def,
|
||||
op_reg_data->op_def, {TensorShapeProto()},
|
||||
{}, {}, {}, {DT_BOOL});
|
||||
{}, {}, handle_data);
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
|
||||
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
|
||||
|
@ -32,10 +32,11 @@ Status SwitchShape(InferenceContext* c) {
|
||||
c->set_output(1, out);
|
||||
|
||||
// Handle resource shape / dtype.
|
||||
c->set_output_handle_shape(0, c->input_handle_shape(0));
|
||||
c->set_output_handle_shape(1, c->input_handle_shape(0));
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
|
||||
c->set_output_handle_dtype(1, c->input_handle_dtype(0));
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr) {
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
c->set_output_handle_shapes_and_types(1, *handle_data);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
@ -200,8 +201,10 @@ REGISTER_OP("Enter")
|
||||
c->set_output(0, c->UnknownShape());
|
||||
|
||||
// Handle resource shape / dtype, if present.
|
||||
c->set_output_handle_shape(0, c->input_handle_shape(0));
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data != nullptr) {
|
||||
c->set_output_handle_shapes_and_types(0, *handle_data);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
})
|
||||
|
@ -644,15 +644,15 @@ REGISTER_OP("QueueDequeueV2")
|
||||
.Attr("component_types: list(type) >= 1")
|
||||
.Attr("timeout_ms: int = -1")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
if (c->num_outputs() == 1) {
|
||||
c->set_output(0, c->input_handle_shape(0));
|
||||
} else {
|
||||
// TODO(vrv): handle the case of multiple outputs.
|
||||
auto* t = c->input_handle_shapes_and_types(0);
|
||||
if (t != nullptr && t->size() == c->num_outputs()) {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->UnknownShape());
|
||||
c->set_output(i, (*t)[i].shape);
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Dequeues a tuple of one or more tensors from the given queue.
|
||||
|
@ -927,18 +927,33 @@ REGISTER_OP("Select")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
auto* handle_data_1 = c->input_handle_shapes_and_types(1);
|
||||
auto* handle_data_2 = c->input_handle_shapes_and_types(2);
|
||||
// Merge handle shape and dtype if applicable.
|
||||
if (c->input_handle_dtype(1) != c->input_handle_dtype(2)) {
|
||||
// TODO(apassos) resolve this in the manner of b/32476923
|
||||
return errors::InvalidArgument(
|
||||
"Trying to merge handles pointing to different dtypes.");
|
||||
if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
|
||||
const auto size = handle_data_1->size();
|
||||
std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
|
||||
if (size != handle_data_2->size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to merge handles pointing to different numbers of "
|
||||
"tensors.");
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
|
||||
const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
|
||||
if (s1.dtype != s2.dtype) {
|
||||
// TODO(apassos) resolve this in the manner of b/32476923
|
||||
return errors::InvalidArgument(
|
||||
"Trying to merge handles pointing to different dtypes.");
|
||||
}
|
||||
merged_handle_data[i].dtype = s1.dtype;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
|
||||
}
|
||||
|
||||
c->set_output_handle_shapes_and_types(0, merged_handle_data);
|
||||
}
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(1));
|
||||
ShapeHandle output_handle_shape;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input_handle_shape(1),
|
||||
c->input_handle_shape(2),
|
||||
&output_handle_shape));
|
||||
c->set_output_handle_shape(0, output_handle_shape);
|
||||
|
||||
// The inputs 'then' and 'else' must have the same shape.
|
||||
ShapeHandle data = c->input(1);
|
||||
@ -2074,7 +2089,6 @@ REGISTER_OP("Bincount")
|
||||
.Output("bins: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShapeOfRank(1));
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(2));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
|
@ -188,36 +188,70 @@ TEST(MathOpsTest, Select_ShapeFn) {
|
||||
INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
|
||||
"[2,?,5];[?,?,3];[?,2,?]");
|
||||
|
||||
// Test that handle shapes were merged.
|
||||
// Test that handles were merged.
|
||||
//
|
||||
// Tests below will modify handle_data and call run_inference_for_handles to
|
||||
// rerun shape inference, updating the context <c>.
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
|
||||
TensorShapeProto i0;
|
||||
i0.add_dim()->set_size(1);
|
||||
i0.add_dim()->set_size(-1);
|
||||
TensorShapeProto i1;
|
||||
i1.add_dim()->set_size(-1);
|
||||
i1.add_dim()->set_size(2);
|
||||
typedef std::vector<std::pair<TensorShapeProto, DataType>> ShapeDtypeV;
|
||||
std::vector<std::unique_ptr<ShapeDtypeV>> handle_data;
|
||||
std::unique_ptr<shape_inference::InferenceContext> c;
|
||||
Status run_status;
|
||||
auto run_inference_for_handles = [&]() -> Status {
|
||||
CHECK(op_reg_data->shape_inference_fn != nullptr);
|
||||
c.reset(new shape_inference::InferenceContext(
|
||||
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
|
||||
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
|
||||
handle_data));
|
||||
TF_CHECK_OK(c->construction_status());
|
||||
Status s = c->Run(op_reg_data->shape_inference_fn);
|
||||
LOG(INFO) << "Inference got " << s;
|
||||
return s;
|
||||
};
|
||||
auto shape_proto = [](std::initializer_list<int64> dim_sizes) {
|
||||
TensorShapeProto p;
|
||||
for (auto i : dim_sizes) p.add_dim()->set_size(i);
|
||||
return p;
|
||||
};
|
||||
|
||||
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
|
||||
shape_inference::InferenceContext c(
|
||||
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
|
||||
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
|
||||
{TensorShapeProto(), i0, i1}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
|
||||
EXPECT_TRUE(c.FullyDefined(c.output_handle_shape(0)));
|
||||
EXPECT_EQ("[1,2]", c.DebugString(c.output_handle_shape(0)));
|
||||
TensorShapeProto i0 = shape_proto({1, -1});
|
||||
TensorShapeProto i1 = shape_proto({-1, 2});
|
||||
TensorShapeProto unknown_shape;
|
||||
unknown_shape.set_unknown_rank(true);
|
||||
TensorShapeProto scalar;
|
||||
|
||||
handle_data.emplace_back(
|
||||
new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}});
|
||||
handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}});
|
||||
handle_data.emplace_back(
|
||||
new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}});
|
||||
|
||||
TF_ASSERT_OK(run_inference_for_handles());
|
||||
auto* out = c->output_handle_shapes_and_types(0);
|
||||
ASSERT_EQ(2, out->size());
|
||||
EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape));
|
||||
EXPECT_EQ(DT_FLOAT, out->at(0).dtype);
|
||||
EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape));
|
||||
EXPECT_EQ(DT_INT32, out->at(1).dtype);
|
||||
|
||||
// Expect an error when the shapes can't be merged.
|
||||
TensorShapeProto i2;
|
||||
i1.add_dim()->set_size(2);
|
||||
i1.add_dim()->set_size(2);
|
||||
shape_inference::InferenceContext c2(
|
||||
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
|
||||
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
|
||||
{TensorShapeProto(), i0, i2}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
EXPECT_FALSE(c2.Run(op_reg_data->shape_inference_fn).ok());
|
||||
handle_data[2]->at(0).first = shape_proto({2, 2});
|
||||
EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
|
||||
.contains("must be equal, but are 1 and 2"));
|
||||
handle_data[2]->at(0).first = i1; // restore to valid
|
||||
|
||||
// Expect an error when the types can't be merged.
|
||||
handle_data[2]->at(1).second = DT_INT64;
|
||||
EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
|
||||
.contains("pointing to different dtypes"));
|
||||
handle_data[2]->at(1).second = DT_INT32; // restore to valid
|
||||
|
||||
// Expect an error when different numbers of tensors are merged.
|
||||
handle_data[2]->push_back({i1, DT_FLOAT});
|
||||
EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
|
||||
.contains("pointing to different numbers of tensors"));
|
||||
handle_data[2]->pop_back(); // restore to valid.
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, Range_ShapeFn) {
|
||||
|
@ -20,10 +20,43 @@
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
using ::tensorflow::shape_inference::InferenceContext;
|
||||
using ::tensorflow::shape_inference::ShapeAndType;
|
||||
using ::tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
Status ValidateVariableResourceHandle(InferenceContext* c,
|
||||
ShapeAndType* shape_and_type) {
|
||||
auto* handle_data = c->input_handle_shapes_and_types(0);
|
||||
if (handle_data == nullptr || handle_data->empty()) {
|
||||
shape_and_type->shape = c->UnknownShape();
|
||||
shape_and_type->dtype = DT_INVALID;
|
||||
} else {
|
||||
*shape_and_type = (*handle_data)[0];
|
||||
DataType value_dtype;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
|
||||
if (shape_and_type->dtype != value_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to read variable with wrong dtype. "
|
||||
"Expected ",
|
||||
DataTypeString(shape_and_type->dtype), " got ",
|
||||
DataTypeString(value_dtype));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadVariableShapeFn(InferenceContext* c) {
|
||||
ShapeAndType shape_and_type;
|
||||
TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &shape_and_type));
|
||||
c->set_output(0, shape_and_type.shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("VarHandleOp")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
@ -31,16 +64,17 @@ REGISTER_OP("VarHandleOp")
|
||||
.Attr("shape: shape")
|
||||
.Output("resource: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
DataType t;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
|
||||
c->set_output_handle_dtype(0, t);
|
||||
TensorShapeProto p;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
|
||||
shape_inference::ShapeHandle s;
|
||||
ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(p, &s));
|
||||
c->set_output_handle_shape(0, s);
|
||||
c->set_output_handle_shapes_and_types(0,
|
||||
std::vector<ShapeAndType>{{s, t}});
|
||||
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"(
|
||||
@ -57,19 +91,7 @@ REGISTER_OP("ReadVariableOp")
|
||||
.Input("resource: resource")
|
||||
.Output("value: dtype")
|
||||
.Attr("dtype: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
DataType handle_dtype = c->input_handle_dtype(0);
|
||||
DataType value_dtype;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
|
||||
if (handle_dtype != value_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to read variable with wrong dtype. "
|
||||
"Expected ",
|
||||
DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
|
||||
}
|
||||
c->set_output(0, c->input_handle_shape(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.SetShapeFn(ReadVariableShapeFn)
|
||||
.Doc(R"(
|
||||
Reads the value of a variable.
|
||||
|
||||
@ -88,19 +110,7 @@ REGISTER_OP("_UnsafeReadVariable")
|
||||
.Input("resource: resource")
|
||||
.Output("value: dtype")
|
||||
.Attr("dtype: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
DataType handle_dtype = c->input_handle_dtype(0);
|
||||
DataType value_dtype;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
|
||||
if (handle_dtype != value_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to read variable with wrong dtype. "
|
||||
"Expected ",
|
||||
DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
|
||||
}
|
||||
c->set_output(0, c->input_handle_shape(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.SetShapeFn(ReadVariableShapeFn)
|
||||
.Doc(R"(
|
||||
Reads the value of a variable without any memory model.
|
||||
|
||||
@ -130,19 +140,13 @@ ignore_lookup_error: whether to ignore the error when the resource
|
||||
)");
|
||||
|
||||
Status CreateAssignShapeFn(InferenceContext* c) {
|
||||
DataType handle_dtype = c->input_handle_dtype(0);
|
||||
DataType value_dtype;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
|
||||
if (handle_dtype != value_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to initialize handle for variable with wrong dtype. "
|
||||
"Expected ",
|
||||
DataTypeString(handle_dtype), " got ", DataTypeString(value_dtype));
|
||||
}
|
||||
ShapeHandle s = c->input_handle_shape(0);
|
||||
ShapeAndType handle_shape_and_type;
|
||||
TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type));
|
||||
|
||||
ShapeHandle value_shape = c->input(1);
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(handle_shape_and_type.shape, value_shape, &unused));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -220,18 +224,16 @@ REGISTER_OP("ResourceGather")
|
||||
.Attr("dtype: type")
|
||||
.Attr("Tindices: {int32,int64}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &dtype));
|
||||
if (c->input_handle_dtype(0) != dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to gather from a variable with the wrong dtype.");
|
||||
}
|
||||
ShapeAndType handle_shape_and_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateVariableResourceHandle(c, &handle_shape_and_type));
|
||||
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRankAtLeast(c->input_handle_shape(0), 1, &unused));
|
||||
c->WithRankAtLeast(handle_shape_and_type.shape, 1, &unused));
|
||||
ShapeHandle params_subshape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Subshape(c->input_handle_shape(0), 1, ¶ms_subshape));
|
||||
c->Subshape(handle_shape_and_type.shape, 1, ¶ms_subshape));
|
||||
ShapeHandle indices_shape = c->input(1);
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
|
||||
@ -264,7 +266,10 @@ REGISTER_OP("ResourceScatterAdd")
|
||||
.Attr("dtype: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle var_shape = c->input_handle_shape(0);
|
||||
ShapeAndType handle_shape_and_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateVariableResourceHandle(c, &handle_shape_and_type));
|
||||
ShapeHandle var_shape = handle_shape_and_type.shape;
|
||||
ShapeHandle indices_shape = c->input(1);
|
||||
|
||||
ShapeHandle unused_updates_shape;
|
||||
|
@ -23,11 +23,12 @@ using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
static ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) {
|
||||
auto h_dtype = c->input_handle_dtype(input);
|
||||
if (h_dtype == DT_INVALID) {
|
||||
return c->input(input);
|
||||
auto* handle_data = c->input_handle_shapes_and_types(input);
|
||||
if (handle_data != nullptr && !handle_data->empty() &&
|
||||
(*handle_data)[0].dtype != DT_INVALID) {
|
||||
return (*handle_data)[0].shape;
|
||||
}
|
||||
return c->input_handle_shape(input);
|
||||
return c->input(input);
|
||||
}
|
||||
|
||||
// Handle the gradient and, if <sparse>, indices inputs.
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import six.moves
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
@ -597,8 +596,7 @@ def call_cpp_shape_fn(op,
|
||||
# calls the C / C-API directly, we should be able to remove this.
|
||||
return {
|
||||
"shapes": [tensor_shape.TensorShape(op.get_attr("value").tensor_shape)],
|
||||
"handle_shapes": [tensor_shape.TensorShape(None).as_proto()],
|
||||
"handle_dtypes": [types_pb2.DT_INVALID]
|
||||
"handle_data": [None]
|
||||
}
|
||||
|
||||
input_tensors_needed = input_tensors_needed or []
|
||||
@ -642,8 +640,8 @@ def _call_cpp_shape_fn_impl(
|
||||
r = cpp_shape_inference_pb2.CppShapeInferenceResult()
|
||||
r.shape.CopyFrom(t.get_shape().as_proto())
|
||||
# pylint: disable=protected-access
|
||||
r.handle_shape.CopyFrom(t._handle_shape)
|
||||
r.handle_dtype = t._handle_dtype
|
||||
if t._handle_data is not None:
|
||||
r.handle_data.CopyFrom(t._handle_data)
|
||||
# pylint: enable=protected-access
|
||||
return r.SerializeToString()
|
||||
input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
|
||||
@ -689,8 +687,9 @@ def _call_cpp_shape_fn_impl(
|
||||
for s in output_shapes
|
||||
]
|
||||
result = [r.shape for r in result_protos]
|
||||
result_handle_shapes = [r.handle_shape for r in result_protos]
|
||||
result_handle_dtypes = [r.handle_dtype for r in result_protos]
|
||||
result_handle_data = [
|
||||
r.handle_data if r.handle_data.is_set else None for r in result_protos
|
||||
]
|
||||
|
||||
if debug_python_shape_fn:
|
||||
try:
|
||||
@ -708,10 +707,11 @@ def _call_cpp_shape_fn_impl(
|
||||
str(result_as_shapes), str(python_result), str(op.node_def),
|
||||
",".join([str(i.get_shape()) for i in op.inputs])))
|
||||
|
||||
return {"shapes": result,
|
||||
"handle_shapes": result_handle_shapes,
|
||||
"handle_dtypes": result_handle_dtypes,
|
||||
"inputs_needed": output[-1]}
|
||||
return {
|
||||
"shapes": result,
|
||||
"handle_data": result_handle_data,
|
||||
"inputs_needed": output[-1]
|
||||
}
|
||||
|
||||
# pylint: disable=protected-access
|
||||
ops._set_call_cpp_shape_fn(call_cpp_shape_fn)
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/python/lib/core/py_func.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace swig {
|
||||
namespace {
|
||||
|
||||
@ -71,11 +72,11 @@ Status RunCppShapeInferenceImpl(
|
||||
|
||||
// Convert input shapes.
|
||||
std::vector<TensorShapeProto> input_shapes;
|
||||
std::vector<TensorShapeProto> input_handle_shapes;
|
||||
std::vector<DataType> input_handle_dtypes;
|
||||
std::vector<
|
||||
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>
|
||||
input_handle_shapes_and_types;
|
||||
input_shapes.resize(input_serialized_shapes.size());
|
||||
input_handle_shapes.resize(input_serialized_shapes.size());
|
||||
input_handle_dtypes.resize(input_serialized_shapes.size());
|
||||
input_handle_shapes_and_types.resize(input_serialized_shapes.size());
|
||||
CppShapeInferenceResult tmp;
|
||||
for (int i = 0; i < input_serialized_shapes.size(); ++i) {
|
||||
tmp.Clear();
|
||||
@ -83,9 +84,17 @@ Status RunCppShapeInferenceImpl(
|
||||
return errors::InvalidArgument(
|
||||
"Error parsing shape proto during cpp shape inference");
|
||||
}
|
||||
|
||||
input_shapes[i].Swap(tmp.mutable_shape());
|
||||
input_handle_dtypes[i] = tmp.handle_dtype();
|
||||
input_handle_shapes[i].Swap(tmp.mutable_handle_shape());
|
||||
|
||||
if (tmp.handle_data().is_set()) {
|
||||
input_handle_shapes_and_types[i].reset(
|
||||
new std::vector<std::pair<TensorShapeProto, DataType>>);
|
||||
auto& v = *input_handle_shapes_and_types[i];
|
||||
for (const auto& x : tmp.handle_data().shape_and_type()) {
|
||||
v.emplace_back(x.shape(), x.dtype());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert input tensor values;
|
||||
@ -116,8 +125,8 @@ Status RunCppShapeInferenceImpl(
|
||||
// Run shape inference.
|
||||
tensorflow::shape_inference::InferenceContext c(
|
||||
graph_def_version, &node, op_reg_data->op_def, input_shapes,
|
||||
input_tensors, input_tensor_as_shapes_protos, input_handle_shapes,
|
||||
input_handle_dtypes);
|
||||
input_tensors, input_tensor_as_shapes_protos,
|
||||
input_handle_shapes_and_types);
|
||||
TF_RETURN_IF_ERROR(c.construction_status());
|
||||
|
||||
TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
|
||||
@ -128,9 +137,18 @@ Status RunCppShapeInferenceImpl(
|
||||
for (int i = 0; i < c.num_outputs(); ++i) {
|
||||
out.Clear();
|
||||
ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape());
|
||||
ProtoFromShapeHandle(c.output_handle_shape(i), &c,
|
||||
out.mutable_handle_shape());
|
||||
out.set_handle_dtype(c.output_handle_dtype(i));
|
||||
|
||||
const auto* shapes_and_types = c.output_handle_shapes_and_types(i);
|
||||
if (shapes_and_types != nullptr) {
|
||||
auto* out_handle_data = out.mutable_handle_data();
|
||||
out_handle_data->set_is_set(true);
|
||||
for (const auto& p : *shapes_and_types) {
|
||||
auto* out_shape_and_type = out_handle_data->add_shape_and_type();
|
||||
ProtoFromShapeHandle(p.shape, &c, out_shape_and_type->mutable_shape());
|
||||
out_shape_and_type->set_dtype(p.dtype);
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
|
||||
}
|
||||
|
||||
|
@ -7,9 +7,21 @@ import "tensorflow/core/framework/types.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
|
||||
message CppShapeInferenceResult {
|
||||
message HandleShapeAndType {
|
||||
TensorShapeProto shape = 1;
|
||||
DataType dtype = 2;
|
||||
}
|
||||
message HandleData {
|
||||
bool is_set = 1;
|
||||
|
||||
// Only valid if <is_set>.
|
||||
repeated HandleShapeAndType shape_and_type = 2;
|
||||
}
|
||||
TensorShapeProto shape = 1;
|
||||
TensorShapeProto handle_shape = 2;
|
||||
DataType handle_dtype = 3;
|
||||
|
||||
reserved 2; // was handle_shape
|
||||
reserved 3; // was handle_dtype
|
||||
HandleData handle_data = 4;
|
||||
}
|
||||
|
||||
message CppShapeInferenceInputsNeeded {
|
||||
|
@ -351,8 +351,7 @@ class _FuncGraph(ops.Graph):
|
||||
self.extra_inputs.append(x)
|
||||
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
|
||||
# pylint: disable=protected-access
|
||||
ph._handle_shape = x._handle_shape
|
||||
ph._handle_dtype = x._handle_dtype
|
||||
ph._handle_data = x._handle_data
|
||||
# pylint: enable=protected-access
|
||||
inputs[i] = ph
|
||||
self._captured[x] = ph
|
||||
|
@ -31,8 +31,6 @@ from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as c_api
|
||||
from tensorflow.python.framework import device as pydev
|
||||
@ -323,8 +321,8 @@ class Tensor(_TensorLike):
|
||||
self._consumers = []
|
||||
|
||||
# Attributes used for C++ shape inference. Not inspected, only forwarded.
|
||||
self._handle_shape = tensor_shape_pb2.TensorShapeProto()
|
||||
self._handle_dtype = types_pb2.DT_INVALID
|
||||
# If set, will be a HandleData object from cpp_shape_inference.proto.
|
||||
self._handle_data = None
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
@ -1877,12 +1875,10 @@ def set_shapes_for_outputs(op):
|
||||
# Returned by call_cpp_shape_fn
|
||||
shapes_dict = shapes
|
||||
shapes = shapes_dict["shapes"]
|
||||
handle_shapes = shapes_dict["handle_shapes"]
|
||||
handle_dtypes = shapes_dict["handle_dtypes"]
|
||||
for output, handle_shape, handle_dtype in zip(op.outputs, handle_shapes, handle_dtypes):
|
||||
handle_datas = shapes_dict["handle_data"]
|
||||
for output, handle_data in zip(op.outputs, handle_datas):
|
||||
# pylint: disable=protected-access
|
||||
output._handle_shape = handle_shape
|
||||
output._handle_dtype = handle_dtype
|
||||
output._handle_data = handle_data
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if len(op.outputs) != len(shapes):
|
||||
|
@ -237,6 +237,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
|
||||
_ = x_read.eval()
|
||||
|
||||
def testShape(self):
|
||||
with self.test_session():
|
||||
v = resource_variable_ops.ResourceVariable(
|
||||
name="var1", initial_value=array_ops.ones(shape=[10, 20, 35]))
|
||||
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
|
||||
self.assertEqual("(10, 20, 35)", str(v.value().shape))
|
||||
self.assertEqual("(3, 20, 35)", str(v.sparse_read([0, 1, 2]).shape))
|
||||
self.assertEqual(
|
||||
"<unknown>",
|
||||
str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
|
||||
|
||||
def testSetInitialValue(self):
|
||||
with self.test_session():
|
||||
# Initialize variable with a value different from the initial value passed
|
||||
|
Loading…
Reference in New Issue
Block a user