Move FusedConvBiasActivationShape out of common_shape_fns.cc to a lambda inside the op.
PiperOrigin-RevId: 168300911
This commit is contained in:
parent
3a98035fa8
commit
ab7f22de6a
@ -52,7 +52,53 @@ REGISTER_OP("FusedConv2DBiasActivation")
|
||||
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||
.Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
|
||||
.Attr("activation_mode: {'Relu'} = 'Relu'")
|
||||
.SetShapeFn(shape_inference::FusedConvBiasActivationShape)
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
using shape_inference::ShapeHandle;
|
||||
using shape_inference::DimensionHandle;
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
|
||||
string data_format_str, filter_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str));
|
||||
|
||||
TensorFormat data_format;
|
||||
FormatFromString(data_format_str, &data_format);
|
||||
FilterTensorFormat filter_format;
|
||||
FilterFormatFromString(filter_format_str, &filter_format);
|
||||
|
||||
constexpr int num_spatial_dims = 2;
|
||||
const int rank =
|
||||
GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
|
||||
ShapeHandle filter_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
|
||||
|
||||
DimensionHandle output_depth_dim =
|
||||
c->Dim(filter_shape,
|
||||
GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
|
||||
int64 output_depth_dim_val = c->Value(output_depth_dim);
|
||||
|
||||
ShapeHandle bias_shape;
|
||||
// Bias should be a 1-D tensor.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape));
|
||||
DimensionHandle bias_dim = c->Dim(bias_shape, 0);
|
||||
int64 bias_dim_val = c->Value(bias_dim);
|
||||
|
||||
if (output_depth_dim_val != bias_dim_val) {
|
||||
return errors::InvalidArgument(
|
||||
"Output depth dimension (", output_depth_dim_val,
|
||||
") and bias dimension (", bias_dim_val, ") do not match.");
|
||||
}
|
||||
|
||||
// Check side input shape matches the output shape.
|
||||
ShapeHandle side_input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape));
|
||||
if (c->Rank(side_input_shape) > 1) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Computes a fused kernel which implements: 2-D convolution, adds side input,
|
||||
with separate scaling on convolution and side inputs, then adds bias and
|
||||
|
@ -202,51 +202,6 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// input, filter, bias, output
|
||||
Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(Conv2DShape(c));
|
||||
|
||||
string data_format_str, filter_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str));
|
||||
|
||||
TensorFormat data_format;
|
||||
FormatFromString(data_format_str, &data_format);
|
||||
FilterTensorFormat filter_format;
|
||||
FilterFormatFromString(filter_format_str, &filter_format);
|
||||
|
||||
constexpr int num_spatial_dims = 2;
|
||||
const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
|
||||
ShapeHandle filter_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
|
||||
|
||||
DimensionHandle output_depth_dim = c->Dim(
|
||||
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
|
||||
int64 output_depth_dim_val = c->Value(output_depth_dim);
|
||||
|
||||
ShapeHandle bias_shape;
|
||||
// Bias should be a 1-D tensor.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape));
|
||||
DimensionHandle bias_dim = c->Dim(bias_shape, 0);
|
||||
int64 bias_dim_val = c->Value(bias_dim);
|
||||
|
||||
if (output_depth_dim_val != bias_dim_val) {
|
||||
return errors::InvalidArgument(
|
||||
"Output depth dimension (", output_depth_dim_val,
|
||||
") and bias dimension (", bias_dim_val, ") do not match.");
|
||||
}
|
||||
|
||||
// Check side input shape matches the output shape.
|
||||
ShapeHandle side_input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape));
|
||||
if (c->Rank(side_input_shape) > 1) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
|
||||
const ShapeHandle shape_handle,
|
||||
const string& tensor_name,
|
||||
|
@ -167,9 +167,6 @@ Status Conv2DShape(shape_inference::InferenceContext* c);
|
||||
// Shape function for Conv3D-like operations.
|
||||
Status Conv3DShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for FusedConvBiasActivation operation.
|
||||
Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for DepthwiseConv2D-like operations.
|
||||
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user