Changed InferenceContext ctor to accept NodeDef as a const ref
It has been a const ref once, but cl/170078811 made it a non-const pointer to allow resetting it in ShapeRefiner::InferShapesForFunction. This code path is no longer used. PiperOrigin-RevId: 273381405
This commit is contained in:
parent
05a5da9097
commit
a6ac9040dd
@ -1111,7 +1111,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
}
|
||||
|
||||
// Create an inference context with dummy values, which will be updated later.
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
|
||||
std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
|
||||
{},
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
|
||||
|
@ -114,7 +114,7 @@ TEST(BitcastOpTest, TestShapeInference_LargerShape) {
|
||||
.Attr("T", DT_INT64)
|
||||
.Input(FakeInput(DT_INT64))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
|
||||
shape_inference::InferenceContext c(0, def, op_def, {S({3, 4})}, {}, {}, {});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
|
||||
@ -132,7 +132,7 @@ TEST(BitcastOpTest, TestShapeInference_SmallerShape) {
|
||||
.Attr("T", DT_INT8)
|
||||
.Input(FakeInput(DT_INT8))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4, 8})}, {}, {},
|
||||
shape_inference::InferenceContext c(0, def, op_def, {S({3, 4, 8})}, {}, {},
|
||||
{});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
@ -151,7 +151,7 @@ TEST(BitcastOpTest, TestShapeInference_SameShape) {
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(&def));
|
||||
shape_inference::InferenceContext c(0, &def, op_def, {S({3, 4})}, {}, {}, {});
|
||||
shape_inference::InferenceContext c(0, def, op_def, {S({3, 4})}, {}, {}, {});
|
||||
std::vector<shape_inference::ShapeHandle> input_shapes;
|
||||
TF_CHECK_OK(c.input("input", &input_shapes));
|
||||
ASSERT_EQ("[3,4]", c.DebugString(input_shapes[0]));
|
||||
|
@ -196,7 +196,7 @@ PartialTensorShape Unknown() { return PartialTensorShape(); }
|
||||
|
||||
TEST(OpsTest, ShapeInferenceWithRank) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
|
||||
shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
|
||||
{S({10, 20, 30})}, {}, {}, {});
|
||||
|
||||
shape_inference::ShapeHandle in0 = c.input(0);
|
||||
@ -236,7 +236,7 @@ TEST(OpsTest, ShapeInferenceWithRank) {
|
||||
|
||||
TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 2),
|
||||
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 2),
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
||||
|
||||
shape_inference::ShapeHandle in0 = c.input(0);
|
||||
@ -260,7 +260,7 @@ TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
|
||||
|
||||
TEST(OpsTest, ShapeInferenceConcatenateShapes) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
|
||||
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
|
||||
{S({1, 2}), S({3, 4})}, {}, {}, {});
|
||||
ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c)));
|
||||
shape_inference::ShapeHandle a = c.input(0);
|
||||
@ -279,7 +279,7 @@ TEST(OpsTest, ShapeInferenceConcatenateShapes) {
|
||||
|
||||
TEST(OpsTest, DimensionHandleValueKnown) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
|
||||
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
|
||||
{S({1, 2}), S({3, 4})}, {}, {}, {});
|
||||
TF_ShapeHandle* handle =
|
||||
TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43);
|
||||
@ -299,7 +299,7 @@ TEST(OpsTest, DimensionHandleValueKnown) {
|
||||
|
||||
TEST(OpsTest, ShapeInferenceSubshape) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
|
||||
shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
|
||||
{S({10, 20, 30, 40, 50})}, {}, {}, {});
|
||||
ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0)));
|
||||
|
||||
|
@ -115,7 +115,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
// shapes. This object is abstracting the information that the ShapeInference
|
||||
// function operates on.
|
||||
tensorflow::shape_inference::InferenceContext c(
|
||||
graph_version, node_def.get(), op_reg_data->op_def, input_shapes,
|
||||
graph_version, *node_def, op_reg_data->op_def, input_shapes,
|
||||
/*input_tensors=*/{}, /*input_tensors_as_shapes=*/{},
|
||||
/*input_handle_shapes_and_types=*/{});
|
||||
auto status = c.Run(op_reg_data->shape_inference_fn);
|
||||
|
@ -37,7 +37,7 @@ Status RunShapeInference(const NodeDef& ndef,
|
||||
if (op_reg_data->shape_inference_fn == nullptr) return Status::OK();
|
||||
|
||||
shape_inference::InferenceContext ic(
|
||||
TF_GRAPH_DEF_VERSION, &ndef, op_reg_data->op_def,
|
||||
TF_GRAPH_DEF_VERSION, ndef, op_reg_data->op_def,
|
||||
std::vector<shape_inference::ShapeHandle>(inputs.size()), {}, {}, {});
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
shape_inference::ShapeHandle shape;
|
||||
|
@ -191,7 +191,7 @@ Status ShapeRefiner::InferShapesForFunction(
|
||||
Status ShapeRefiner::AddNode(const Node* node) {
|
||||
// Create the inference context for this node with the existing input shapes.
|
||||
std::unique_ptr<InferenceContext> ic(new InferenceContext(
|
||||
graph_def_version_, &node->def(), node->op_def(),
|
||||
graph_def_version_, node->def(), node->op_def(),
|
||||
std::vector<ShapeHandle>(node->num_inputs()), {}, {}, {}));
|
||||
TF_RETURN_IF_ERROR(ic->construction_status());
|
||||
|
||||
|
@ -63,7 +63,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
|
||||
.Input({{"data", 0, DT_FLOAT}})
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({}), S({10})}, {},
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({}), S({10})}, {},
|
||||
{}, {});
|
||||
TF_EXPECT_OK(NoOutputs(&c));
|
||||
EXPECT_EQ(0, c.num_outputs());
|
||||
@ -82,15 +82,15 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
|
||||
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({})}, {}, {}, {});
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({})}, {}, {}, {});
|
||||
TF_EXPECT_OK(ScalarShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(0, c.Rank(output));
|
||||
}
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({1, 23, 4, 4, 2})}, {}, {}, {});
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({1, 23, 4, 4, 2})},
|
||||
{}, {}, {});
|
||||
TF_EXPECT_OK(ScalarShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(0, c.Rank(output));
|
||||
@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({2, 3}), S({3, 4})}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -127,7 +127,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown inner dimension for one
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({2, -1}), S({3, 4})}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -137,7 +137,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Invalid rank.
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2}), S({3, 4})},
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2}), S({3, 4})},
|
||||
{}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
@ -147,7 +147,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown outer dimension
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({2, 3}), S({3, -1})}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -157,7 +157,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Inner shapes not compatible
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({2, 5}), S({3, 4})}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
@ -168,7 +168,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Inner shapes not compatible
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
@ -186,7 +186,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Attr("type", DT_FLOAT)
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({3, 2}), S({3, 4})}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -204,7 +204,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Attr("type", DT_FLOAT)
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({2, 3}), S({4, 3})}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -420,8 +420,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
{S({2, 10}), S({10})}, {}, {}, {});
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 10}), S({10})},
|
||||
{}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -430,7 +430,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown ranks.
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{Unknown(), Unknown()}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -439,7 +439,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Rank > 2
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -453,7 +453,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({2, 3, 4, 5}), S({3})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -467,7 +467,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
@ -479,7 +479,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({10, 11, 12}), S({11})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -488,7 +488,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3}), S({3})}, {},
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3}), S({3})}, {},
|
||||
{}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
@ -501,7 +501,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3}), S({3})},
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3}), S({3})},
|
||||
{}, {}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
@ -548,7 +548,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 10})}, {}, {},
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 10})}, {}, {},
|
||||
{});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -557,7 +557,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
|
||||
{
|
||||
// Rank > 2
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({5, 7, 2, 10})},
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({5, 7, 2, 10})},
|
||||
{}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -570,8 +570,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3, 4, 5})},
|
||||
{}, {}, {});
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3, 4, 5})}, {},
|
||||
{}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||
@ -583,7 +583,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def,
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def,
|
||||
{S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
@ -596,8 +596,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({10, 11, 12})},
|
||||
{}, {}, {});
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({10, 11, 12})}, {},
|
||||
{}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(11, c.Value(c.Dim(output, 0)));
|
||||
@ -605,8 +605,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({3})}, {}, {},
|
||||
{});
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({3})}, {}, {}, {});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
|
||||
@ -617,7 +616,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, op_def, {S({2, 3})}, {}, {},
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, op_def, {S({2, 3})}, {}, {},
|
||||
{});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
@ -1353,7 +1352,7 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{Unknown(), Unknown(), Unknown()}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1366,7 +1365,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapes) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({-1, -1}), S({-1}), S({-1})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1379,7 +1378,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownDims) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({-1}), S({-1}), S({-1})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1393,7 +1392,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({4}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1407,7 +1406,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({5}), S({4})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1421,7 +1420,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_InvalidRank) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({-1, 3}), S({5}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1434,7 +1433,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({-1}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1447,7 +1446,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({5, -1}), S({5}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1460,7 +1459,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({5}), S({-1})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
@ -1473,7 +1472,7 @@ TEST(CommonShapeFnsTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
|
||||
TEST(CommonShapeFnsTest, ValidateSparseTensor) {
|
||||
NodeDef def;
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, &def, MakeOpDef(3, 1),
|
||||
InferenceContext c(TF_GRAPH_DEF_VERSION, def, MakeOpDef(3, 1),
|
||||
{S({5, 3}), S({5}), S({3})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
@ -31,15 +31,14 @@ constexpr int64 InferenceContext::kUnknownDim;
|
||||
|
||||
// Same as above, but with PartialTensorShape instead of TensorShapeProto
|
||||
InferenceContext::InferenceContext(
|
||||
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||
int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
|
||||
const std::vector<PartialTensorShape>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
|
||||
const std::vector<
|
||||
std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
|
||||
input_handle_shapes_and_types)
|
||||
: graph_def_version_(graph_def_version),
|
||||
node_def_(CHECK_NOTNULL(node_def)) {
|
||||
: graph_def_version_(graph_def_version), node_def_(node_def) {
|
||||
std::vector<ShapeHandle> input_tensors_as_shape_handles;
|
||||
input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
|
||||
for (const PartialTensorShape& p : input_tensors_as_shapes) {
|
||||
@ -84,14 +83,13 @@ InferenceContext::InferenceContext(
|
||||
}
|
||||
|
||||
InferenceContext::InferenceContext(
|
||||
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||
int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes,
|
||||
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
|
||||
input_handle_shapes_and_types)
|
||||
: graph_def_version_(graph_def_version),
|
||||
node_def_(CHECK_NOTNULL(node_def)) {
|
||||
: graph_def_version_(graph_def_version), node_def_(node_def) {
|
||||
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
|
||||
if (!construction_status_.ok()) return;
|
||||
inputs_ = input_shapes;
|
||||
@ -112,7 +110,7 @@ Status InferenceContext::Run(
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < num_outputs(); ++i) {
|
||||
DCHECK(output(i).IsSet())
|
||||
<< i << " for " << node_def_->name() << " of type " << node_def_->op();
|
||||
<< i << " for " << node_def_.name() << " of type " << node_def_.op();
|
||||
}
|
||||
#endif // NDEBUG
|
||||
return s;
|
||||
@ -171,8 +169,8 @@ void InferenceContext::PreInputInit(
|
||||
input_tensors_ = input_tensors;
|
||||
input_tensors_as_shapes_ = input_tensors_as_shapes;
|
||||
|
||||
construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_,
|
||||
&output_name_map_);
|
||||
construction_status_ =
|
||||
NameRangesForNode(node_def_, op_def, &input_name_map_, &output_name_map_);
|
||||
if (!construction_status_.ok()) return;
|
||||
|
||||
int num_outputs = 0;
|
||||
@ -290,7 +288,7 @@ string InferenceContext::DebugString(DimensionHandle d) {
|
||||
|
||||
string InferenceContext::DebugString() const {
|
||||
return strings::StrCat("InferenceContext for node: ",
|
||||
node_def_->DebugString());
|
||||
node_def_.DebugString());
|
||||
}
|
||||
|
||||
string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
|
||||
@ -1119,7 +1117,7 @@ Status InferenceContext::AttachContext(const Status& status) {
|
||||
}
|
||||
|
||||
string error_context = strings::StrCat(
|
||||
" for '", node_def_->name(), "' (op: '", node_def_->op(),
|
||||
" for '", node_def_.name(), "' (op: '", node_def_.op(),
|
||||
"') with input shapes: ", absl::StrJoin(input_shapes, ", "));
|
||||
if (!input_from_tensors_str.empty()) {
|
||||
strings::StrAppend(&error_context, " and with computed input tensors: ",
|
||||
|
@ -161,9 +161,7 @@ class InferenceContext {
|
||||
// known from analysis of the graph.
|
||||
// <input_tensors_as_shapes> can have fewer elements than <input_shapes>.
|
||||
// Values of <input_tensors_as_shapes> do not need to outlive the context.
|
||||
//
|
||||
// REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
|
||||
InferenceContext(int graph_def_version, const NodeDef* node_def,
|
||||
InferenceContext(int graph_def_version, const NodeDef& node_def,
|
||||
const OpDef& op_def,
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
@ -179,11 +177,8 @@ class InferenceContext {
|
||||
// partially known from analysis of the graph. <input_tensors_as_shapes>
|
||||
// can have fewer elements than <input_shapes>. Values of
|
||||
// <input_tensors_as_shapes> do not need to outlive the context.
|
||||
//
|
||||
// REQUIRES: <node_def> is not NULL, and must outlive the
|
||||
// InferenceContext.
|
||||
InferenceContext(
|
||||
int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
|
||||
int graph_def_version, const NodeDef& node_def, const OpDef& op_def,
|
||||
const std::vector<PartialTensorShape>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<PartialTensorShape>& input_tensors_as_shapes,
|
||||
@ -306,7 +301,7 @@ class InferenceContext {
|
||||
Status output(StringPiece output_name,
|
||||
std::vector<ShapeHandle>* output) const;
|
||||
|
||||
AttrSlice attrs() const { return AttrSlice(*node_def_); }
|
||||
AttrSlice attrs() const { return AttrSlice(node_def_); }
|
||||
|
||||
// idx can be negative for an offset from end of dimensions.
|
||||
// idx must be in the range [-1 * s.rank, s.rank).
|
||||
@ -737,7 +732,7 @@ class InferenceContext {
|
||||
output_handle_shapes_and_types_;
|
||||
|
||||
const int graph_def_version_;
|
||||
const NodeDef* node_def_;
|
||||
const NodeDef& node_def_;
|
||||
NameRangeMap input_name_map_;
|
||||
NameRangeMap output_name_map_;
|
||||
|
||||
@ -784,7 +779,7 @@ inline DimensionOrConstant::DimensionOrConstant(int64 val) : val(val) {
|
||||
|
||||
template <class T>
|
||||
Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
|
||||
return GetNodeAttr(*node_def_, attr_name, value);
|
||||
return GetNodeAttr(node_def_, attr_name, value);
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
|
@ -77,7 +77,7 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
|
||||
.Attr("N", 3)
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(&def);
|
||||
InferenceContext c(kVersion, &def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
|
||||
InferenceContext c(kVersion, def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
|
||||
{}, {}, {});
|
||||
|
||||
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
|
||||
@ -114,7 +114,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
|
||||
EXPECT_EQ(InferenceContext::kUnknownDim,
|
||||
c.Value(InferenceContext::kUnknownDim));
|
||||
EXPECT_EQ(1, c.Value(1));
|
||||
@ -129,7 +129,7 @@ TEST_F(ShapeInferenceTest, Run) {
|
||||
NodeDef def;
|
||||
def.set_name("foo");
|
||||
def.set_op("foo_op");
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
|
||||
{
|
||||
@ -167,7 +167,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
def.set_op("foo_op");
|
||||
// Error when no constant tensors were requested.
|
||||
{
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
|
||||
{});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
auto fn = [](InferenceContext* c) {
|
||||
@ -186,7 +186,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
{
|
||||
Tensor input_t =
|
||||
::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2),
|
||||
{S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
auto fn = [](InferenceContext* c) {
|
||||
@ -208,7 +208,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
// shapes provided.
|
||||
{
|
||||
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
|
||||
{nullptr, &input_t}, {}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
auto fn = [](InferenceContext* c) {
|
||||
@ -231,7 +231,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
// shape was provided.
|
||||
{
|
||||
Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({3}), S({4})},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
|
||||
{nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
auto fn = [](InferenceContext* c) {
|
||||
@ -254,7 +254,7 @@ TEST_F(ShapeInferenceTest, AttachContext) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
|
||||
InferenceContext c(kVersion, def, MakeOpDef(3, 2),
|
||||
{Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(2, c.num_outputs());
|
||||
@ -295,7 +295,7 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, NumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
|
||||
InferenceContext c(kVersion, def, MakeOpDef(3, 2),
|
||||
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
|
||||
|
||||
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
|
||||
@ -309,8 +309,8 @@ TEST_F(ShapeInferenceTest, NumElements) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
|
||||
{}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -348,8 +348,8 @@ TEST_F(ShapeInferenceTest, WithRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
|
||||
{}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -386,8 +386,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
|
||||
{}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -424,7 +424,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithValue) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
|
||||
|
||||
auto d0 = c.Dim(c.input(0), 0);
|
||||
auto d1 = c.Dim(c.input(0), 1);
|
||||
@ -467,8 +467,8 @@ TEST_F(ShapeInferenceTest, WithValue) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, MergeDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})},
|
||||
{}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {},
|
||||
{}, {});
|
||||
|
||||
auto d2 = c.Dim(c.input(0), 0);
|
||||
auto d_unknown = c.Dim(c.input(0), 1);
|
||||
@ -530,7 +530,7 @@ TEST_F(ShapeInferenceTest, MergeDim) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, RelaxDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2),
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2),
|
||||
{S({2, InferenceContext::kUnknownDim, 2, 1,
|
||||
InferenceContext::kUnknownDim})},
|
||||
{}, {}, {});
|
||||
@ -578,7 +578,7 @@ TEST_F(ShapeInferenceTest, RelaxDim) {
|
||||
TEST_F(ShapeInferenceTest, RelaxShape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(
|
||||
kVersion, &def, MakeOpDef(7, 2),
|
||||
kVersion, def, MakeOpDef(7, 2),
|
||||
{Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
|
||||
S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
|
||||
{}, {}, {});
|
||||
@ -647,7 +647,7 @@ TEST_F(ShapeInferenceTest, RelaxShape) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, MergeShape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(7, 2),
|
||||
InferenceContext c(kVersion, def, MakeOpDef(7, 2),
|
||||
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
|
||||
Unknown(), S({1})},
|
||||
{}, {}, {});
|
||||
@ -753,7 +753,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, MergePrefix) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(4, 2),
|
||||
InferenceContext c(kVersion, def, MakeOpDef(4, 2),
|
||||
{
|
||||
Unknown(),
|
||||
S({-1, 2}),
|
||||
@ -808,7 +808,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Subshape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2),
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2),
|
||||
{S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
|
||||
|
||||
ShapeHandle unknown = c.input(1);
|
||||
@ -880,7 +880,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Concatenate) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(3, 2),
|
||||
InferenceContext c(kVersion, def, MakeOpDef(3, 2),
|
||||
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
@ -907,7 +907,7 @@ TEST_F(ShapeInferenceTest, Concatenate) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ReplaceDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
|
||||
{}, {}, {});
|
||||
|
||||
auto in = c.input(0);
|
||||
@ -939,7 +939,7 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, MakeShape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
|
||||
{}, {});
|
||||
|
||||
std::vector<DimensionHandle> dims;
|
||||
@ -966,7 +966,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShape) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto u0 = c.UnknownShape();
|
||||
auto u1 = c.UnknownShape();
|
||||
@ -978,7 +978,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
|
||||
TEST_F(ShapeInferenceTest, KnownShapeToProto) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto s = c.MakeShape({1, 2, 3});
|
||||
TensorShapeProto proto;
|
||||
@ -992,7 +992,7 @@ TEST_F(ShapeInferenceTest, KnownShapeToProto) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto u0 = c.UnknownShape();
|
||||
TensorShapeProto proto;
|
||||
@ -1005,7 +1005,7 @@ TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
|
||||
TEST_F(ShapeInferenceTest, Scalar) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto s0 = c.Scalar();
|
||||
EXPECT_EQ("[]", c.DebugString(s0));
|
||||
@ -1016,7 +1016,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
|
||||
TEST_F(ShapeInferenceTest, Vector) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto s0 = c.Vector(1);
|
||||
EXPECT_EQ("[1]", c.DebugString(s0));
|
||||
@ -1032,7 +1032,7 @@ TEST_F(ShapeInferenceTest, Vector) {
|
||||
TEST_F(ShapeInferenceTest, Matrix) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto s0 = c.Matrix(1, 2);
|
||||
EXPECT_EQ("[1,2]", c.DebugString(s0));
|
||||
@ -1054,7 +1054,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
auto create = [&](Tensor* t) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
|
||||
{});
|
||||
ShapeHandle out;
|
||||
Status s = c.MakeShapeFromShapeTensor(0, &out);
|
||||
@ -1115,7 +1115,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
// Test when the input shape is wrong.
|
||||
{
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
|
||||
{}, {});
|
||||
ShapeHandle out;
|
||||
EXPECT_EQ("Shape must be rank 1 but is rank 2",
|
||||
@ -1126,7 +1126,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
// With an unknown rank.
|
||||
ShapeHandle out;
|
||||
@ -1145,7 +1145,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
ShapeHandle out;
|
||||
TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
|
||||
@ -1159,7 +1159,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
TensorShapeProto proto;
|
||||
|
||||
// With a set unknown rank.
|
||||
@ -1195,7 +1195,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
||||
TEST_F(ShapeInferenceTest, MakeDim) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto d0 = c.MakeDim(1);
|
||||
auto d1 = c.MakeDim(1);
|
||||
@ -1209,7 +1209,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
|
||||
TEST_F(ShapeInferenceTest, UnknownDim) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto d0 = c.UnknownDim();
|
||||
auto d1 = c.UnknownDim();
|
||||
@ -1221,7 +1221,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
|
||||
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
|
||||
@ -1234,7 +1234,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
|
||||
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
|
||||
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
|
||||
{&t1, &t2}, {}, {});
|
||||
|
||||
EXPECT_TRUE(c.input_tensor(0) == &t1);
|
||||
@ -1246,8 +1246,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
|
||||
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
|
||||
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})},
|
||||
{&t1, &t2}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
|
||||
{}, {});
|
||||
|
||||
DimensionHandle d;
|
||||
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
|
||||
@ -1280,7 +1280,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
|
||||
.ok());
|
||||
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, op_reg_data.op_def, empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, op_reg_data.op_def, empty, {}, {}, {});
|
||||
string value;
|
||||
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
|
||||
EXPECT_EQ("bar", value);
|
||||
@ -1288,7 +1288,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Divide) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
|
||||
{}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
@ -1351,7 +1351,7 @@ TEST_F(ShapeInferenceTest, Divide) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Add) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
|
||||
{});
|
||||
|
||||
auto s = c.input(0);
|
||||
@ -1401,8 +1401,8 @@ TEST_F(ShapeInferenceTest, Add) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Subtract) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {},
|
||||
{}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {},
|
||||
{});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1451,8 +1451,8 @@ TEST_F(ShapeInferenceTest, Subtract) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Multiply) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {},
|
||||
{}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {},
|
||||
{});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1505,7 +1505,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
|
||||
TEST_F(ShapeInferenceTest, FullyDefined) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
|
||||
|
||||
// No rank or missing dimension information should return false.
|
||||
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
|
||||
@ -1518,8 +1518,8 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Min) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {},
|
||||
{}, {});
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {},
|
||||
{});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_1 = c.Dim(s, 0);
|
||||
@ -1567,7 +1567,7 @@ TEST_F(ShapeInferenceTest, Min) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Max) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
|
||||
{});
|
||||
|
||||
auto s = c.input(0);
|
||||
@ -1605,7 +1605,7 @@ TEST_F(ShapeInferenceTest, Max) {
|
||||
|
||||
void ShapeInferenceTest::TestMergeHandles(bool input_not_output) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
|
||||
{});
|
||||
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
||||
ShapeHandle s;
|
||||
@ -1716,7 +1716,7 @@ TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) {
|
||||
|
||||
void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) {
|
||||
NodeDef def;
|
||||
InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
|
||||
InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
|
||||
{});
|
||||
auto make_shape = [&c](std::initializer_list<int64> dim_sizes) {
|
||||
ShapeHandle s;
|
||||
|
@ -60,7 +60,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
|
||||
}
|
||||
}
|
||||
shape_inference::InferenceContext c(
|
||||
op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes,
|
||||
op.graph_def_version, op.node_def, op_reg_data->op_def, in_shapes,
|
||||
op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types));
|
||||
TF_RETURN_IF_ERROR(c.construction_status());
|
||||
if (op_reg_data->shape_inference_fn == nullptr) {
|
||||
|
@ -1168,7 +1168,7 @@ class SymbolicShapeRefiner {
|
||||
std::vector<ShapeHandle> input_tensors_as_shapes;
|
||||
|
||||
node_ctx.inference_context.reset(new InferenceContext(
|
||||
graph_def_version_, node, node_ctx.op_data->op_def, input_shapes,
|
||||
graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
|
||||
input_tensors, input_tensors_as_shapes,
|
||||
std::move(input_handle_shapes_and_types)));
|
||||
const Status s = node_ctx.inference_context->construction_status();
|
||||
|
@ -198,7 +198,7 @@ TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
|
||||
new std::vector<std::pair<PartialTensorShape, DataType>>(
|
||||
{{PartialTensorShape(), DT_BOOL}}));
|
||||
shape_inference::InferenceContext c(
|
||||
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
|
||||
TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
|
||||
{PartialTensorShape()}, {}, {}, handle_data);
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
|
||||
|
@ -229,7 +229,7 @@ TEST(MathOpsTest, Select_ShapeFn) {
|
||||
auto run_inference_for_handles = [&]() -> Status {
|
||||
CHECK(op_reg_data->shape_inference_fn != nullptr);
|
||||
c.reset(new shape_inference::InferenceContext(
|
||||
TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
|
||||
TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
|
||||
{PartialTensorShape(), PartialTensorShape(), PartialTensorShape()}, {},
|
||||
{}, handle_data));
|
||||
TF_CHECK_OK(c->construction_status());
|
||||
|
Loading…
Reference in New Issue
Block a user