TensorFlow: implement ExtractImagePatches shape fn, move Validation

of KnownDim strewn across a few files into shape_inference.h with
a simple unittest.
Change: 129152593
This commit is contained in:
Vijay Vasudevan 2016-08-02 14:57:06 -08:00 committed by TensorFlower Gardener
parent 5a828e3a9c
commit 21ca7e442c
6 changed files with 148 additions and 50 deletions

View File

@ -173,17 +173,6 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
namespace {
Status CheckKnownDim(shape_inference::InferenceContext* c, const Dimension* dim,
const char* name) {
if (!c->ValueKnown(dim)) {
return errors::InvalidArgument("Cannot infer shape because dimension ",
name, " is not known.");
}
return Status::OK();
}
} // namespace
Status Conv2DShape(shape_inference::InferenceContext* c) {
const Shape* input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
@ -224,10 +213,10 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
const Dimension* output_depth_dim = c->Dim(filter_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols"));
auto in_rows = c->Value(in_rows_dim);
auto in_cols = c->Value(in_cols_dim);
@ -292,12 +281,12 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
const Dimension* output_depth_dim = c->Dim(filter_shape, 4);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_planes_dim, "filter_planes"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_planes_dim, "in_planes"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_planes_dim, "filter_planes"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols"));
auto in_planes = c->Value(in_planes_dim);
auto in_rows = c->Value(in_rows_dim);
@ -357,12 +346,12 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
const Dimension* depth_multiplier = c->Dim(filter_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, input_depth, "depth"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, depth_multiplier, "depth_multiplier"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(input_depth, "depth"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(depth_multiplier, "depth_multiplier"));
// Check that the input depths are compatible.
TF_RETURN_IF_ERROR(
@ -449,8 +438,8 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
const Dimension* output_depth_dim = c->Dim(input_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
@ -536,9 +525,9 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
const Dimension* in_depth_dim = c->Dim(input_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_depth_dim, "in_depth"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_depth_dim, "in_depth"));
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
@ -614,9 +603,9 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
const Dimension* output_depth_dim = c->Dim(input_shape, 4);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_planes_dim, "in_planes"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));

View File

@ -261,6 +261,16 @@ class InferenceContext {
Status construction_status() const { return construction_status_; }
// Validates that 'dim' has a known value, and prints an error
// message containing 'name' if validation fails.
Status ValidateKnownDim(const Dimension* dim, const char* name) {
if (!ValueKnown(dim)) {
return errors::InvalidArgument("Cannot infer shape because dimension ",
name, " is not known.");
}
return Status::OK();
}
private:
const Dimension* GetDimension(const DimensionOrConstant& d);

View File

@ -887,5 +887,13 @@ TEST(ShapeInferenceTest, FullyDefined) {
EXPECT_TRUE(c.FullyDefined(c.Scalar()));
}
TEST(ShapeInferenceTest, ValidateKnownDim) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(0, 2), {}, {});
EXPECT_FALSE(c.ValidateKnownDim(c.UnknownDim(), "unknown").ok());
EXPECT_TRUE(c.ValidateKnownDim(c.Dim(c.Matrix(1, 2), 0), "known").ok());
}
} // namespace shape_inference
} // namespace tensorflow

View File

