Move FusedConvBiasActivationShape out of common_shape_fns.cc to a lambda inside the op.

PiperOrigin-RevId: 168300911
This commit is contained in:
A. Unique TensorFlower 2017-09-11 16:09:49 -07:00 committed by TensorFlower Gardener
parent 3a98035fa8
commit ab7f22de6a
3 changed files with 47 additions and 49 deletions

View File

@ -52,7 +52,53 @@ REGISTER_OP("FusedConv2DBiasActivation")
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
.Attr("activation_mode: {'Relu'} = 'Relu'")
.SetShapeFn(shape_inference::FusedConvBiasActivationShape)
.SetShapeFn([](shape_inference::InferenceContext* c) {
using shape_inference::ShapeHandle;
using shape_inference::DimensionHandle;
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
string data_format_str, filter_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str));
TensorFormat data_format;
FormatFromString(data_format_str, &data_format);
FilterTensorFormat filter_format;
FilterFormatFromString(filter_format_str, &filter_format);
constexpr int num_spatial_dims = 2;
const int rank =
GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
ShapeHandle filter_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
DimensionHandle output_depth_dim =
c->Dim(filter_shape,
GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
int64 output_depth_dim_val = c->Value(output_depth_dim);
ShapeHandle bias_shape;
// Bias should be a 1-D tensor.
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape));
DimensionHandle bias_dim = c->Dim(bias_shape, 0);
int64 bias_dim_val = c->Value(bias_dim);
if (output_depth_dim_val != bias_dim_val) {
return errors::InvalidArgument(
"Output depth dimension (", output_depth_dim_val,
") and bias dimension (", bias_dim_val, ") do not match.");
}
// Check side input shape matches the output shape.
ShapeHandle side_input_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape));
if (c->Rank(side_input_shape) > 1) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
}
return Status::OK();
})
.Doc(R"doc(
Computes a fused kernel which implements: 2-D convolution, adds side input,
with separate scaling on convolution and side inputs, then adds bias and

View File

@ -202,51 +202,6 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
// input, filter, bias, output
Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(Conv2DShape(c));
string data_format_str, filter_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str));
TensorFormat data_format;
FormatFromString(data_format_str, &data_format);
FilterTensorFormat filter_format;
FilterFormatFromString(filter_format_str, &filter_format);
constexpr int num_spatial_dims = 2;
const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
ShapeHandle filter_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
DimensionHandle output_depth_dim = c->Dim(
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
int64 output_depth_dim_val = c->Value(output_depth_dim);
ShapeHandle bias_shape;
// Bias should be a 1-D tensor.
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape));
DimensionHandle bias_dim = c->Dim(bias_shape, 0);
int64 bias_dim_val = c->Value(bias_dim);
if (output_depth_dim_val != bias_dim_val) {
return errors::InvalidArgument(
"Output depth dimension (", output_depth_dim_val,
") and bias dimension (", bias_dim_val, ") do not match.");
}
// Check side input shape matches the output shape.
ShapeHandle side_input_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape));
if (c->Rank(side_input_shape) > 1) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
}
return Status::OK();
}
Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
const ShapeHandle shape_handle,
const string& tensor_name,

View File

@ -167,9 +167,6 @@ Status Conv2DShape(shape_inference::InferenceContext* c);
// Shape function for Conv3D-like operations.
Status Conv3DShape(shape_inference::InferenceContext* c);
// Shape function for FusedConvBiasActivation operation.
Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c);
// Shape function for DepthwiseConv2D-like operations.
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);