From ab7f22de6aba356fae564e3e8bbb0beb9a98acb4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 11 Sep 2017 16:09:49 -0700 Subject: [PATCH] Move FusedConvBiasActivationShape out of common_shape_fns.cc to a lambda inside the op. PiperOrigin-RevId: 168300911 --- .../ops/fused_conv2d_bias_activation_op.cc | 48 ++++++++++++++++++- tensorflow/core/framework/common_shape_fns.cc | 45 ----------------- tensorflow/core/framework/common_shape_fns.h | 3 -- 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 48f058b4c53..c9d0e1f41c9 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -52,7 +52,53 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") .Attr("activation_mode: {'Relu'} = 'Relu'") - .SetShapeFn(shape_inference::FusedConvBiasActivationShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + using shape_inference::ShapeHandle; + using shape_inference::DimensionHandle; + TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); + + string data_format_str, filter_format_str; + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); + TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str)); + + TensorFormat data_format; + FormatFromString(data_format_str, &data_format); + FilterTensorFormat filter_format; + FilterFormatFromString(filter_format_str, &filter_format); + + constexpr int num_spatial_dims = 2; + const int rank = + GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); + ShapeHandle filter_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); + + DimensionHandle output_depth_dim = + c->Dim(filter_shape, + GetFilterDimIndex(filter_format, 'O')); + int64 output_depth_dim_val = c->Value(output_depth_dim); + + ShapeHandle bias_shape; + // Bias should be a 1-D tensor. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape)); + DimensionHandle bias_dim = c->Dim(bias_shape, 0); + int64 bias_dim_val = c->Value(bias_dim); + + if (output_depth_dim_val != bias_dim_val) { + return errors::InvalidArgument( + "Output depth dimension (", output_depth_dim_val, + ") and bias dimension (", bias_dim_val, ") do not match."); + } + + // Check side input shape matches the output shape. + ShapeHandle side_input_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape)); + if (c->Rank(side_input_shape) > 1) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused)); + } + + return Status::OK(); + }) .Doc(R"doc( Computes a fused kernel which implements: 2-D convolution, adds side input, with separate scaling on convolution and side inputs, then adds bias and diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index ab21f47282e..9e0f6d3be11 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -202,51 +202,6 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { return Status::OK(); } -// input, filter, bias, output -Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) { - TF_RETURN_IF_ERROR(Conv2DShape(c)); - - string data_format_str, filter_format_str; - TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); - TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str)); - - TensorFormat data_format; - FormatFromString(data_format_str, &data_format); - FilterTensorFormat filter_format; - FilterFormatFromString(filter_format_str, &filter_format); - - constexpr int num_spatial_dims = 2; - const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); - ShapeHandle filter_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); - - DimensionHandle output_depth_dim = c->Dim( - filter_shape, GetFilterDimIndex(filter_format, 'O')); - int64 output_depth_dim_val = c->Value(output_depth_dim); - - ShapeHandle bias_shape; - // Bias should be a 1-D tensor. - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape)); - DimensionHandle bias_dim = c->Dim(bias_shape, 0); - int64 bias_dim_val = c->Value(bias_dim); - - if (output_depth_dim_val != bias_dim_val) { - return errors::InvalidArgument( - "Output depth dimension (", output_depth_dim_val, - ") and bias dimension (", bias_dim_val, ") do not match."); - } - - // Check side input shape matches the output shape. - ShapeHandle side_input_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape)); - if (c->Rank(side_input_shape) > 1) { - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused)); - } - - return Status::OK(); -} - Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, const ShapeHandle shape_handle, const string& tensor_name, diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index fb79df07a4f..a3c79afc0b0 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -167,9 +167,6 @@ Status Conv2DShape(shape_inference::InferenceContext* c); // Shape function for Conv3D-like operations. Status Conv3DShape(shape_inference::InferenceContext* c); -// Shape function for FusedConvBiasActivation operation. -Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c); - // Shape function for DepthwiseConv2D-like operations. Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);