Changed InferenceContext ctor to accept NodeDef as a const ref

It has been a const ref once, but cl/170078811 made it a non-const pointer
to allow resetting it in ShapeRefiner::InferShapesForFunction. This code
path is no longer used.

PiperOrigin-RevId: 273381405
This commit is contained in:
Sergei Lebedev 2019-10-07 14:49:17 -07:00 committed by TensorFlower Gardener
parent 05a5da9097
commit a6ac9040dd
14 changed files with 126 additions and 134 deletions

View File

@ -1111,7 +1111,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
}
// Create an inference context with dummy values, which will be updated later.
InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
{},
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());

View File

@ -114,7 +114,7 @@ TEST(BitcastOpTest, TestShapeInference_LargerShape) {
.Attr("T", DT_INT64)
.Input(FakeInput(DT_INT64))
.Finalize(&def));
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
shape_inference::InferenceContext c(0, def, op_def, {S({3, 4})}, {}, {}, {});
std::vector<shape_inference::ShapeHandle> input_shapes;
TF_CHECK_OK(c.input("input", &input_shapes));
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
@ -132,7 +132,7 @@ TEST(BitcastOpTest, TestShapeInference_SmallerShape) {
.Attr("T", DT_INT8)
.Input(FakeInput(DT_INT8))
.Finalize(&def));
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4, 8})}, {}, {},
shape_inference::InferenceContext c(0, def, op_def, {S({3, 4, 8})}, {}, {},
{});
std::vector<shape_inference::ShapeHandle> input_shapes;
TF_CHECK_OK(c.input("input", &input_shapes));
@ -151,7 +151,7 @@ TEST(BitcastOpTest, TestShapeInference_SameShape) {
.Attr("T", DT_FLOAT)
.Input(FakeInput(DT_FLOAT))
.Finalize(&def));
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
shape_inference::InferenceContext c(0, def, op_def, {S({3, 4})}, {}, {}, {});
std::vector<shape_inference::ShapeHandle> input_shapes;
TF_CHECK_OK(c.input("input", &input_shapes));
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));

View File

@ -196,7 +196,7 @@ PartialTensorShape Unknown() { return PartialTensorShape(); }
TEST(OpsTest, ShapeInferenceWithRank) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
{S({10, 20, 30})}, {}, {}, {});
shape_inference::ShapeHandle in0 = c.input(0);
@ -236,7 +236,7 @@ TEST(OpsTest, ShapeInferenceWithRank) {
TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 2),
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 2),
{Unknown(), S({1, -1, 3})}, {}, {}, {});
shape_inference::ShapeHandle in0 = c.input(0);
@ -260,7 +260,7 @@ TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
TEST(OpsTest, ShapeInferenceConcatenateShapes) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
{S({1, 2}), S({3, 4})}, {}, {}, {});
ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c)));
shape_inference::ShapeHandle a = c.input(0);
@ -279,7 +279,7 @@ TEST(OpsTest, ShapeInferenceConcatenateShapes) {
TEST(OpsTest, DimensionHandleValueKnown) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
{S({1, 2}), S({3, 4})}, {}, {}, {});
TF_ShapeHandle* handle =
TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43);
@ -299,7 +299,7 @@ TEST(OpsTest, DimensionHandleValueKnown) {
TEST(OpsTest, ShapeInferenceSubshape) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
{S({10, 20, 30, 40, 50})}, {}, {}, {});
ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0)));

View File

@ -115,7 +115,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
// shapes. This object is abstracting the information that the ShapeInference
// function operates on.
tensorflow::shape_inference::InferenceContext c(
graph_version, node_def.get(), op_reg_data->op_def, input_shapes,
graph_version, *node_def, op_reg_data->op_def, input_shapes,
/*input_tensors=*/{}, /*input_tensors_as_shapes=*/{},
/*input_handle_shapes_and_types=*/{});
auto status = c.Run(op_reg_data->shape_inference_fn);

View File