@ -2771,6 +2771,76 @@ REGISTER_OP("ExtractImagePatches")
.Attr("rates: list(int) >= 4")
.Attr("T: realnumbertype")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
const Shape* input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
std::vector<int32> ksizes;
TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
if (ksizes.size() != 4) {
return errors::InvalidArgument(
"ExtractImagePatches requires the ksizes attribute to contain 4 "
"values, but got: ",
ksizes.size());
}
std::vector<int32> strides;
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
if (strides.size() != 4) {
return errors::InvalidArgument(
"ExtractImagePatches requires the stride attribute to contain 4 "
"values, but got: ",
strides.size());
}
std::vector<int32> rates;
TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
if (rates.size() != 4) {
return errors::InvalidArgument(
"ExtractImagePatches requires the rates attribute to contain 4 "
"values, but got: ",
rates.size());
}
int32 ksize_rows = ksizes[1];
int32 ksize_cols = ksizes[2];
int32 stride_rows = strides[1];
int32 stride_cols = strides[2];
int32 rate_rows = rates[1];
int32 rate_cols = rates[2];
int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
const Dimension* batch_size_dim = c->Dim(input_shape, 0);
const Dimension* in_rows_dim = c->Dim(input_shape, 1);
const Dimension* in_cols_dim = c->Dim(input_shape, 2);
const Dimension* output_depth_dim = c->Dim(input_shape, 3);
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
auto in_rows = c->Value(in_rows_dim);
auto in_cols = c->Value(in_cols_dim);
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
int64 output_rows, output_cols;
int64 padding_before, padding_after;
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
in_rows, ksize_rows_eff, stride_rows, padding, &output_rows,
&padding_before, &padding_after));
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
in_cols, ksize_cols_eff, stride_cols, padding, &output_cols,
&padding_before, &padding_after));
const Shape* output_shape = c->MakeShape(
{batch_size_dim, output_rows, output_cols, output_depth_dim});
c->set_output(0, output_shape);
return Status::OK();
})
.Doc(R"doc(
Extract `patches` from `images` and put them in the "depth" output dimension.

View File

@ -877,4 +877,36 @@ TEST(ArrayOpsTest, OneHot_ShapeFn) {
INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]");
}
TEST(NNOpsTest, ExtractImagePatchesShapeTest) {
ShapeInferenceTestOp op("ExtractImagePatches");
auto set_op = [&op](const std::vector<int32>& ksizes,
const std::vector<int32>& strides,
const std::vector<int32>& rates, const string& padding) {
TF_CHECK_OK(NodeDefBuilder("test", "ExtractImagePatches")
.Input("input", 0, DT_FLOAT)
.Attr("ksizes", ksizes)
.Attr("strides", strides)
.Attr("rates", rates)
.Attr("padding", padding)
.Finalize(&op.node_def));
};
// Just tests that the ksize calculation with rates works. Most of
// the other code is boilerplate that is tested by a variety of
// other ops.
//
// ksizes is 2x2. rate rows and cols is 2, so ksize_rows and
// cols are changed to be 2 + (2 - 1) = 3. 7x7 input with 3x3
// filter and 1x1 stride gives a 5x5 output.
set_op({1, 2, 2, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
INFER_OK(op, "[1,7,7,2]", "[d0_0,5,5,d0_3]");
// Bad ksize rank
set_op({1, 2, 2, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
INFER_ERROR(
"ExtractImagePatches requires the ksizes attribute to contain 4 values, "
"but got: 5",
op, "[1,7,7,2]");
}
} // end namespace tensorflow

View File

@ -968,17 +968,6 @@ output: Gradients w.r.t. the input of `max_pool`.
// --------------------------------------------------------------------------
namespace {
Status CheckKnownDim(shape_inference::InferenceContext* c, const Dimension* dim,
const char* name) {
if (!c->ValueKnown(dim)) {
return errors::InvalidArgument("Cannot infer shape because dimension ",
name, " is not known.");
}
return Status::OK();
}
} // namespace
REGISTER_OP("Dilation2D")
.Input("input: T")
.Input("filter: T")
@ -1029,10 +1018,10 @@ REGISTER_OP("Dilation2D")
c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
// At the moment we need to know the values of several fields.
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_rows_dim, "in_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(in_cols_dim, "in_cols"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_rows_dim, "filter_rows"));
TF_RETURN_IF_ERROR(c->ValidateKnownDim(filter_cols_dim, "filter_cols"));
auto in_rows = c->Value(in_rows_dim);
auto in_cols = c->Value(in_cols_dim);