diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 65cfb1a90e6..5dfaa160d0c 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -173,17 +173,6 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { return Status::OK(); } -namespace { -Status CheckKnownDim(shape_inference::InferenceContext* c, const Dimension* dim, - const char* name) { - if (!c->ValueKnown(dim)) { - return errors::InvalidArgument("Cannot infer shape because dimension ", - name, " is not known."); - } - return Status::OK(); -} -} // namespace - Status Conv2DShape(shape_inference::InferenceContext* c) { const Shape* input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); @@ -224,10 +213,10 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { const Dimension* output_depth_dim = c->Dim(filter_shape, 3); // At the moment we need to know the values of several fields. - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols")); auto in_rows = c->Value(in_rows_dim); auto in_cols = c->Value(in_cols_dim); @@ -292,12 +281,12 @@ Status Conv3DShape(shape_inference::InferenceContext* c) { const Dimension* output_depth_dim = c->Dim(filter_shape, 4); // At the moment we need to know the values of several fields. - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_planes_dim, "filter_planes")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_planes_dim, "in_planes")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_planes_dim, "filter_planes")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols")); auto in_planes = c->Value(in_planes_dim); auto in_rows = c->Value(in_rows_dim); @@ -357,12 +346,12 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { const Dimension* depth_multiplier = c->Dim(filter_shape, 3); // At the moment we need to know the values of several fields. - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, input_depth, "depth")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, depth_multiplier, "depth_multiplier")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(input_depth, "depth")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(depth_multiplier, "depth_multiplier")); // Check that the input depths are compatible. TF_RETURN_IF_ERROR( @@ -449,8 +438,8 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { const Dimension* output_depth_dim = c->Dim(input_shape, 3); // At the moment we need to know the values of several fields. - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols")); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); @@ -536,9 +525,9 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { const Dimension* in_depth_dim = c->Dim(input_shape, 3); // At the moment we need to know the values of several fields. - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_depth_dim, "in_depth")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_depth_dim, "in_depth")); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); @@ -614,9 +603,9 @@ Status Pool3DShape(shape_inference::InferenceContext* c) { const Dimension* output_depth_dim = c->Dim(input_shape, 4); // At the moment we need to know the values of several fields. - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_planes_dim, "in_planes")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols")); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 6b35f82d13d..1aa51f50175 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -261,6 +261,16 @@ class InferenceContext { Status construction_status() const { return construction_status_; } + // Validates that 'dim' has a known value, and prints an error + // message containing 'name' if validation fails. + Status ValidateKnownDim(const Dimension* dim, const char* name) { + if (!ValueKnown(dim)) { + return errors::InvalidArgument("Cannot infer shape because dimension ", + name, " is not known."); + } + return Status::OK(); + } + private: const Dimension* GetDimension(const DimensionOrConstant& d); diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index 103909c6b50..fffb25da6de 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -887,5 +887,13 @@ TEST(ShapeInferenceTest, FullyDefined) { EXPECT_TRUE(c.FullyDefined(c.Scalar())); } +TEST(ShapeInferenceTest, ValidateKnownDim) { + NodeDef def; + InferenceContext c(&def, MakeOpDef(0, 2), {}, {}); + + EXPECT_FALSE(c.ValidateKnownDim(c.UnknownDim(), "unknown").ok()); + EXPECT_TRUE(c.ValidateKnownDim(c.Dim(c.Matrix(1, 2), 0), "known").ok()); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 8619a1e1a90..f10bccd87ca 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2771,6 +2771,76 @@ REGISTER_OP("ExtractImagePatches") .Attr("rates: list(int) >= 4") .Attr("T: realnumbertype") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + const Shape* input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + + std::vector ksizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes)); + if (ksizes.size() != 4) { + return errors::InvalidArgument( + "ExtractImagePatches requires the ksizes attribute to contain 4 " + "values, but got: ", + ksizes.size()); + } + + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 4) { + return errors::InvalidArgument( + "ExtractImagePatches requires the stride attribute to contain 4 " + "values, but got: ", + strides.size()); + } + + std::vector rates; + TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates)); + if (rates.size() != 4) { + return errors::InvalidArgument( + "ExtractImagePatches requires the rates attribute to contain 4 " + "values, but got: ", + rates.size()); + } + + int32 ksize_rows = ksizes[1]; + int32 ksize_cols = ksizes[2]; + + int32 stride_rows = strides[1]; + int32 stride_cols = strides[2]; + + int32 rate_rows = rates[1]; + int32 rate_cols = rates[2]; + + int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); + int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); + + const Dimension* batch_size_dim = c->Dim(input_shape, 0); + const Dimension* in_rows_dim = c->Dim(input_shape, 1); + const Dimension* in_cols_dim = c->Dim(input_shape, 2); + const Dimension* output_depth_dim = c->Dim(input_shape, 3); + + // At the moment we need to know the values of several fields. + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols")); + auto in_rows = c->Value(in_rows_dim); + auto in_cols = c->Value(in_cols_dim); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + int64 output_rows, output_cols; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_rows, ksize_rows_eff, stride_rows, padding, &output_rows, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_cols, ksize_cols_eff, stride_cols, padding, &output_cols, + &padding_before, &padding_after)); + const Shape* output_shape = c->MakeShape( + {batch_size_dim, output_rows, output_cols, output_depth_dim}); + c->set_output(0, output_shape); + return Status::OK(); + }) .Doc(R"doc( Extract `patches` from `images` and put them in the "depth" output dimension. diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index ac6ba3f3a02..6345db128e7 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -877,4 +877,36 @@ TEST(ArrayOpsTest, OneHot_ShapeFn) { INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]"); } +TEST(NNOpsTest, ExtractImagePatchesShapeTest) { + ShapeInferenceTestOp op("ExtractImagePatches"); + auto set_op = [&op](const std::vector& ksizes, + const std::vector& strides, + const std::vector& rates, const string& padding) { + TF_CHECK_OK(NodeDefBuilder("test", "ExtractImagePatches") + .Input("input", 0, DT_FLOAT) + .Attr("ksizes", ksizes) + .Attr("strides", strides) + .Attr("rates", rates) + .Attr("padding", padding) + .Finalize(&op.node_def)); + }; + + // Just tests that the ksize calculation with rates works. Most of + // the other code is boilerplate that is tested by a variety of + // other ops. + // + // ksizes is 2x2. rate rows and cols is 2, so ksize_rows and + // cols are changed to be 2 + (2 - 1) = 3. 7x7 input with 3x3 + // filter and 1x1 stride gives a 5x5 output. + set_op({1, 2, 2, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID"); + INFER_OK(op, "[1,7,7,2]", "[d0_0,5,5,d0_3]"); + + // Bad ksize rank + set_op({1, 2, 2, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID"); + INFER_ERROR( + "ExtractImagePatches requires the ksizes attribute to contain 4 values, " + "but got: 5", + op, "[1,7,7,2]"); +} + } // end namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index c80a3c27921..d38a5350831 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -968,17 +968,6 @@ output: Gradients w.r.t. the input of `max_pool`. // -------------------------------------------------------------------------- -namespace { -Status CheckKnownDim(shape_inference::InferenceContext* c, const Dimension* dim, - const char* name) { - if (!c->ValueKnown(dim)) { - return errors::InvalidArgument("Cannot infer shape because dimension ", - name, " is not known."); - } - return Status::OK(); -} -} // namespace - REGISTER_OP("Dilation2D") .Input("input: T") .Input("filter: T") @@ -1029,10 +1018,10 @@ REGISTER_OP("Dilation2D") c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused)); // At the moment we need to know the values of several fields. - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows")); - TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows")); + TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols")); auto in_rows = c->Value(in_rows_dim); auto in_cols = c->Value(in_cols_dim);