@ -37,7 +37,7 @@ Status RunShapeInference(const NodeDef& ndef,
if (op_reg_data->shape_inference_fn == nullptr) return Status::OK();
shape_inference::InferenceContext ic(
TF_GRAPH_DEF_VERSION, &ndef, op_reg_data->op_def,
TF_GRAPH_DEF_VERSION, ndef, op_reg_data->op_def,
std::vector<shape_inference::ShapeHandle>(inputs.size()), {}, {}, {});
for (size_t i = 0; i < inputs.size(); i++) {
shape_inference::ShapeHandle shape;

View File

@ -191,7 +191,7 @@ Status ShapeRefiner::InferShapesForFunction(
Status ShapeRefiner::AddNode(const Node* node) {
// Create the inference context for this node with the existing input shapes.
std::unique_ptr<InferenceContext> ic(new InferenceContext(
graph_def_version_, &node->def(), node->op_def(),
graph_def_version_, node->def(), node->op_def(),
std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
TF_RETURN_IF_ERROR(ic->construction_status());

View File

@ -63,7 +63,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
.Input({{"data", 0, DT_FLOAT}})
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {},
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({}), S({10})}, {},
{}, {});
TF_EXPECT_OK(NoOutputs(&c));
EXPECT_EQ(0, c.num_outputs());
@ -82,15 +82,15 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {});
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({})}, {}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({1, 23, 4, 4, 2})}, {}, {}, {});
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({1, 23, 4, 4, 2})},
{}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({2, 3}), S({3, 4})}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
@ -127,7 +127,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown inner dimension for one
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({2, -1}), S({3, 4})}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
@ -137,7 +137,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Invalid rank.
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})},
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2}), S({3, 4})},
{}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
@ -147,7 +147,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown outer dimension
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({2, 3}), S({3, -1})}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
@ -157,7 +157,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({2, 5}), S({3, 4})}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
@ -168,7 +168,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
@ -186,7 +186,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({3, 2}), S({3, 4})}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
@ -204,7 +204,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({2, 3}), S({4, 3})}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
@ -420,8 +420,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
{S({2, 10}), S({10})}, {}, {}, {});
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 10}), S({10})},
{}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -430,7 +430,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Unknown ranks.
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{Unknown(), Unknown()}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
@ -439,7 +439,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Rank > 2
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
@ -453,7 +453,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({2, 3, 4, 5}), S({3})}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
@ -467,7 +467,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@ -479,7 +479,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({10, 11, 12}), S({11})}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
@ -488,7 +488,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Input rank not high enough
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {},
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3}), S({3})}, {},
{}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@ -501,7 +501,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})},
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3}), S({3})},
{}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@ -548,7 +548,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
{
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {},
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 10})}, {}, {},
{});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
@ -557,7 +557,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Rank > 2
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})},
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({5, 7, 2, 10})},
{}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
@ -570,8 +570,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})},
{}, {}, {});
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3, 4, 5})}, {},
{}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@ -583,7 +583,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
{S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
@ -596,8 +596,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})},
{}, {}, {});
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({10, 11, 12})}, {},
{}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(11, c.Value(c.Dim(output, 0)));
@ -605,8 +605,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Input rank not high enough
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {},
{});
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3})}, {}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@ -617,7 +616,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {},
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3})}, {}, {},
{});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@ -1353,7 +1352,7 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{Unknown(), Unknown(), Unknown()}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1366,7 +1365,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({-1, -1}), S({-1}), S({-1})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1379,7 +1378,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({-1}), S({-1}), S({-1})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1393,7 +1392,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({5, 3}), S({4}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1407,7 +1406,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({5, 3}), S({5}), S({4})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1421,7 +1420,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({-1, 3}), S({5}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1434,7 +1433,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({5, 3}), S({-1}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1447,7 +1446,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({5, -1}), S({5}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1460,7 +1459,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({5, 3}), S({5}), S({-1})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1473,7 +1472,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
NodeDef def;
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
{S({5, 3}), S({5}), S({3})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());

View File

@ -31,15 +31,14 @@ constexpr int64 InferenceContext::kUnknownDim;
// Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext::InferenceContext(
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
const std::vector<PartialTensorShape>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
const std::vector<
std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
input_handle_shapes_and_types)
: graph_def_version_(graph_def_version),
node_def_(CHECK_NOTNULL(node_def)) {
: graph_def_version_(graph_def_version), node_def_(node_def) {
std::vector<ShapeHandle> input_tensors_as_shape_handles;
input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
for (const PartialTensorShape& p : input_tensors_as_shapes) {
@ -84,14 +83,13 @@ InferenceContext::InferenceContext(
}
InferenceContext::InferenceContext(
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes,
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
input_handle_shapes_and_types)
: graph_def_version_(graph_def_version),
node_def_(CHECK_NOTNULL(node_def)) {
: graph_def_version_(graph_def_version), node_def_(node_def) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
@ -112,7 +110,7 @@ Status InferenceContext::Run(
#ifndef NDEBUG
for (int i = 0; i < num_outputs(); ++i) {
DCHECK(output(i).IsSet())
<< i << " for " << node_def_->name() << " of type " << node_def_->op();
<< i << " for " << node_def_.name() << " of type " << node_def_.op();
}
#endif // NDEBUG
return s;
@ -171,8 +169,8 @@ void InferenceContext::PreInputInit(
input_tensors_ = input_tensors;
input_tensors_as_shapes_ = input_tensors_as_shapes;
construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_,
&output_name_map_);
construction_status_ =
NameRangesForNode(node_def_, op_def, &input_name_map_, &output_name_map_);
if (!construction_status_.ok()) return;
int num_outputs = 0;
@ -290,7 +288,7 @@ string InferenceContext::DebugString(DimensionHandle d) {
string InferenceContext::DebugString() const {
return strings::StrCat("InferenceContext for node: ",
node_def_->DebugString());
node_def_.DebugString());
}
string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
@ -1119,7 +1117,7 @@ Status InferenceContext::AttachContext(const Status& status) {
}
string error_context = strings::StrCat(
" for '", node_def_->name(), "' (op: '", node_def_->op(),
" for '", node_def_.name(), "' (op: '", node_def_.op(),
"') with input shapes: ", absl::StrJoin(input_shapes, ", "));
if (!input_from_tensors_str.empty()) {
strings::StrAppend(&error_context, " and with computed input tensors: ",

View File

@ -161,9 +161,7 @@ class InferenceContext {
// known from analysis of the graph.
// <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
// Values of <input_tensors_as_shapes> do not need to outlive the context.
//
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
InferenceContext(int graph_def_version, const NodeDef* node_def,
InferenceContext(int graph_def_version, const NodeDef& node_def,
const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
@ -179,11 +177,8 @@ class InferenceContext {
// partially known from analysis of the graph. <input_tensors_as_shapes>
// can have fewer elements than <input_shapes>. Values of
// <input_tensors_as_shapes> do not need to outlive the context.
//
// REQUIRES: <node_def> is not NULL, and must outlive the
// InferenceContext.
InferenceContext(
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
const std::vector<PartialTensorShape>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
@ -306,7 +301,7 @@ class InferenceContext {
Status output(StringPiece output_name,
std::vector<ShapeHandle>* output) const;
AttrSlice attrs() const { return AttrSlice(*node_def_); }
AttrSlice attrs() const { return AttrSlice(node_def_); }
// idx can be negative for an offset from end of dimensions.
// idx must be in the range [-1 * s.rank, s.rank).
@ -737,7 +732,7 @@ class InferenceContext {
output_handle_shapes_and_types_;
const int graph_def_version_;
const NodeDef* node_def_;
const NodeDef& node_def_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;
@ -784,7 +779,7 @@ inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
template <class T>
Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
return GetNodeAttr(*node_def_, attr_name, value);
return GetNodeAttr(node_def_, attr_name, value);
}
} // namespace shape_inference

View File

@ -77,7 +77,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
.Attr("N", 3)
.Input(FakeInput(DT_FLOAT))
.Finalize(&def);
InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
InferenceContext c(kVersion, def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
{}, {}, {});
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
@ -114,7 +114,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
EXPECT_EQ(InferenceContext::kUnknownDim,
c.Value(InferenceContext::kUnknownDim));
EXPECT_EQ(1, c.Value(1));
@ -129,7 +129,7 @@ TEST_F(ShapeInferenceTest, Run) {
NodeDef def;
def.set_name("foo");
def.set_op("foo_op");
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
TF_ASSERT_OK(c.construction_status());
{
@ -167,7 +167,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
def.set_op("foo_op");
// Error when no constant tensors were requested.
{
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
{});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
@ -186,7 +186,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
{
Tensor input_t =
::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
InferenceContext c(kVersion, def, MakeOpDef(2, 2),
{S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
@ -208,7 +208,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
// shapes provided.
{
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
{nullptr, &input_t}, {}, {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
@ -231,7 +231,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
// shape was provided.
{
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
{nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
TF_ASSERT_OK(c.construction_status());
auto fn = [](InferenceContext* c) {
@ -254,7 +254,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
InferenceContext c(kVersion, def, MakeOpDef(3, 2),
{Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@ -295,7 +295,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
TEST_F(ShapeInferenceTest, NumElements) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
InferenceContext c(kVersion, def, MakeOpDef(3, 2),
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
@ -309,8 +309,8 @@ TEST_F(ShapeInferenceTest, NumElements) {
TEST_F(ShapeInferenceTest, WithRank) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
{Unknown(), S({1, -1, 3})}, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
{}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -348,8 +348,8 @@ TEST_F(ShapeInferenceTest, WithRank) {
TEST_F(ShapeInferenceTest, WithRankAtMost) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
{Unknown(), S({1, -1, 3})}, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
{}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -386,8 +386,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
{Unknown(), S({1, -1, 3})}, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
{}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -424,7 +424,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
TEST_F(ShapeInferenceTest, WithValue) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@ -467,8 +467,8 @@ TEST_F(ShapeInferenceTest, WithValue) {
TEST_F(ShapeInferenceTest, MergeDim) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
{}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {},
{}, {});
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@ -530,7 +530,7 @@ TEST_F(ShapeInferenceTest, MergeDim) {
TEST_F(ShapeInferenceTest, RelaxDim) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2),
InferenceContext c(kVersion, def, MakeOpDef(1, 2),
{S({2, InferenceContext::kUnknownDim, 2, 1,
InferenceContext::kUnknownDim})},
{}, {}, {});
@ -578,7 +578,7 @@ TEST_F(ShapeInferenceTest, RelaxDim) {
TEST_F(ShapeInferenceTest, RelaxShape) {
NodeDef def;
InferenceContext c(
kVersion, &def, MakeOpDef(7, 2),
kVersion, def, MakeOpDef(7, 2),
{Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
{}, {}, {});
@ -647,7 +647,7 @@ TEST_F(ShapeInferenceTest, RelaxShape) {
TEST_F(ShapeInferenceTest, MergeShape) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
InferenceContext c(kVersion, def, MakeOpDef(7, 2),
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
Unknown(), S({1})},
{}, {}, {});
@ -753,7 +753,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
TEST_F(ShapeInferenceTest, MergePrefix) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(4, 2),
InferenceContext c(kVersion, def, MakeOpDef(4, 2),
{
Unknown(),
S({-1, 2}),
@ -808,7 +808,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
TEST_F(ShapeInferenceTest, Subshape) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
InferenceContext c(kVersion, def, MakeOpDef(2, 2),
{S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
ShapeHandle unknown = c.input(1);
@ -880,7 +880,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
TEST_F(ShapeInferenceTest, Concatenate) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
InferenceContext c(kVersion, def, MakeOpDef(3, 2),
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
auto in0 = c.input(0);
@ -907,7 +907,7 @@ TEST_F(ShapeInferenceTest, Concatenate) {
TEST_F(ShapeInferenceTest, ReplaceDim) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
InferenceContext c(kVersion, def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
{}, {}, {});
auto in = c.input(0);
@ -939,7 +939,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
TEST_F(ShapeInferenceTest, MakeShape) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
{}, {});
std::vector<DimensionHandle> dims;
@ -966,7 +966,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
TEST_F(ShapeInferenceTest, UnknownShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto u0 = c.UnknownShape();
auto u1 = c.UnknownShape();
@ -978,7 +978,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
TEST_F(ShapeInferenceTest, KnownShapeToProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto s = c.MakeShape({1, 2, 3});
TensorShapeProto proto;
@ -992,7 +992,7 @@ TEST_F(ShapeInferenceTest, KnownShapeToProto) {
TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto u0 = c.UnknownShape();
TensorShapeProto proto;
@ -1005,7 +1005,7 @@ TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
TEST_F(ShapeInferenceTest, Scalar) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto s0 = c.Scalar();
EXPECT_EQ("[]", c.DebugString(s0));
@ -1016,7 +1016,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
TEST_F(ShapeInferenceTest, Vector) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto s0 = c.Vector(1);
EXPECT_EQ("[1]", c.DebugString(s0));
@ -1032,7 +1032,7 @@ TEST_F(ShapeInferenceTest, Vector) {
TEST_F(ShapeInferenceTest, Matrix) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto s0 = c.Matrix(1, 2);
EXPECT_EQ("[1,2]", c.DebugString(s0));
@ -1054,7 +1054,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
auto create = [&](Tensor* t) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
{});
ShapeHandle out;
Status s = c.MakeShapeFromShapeTensor(0, &out);
@ -1115,7 +1115,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
// Test when the input shape is wrong.
{
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
InferenceContext c(kVersion, def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
{}, {});
ShapeHandle out;
EXPECT_EQ("Shape must be rank 1 but is rank 2",
@ -1126,7 +1126,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
// With an unknown rank.
ShapeHandle out;
@ -1145,7 +1145,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
ShapeHandle out;
TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
@ -1159,7 +1159,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
TensorShapeProto proto;
// With a set unknown rank.
@ -1195,7 +1195,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
TEST_F(ShapeInferenceTest, MakeDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto d0 = c.MakeDim(1);
auto d1 = c.MakeDim(1);
@ -1209,7 +1209,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
TEST_F(ShapeInferenceTest, UnknownDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto d0 = c.UnknownDim();
auto d1 = c.UnknownDim();
@ -1221,7 +1221,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
@ -1234,7 +1234,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
{&t1, &t2}, {}, {});
EXPECT_TRUE(c.input_tensor(0) == &t1);
@ -1246,8 +1246,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
{&t1, &t2}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
{}, {});
DimensionHandle d;
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
@ -1280,7 +1280,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
.ok());
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {});
InferenceContext c(kVersion, def, op_reg_data.op_def, empty, {}, {}, {});
string value;
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
EXPECT_EQ("bar", value);
@ -1288,7 +1288,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
TEST_F(ShapeInferenceTest, Divide) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
{}, {});
auto s = c.input(0);
@ -1351,7 +1351,7 @@ TEST_F(ShapeInferenceTest, Divide) {
TEST_F(ShapeInferenceTest, Add) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
{});
auto s = c.input(0);
@ -1401,8 +1401,8 @@ TEST_F(ShapeInferenceTest, Add) {
TEST_F(ShapeInferenceTest, Subtract) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
{}, {});
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {},
{});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1451,8 +1451,8 @@ TEST_F(ShapeInferenceTest, Subtract) {
TEST_F(ShapeInferenceTest, Multiply) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
{}, {});
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {},
{});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1505,7 +1505,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
TEST_F(ShapeInferenceTest, FullyDefined) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
// No rank or missing dimension information should return false.
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
@ -1518,8 +1518,8 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
TEST_F(ShapeInferenceTest, Min) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
{}, {});
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {},
{});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@ -1567,7 +1567,7 @@ TEST_F(ShapeInferenceTest, Min) {
TEST_F(ShapeInferenceTest, Max) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
{});
auto s = c.input(0);
@ -1605,7 +1605,7 @@ TEST_F(ShapeInferenceTest, Max) {
void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
{});
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
ShapeHandle s;
@ -1716,7 +1716,7 @@ TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
{});
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
ShapeHandle s;

View File

@ -60,7 +60,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
}
}
shape_inference::InferenceContext c(
op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes,
op.graph_def_version, op.node_def, op_reg_data->op_def, in_shapes,
op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types));
TF_RETURN_IF_ERROR(c.construction_status());
if (op_reg_data->shape_inference_fn == nullptr) {

View File

@ -1168,7 +1168,7 @@ class SymbolicShapeRefiner {
std::vector<ShapeHandle> input_tensors_as_shapes;
node_ctx.inference_context.reset(new InferenceContext(
graph_def_version_, node, node_ctx.op_data->op_def, input_shapes,
graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
input_tensors, input_tensors_as_shapes,
std::move(input_handle_shapes_and_types)));
const Status s = node_ctx.inference_context->construction_status();

View File

@ -198,7 +198,7 @@ TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
new std::vector<std::pair<PartialTensorShape, DataType>>(
{{PartialTensorShape(), DT_BOOL}}));
shape_inference::InferenceContext c(
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
{PartialTensorShape()}, {}, {}, handle_data);
TF_ASSERT_OK(c.construction_status());
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);

View File

@ -229,7 +229,7 @@ TEST(MathOpsTest, Select_ShapeFn) {
auto run_inference_for_handles = [&]() -> Status {
CHECK(op_reg_data->shape_inference_fn != nullptr);
c.reset(new shape_inference::InferenceContext(
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
{PartialTensorShape(), PartialTensorShape(), PartialTensorShape()}, {},
{}, handle_data));
TF_CHECK_OK(c->construction_status());