Take #2: Improve Conv2DBackpropInput to take input_sizes as a 2D shape.
When input_sizes is a 2D shape, the input batch size comes from output_grad and the input channel size comes from the filter. With this change, input_sizes is more likely to be a constant (e.g. even when the batch size is variable) so tf2tensorrt is able to convert more Conv2DBackpropInput to IDeconvolutionLayer. Changes to tf2tensorrt will come in separate CLs. I haven't made tf2xla support input_sizes being a 2D shape. It would error out for now. So we disabled the test added to conv_ops_test.py for XLA. PiperOrigin-RevId: 303217218 Change-Id: I283106657c00f49be41a74c7131bf8be787742a8
This commit is contained in:
parent
55d96a7c83
commit
9dadbb283d
tensorflow
compiler/tf2xla/kernels
core
framework
kernels
ops
python/kernel_tests
@ -107,6 +107,11 @@ class ConvBackpropInputOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
|
||||
xla::Shape input_shape =
|
||||
TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
|
||||
OP_REQUIRES(ctx, input_shape.rank() == attrs_.num_spatial_dims + 2,
|
||||
errors::InvalidArgument(
|
||||
"The rank of the specified input shape must be "
|
||||
"num_spatial_dims + 2. Expected ",
|
||||
attrs_.num_spatial_dims + 2, " got ", input_shape.rank()));
|
||||
|
||||
xla::StatusOr<xla::XlaOp> in_backprop =
|
||||
MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
|
||||
|
@ -822,6 +822,78 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) {
|
||||
string data_format_str;
|
||||
if (!c->GetAttr("data_format", &data_format_str).ok()) {
|
||||
data_format_str = "NHWC";
|
||||
}
|
||||
TensorFormat data_format;
|
||||
if (!FormatFromString(data_format_str, &data_format)) {
|
||||
return errors::InvalidArgument("Invalid data format string: ",
|
||||
data_format_str);
|
||||
}
|
||||
|
||||
// For the rest of this function, output_grad_* describes out_backprop and
|
||||
// input_grad_* describes in_backprop.
|
||||
ShapeHandle output_grad_shape = c->input(2);
|
||||
TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape));
|
||||
ShapeHandle filter_shape = c->input(1);
|
||||
TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape));
|
||||
|
||||
DimensionHandle batch_size_dim;
|
||||
DimensionHandle output_grad_depth_dim;
|
||||
gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2);
|
||||
TF_RETURN_IF_ERROR(DimensionsFromShape(
|
||||
output_grad_shape, data_format, &batch_size_dim,
|
||||
absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c));
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused));
|
||||
|
||||
ShapeHandle specified_input_grad_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape));
|
||||
if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) {
|
||||
TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4,
|
||||
&specified_input_grad_shape));
|
||||
}
|
||||
|
||||
// input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number
|
||||
// of groups is larger than 1. If input_sizes is a 4D shape, we collect
|
||||
// input_grad_depth_dim from input_sizes; otherwise we compute it as
|
||||
// c->Dim(filter_shape,2).
|
||||
DimensionHandle input_grad_depth_dim;
|
||||
gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2);
|
||||
int specified_input_grad_rank = c->Rank(specified_input_grad_shape);
|
||||
if (specified_input_grad_rank == 4) {
|
||||
DimensionHandle specified_batch_size_dim;
|
||||
TF_RETURN_IF_ERROR(DimensionsFromShape(
|
||||
specified_input_grad_shape, data_format, &specified_batch_size_dim,
|
||||
absl::MakeSpan(specified_input_grad_spatial_dims),
|
||||
&input_grad_depth_dim, c));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(specified_batch_size_dim, batch_size_dim, &unused));
|
||||
} else if (specified_input_grad_rank == 2) {
|
||||
specified_input_grad_spatial_dims[0] =
|
||||
c->Dim(specified_input_grad_shape, 0);
|
||||
specified_input_grad_spatial_dims[1] =
|
||||
c->Dim(specified_input_grad_shape, 1);
|
||||
input_grad_depth_dim = c->Dim(filter_shape, 2);
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Conv2DBackpropInput requires input_sizes to contain 4 values or 2 "
|
||||
"values, but got: ",
|
||||
specified_input_grad_rank);
|
||||
}
|
||||
|
||||
ShapeHandle input_grad_shape;
|
||||
TF_RETURN_IF_ERROR(ShapeFromDimensions(
|
||||
batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim,
|
||||
data_format, c, &input_grad_shape));
|
||||
c->set_output(0, input_grad_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,
|
||||
|
@ -138,6 +138,9 @@ Status DepthwiseConv2DNativeShapeWithExplicitPadding(
|
||||
// explicit padding.
|
||||
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for Conv2DBackpropInput.
|
||||
Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for AvgPool-like operations.
|
||||
Status AvgPoolShape(shape_inference::InferenceContext* c);
|
||||
|
||||
|
@ -450,14 +450,12 @@ class Conv2DBackpropInputOp : public OpKernel {
|
||||
const Tensor& input_sizes = context->input(0);
|
||||
const Tensor& filter = context->input(1);
|
||||
const Tensor& out_backprop = context->input(2);
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsVector(input_sizes.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
|
||||
input_sizes.dims()));
|
||||
|
||||
TensorShape input_shape;
|
||||
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
|
||||
input_sizes.vec<int32>(), &input_shape));
|
||||
OP_REQUIRES_OK(context,
|
||||
Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
|
||||
out_backprop.shape(),
|
||||
data_format_, &input_shape));
|
||||
|
||||
Tensor* in_backprop = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -549,14 +547,12 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
const Tensor& input_sizes = context->input(0);
|
||||
const Tensor& filter = context->input(1);
|
||||
const Tensor& out_backprop = context->input(2);
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsVector(input_sizes.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
|
||||
input_sizes.dims()));
|
||||
|
||||
TensorShape input_shape;
|
||||
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
|
||||
input_sizes.vec<int32>(), &input_shape));
|
||||
OP_REQUIRES_OK(context,
|
||||
Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
|
||||
out_backprop.shape(),
|
||||
data_format_, &input_shape));
|
||||
|
||||
ConvBackpropDimensions dims;
|
||||
OP_REQUIRES_OK(context,
|
||||
|
@ -166,4 +166,35 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims,
|
||||
dims);
|
||||
}
|
||||
|
||||
Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes,
|
||||
const TensorShape& filter_shape,
|
||||
const TensorShape& out_backprop_shape,
|
||||
const TensorFormat& data_format,
|
||||
TensorShape* input_shape) {
|
||||
if (!TensorShapeUtils::IsVector(input_sizes.shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
|
||||
input_sizes.dims());
|
||||
}
|
||||
|
||||
if (input_sizes.dim_size(0) == 4) {
|
||||
return TensorShapeUtils::MakeShape(input_sizes.vec<int32>(), input_shape);
|
||||
}
|
||||
|
||||
if (input_sizes.dim_size(0) == 2) {
|
||||
const int batch_size = GetTensorDim(out_backprop_shape, data_format, 'N');
|
||||
const int output_height = input_sizes.vec<int32>()(0);
|
||||
const int output_width = input_sizes.vec<int32>()(1);
|
||||
const int output_depth = filter_shape.dim_size(2);
|
||||
*input_shape = ShapeFromFormat(data_format, batch_size, output_height,
|
||||
output_width, output_depth);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return errors::InvalidArgument(
|
||||
"Conv2DBackpropInput requires input_sizes to "
|
||||
"contain 4 values or 2 values, but got: ",
|
||||
input_sizes.dim_size(0));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -83,6 +83,13 @@ Status ConvBackpropComputeDimensionsV2(
|
||||
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
|
||||
Padding padding, absl::Span<const int64> explicit_paddings,
|
||||
TensorFormat data_format, ConvBackpropDimensions* dims);
|
||||
|
||||
// Computes the shape of the in_backprop.
|
||||
Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes,
|
||||
const TensorShape& filter_shape,
|
||||
const TensorShape& out_backprop_shape,
|
||||
const TensorFormat& data_format,
|
||||
TensorShape* input_shape);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_
|
||||
|
@ -357,13 +357,7 @@ REGISTER_OP("Conv2DBackpropInput")
|
||||
.Attr(GetExplicitPaddingsAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
|
||||
c->set_output(0, s);
|
||||
return Status::OK();
|
||||
});
|
||||
.SetShapeFn(shape_inference::Conv2DBackpropInputShape);
|
||||
|
||||
// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
|
||||
// more general string attribute ('kernel_impl'?) that can be used to
|
||||
|
@ -320,6 +320,25 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
|
||||
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
|
||||
}
|
||||
|
||||
TEST(NNOpsTest, Conv2DBackpropInput_ShapeFn) {
|
||||
ShapeInferenceTestOp op("Conv2DBackpropInput");
|
||||
|
||||
// Test rank error.
|
||||
INFER_ERROR("input_sizes to contain 4 values or 2 values", op,
|
||||
"[3];[?,?,?,?];[?,?,?,?]");
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op,
|
||||
"[4];[?,?,?,?];[?,?,?]");
|
||||
|
||||
// When input_sizes is a 4D shape and the convolution is grouped, the channel
|
||||
// size of the input grad doesn't always equal the input channel size of the
|
||||
// filter. So, when input_sizes is a 4D shape, the channel size of the input
|
||||
// grad is determined by the content of input_sizes.
|
||||
INFER_OK(op, "[4];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,?]");
|
||||
// When input_sizes is a 2D shape, the channel size of the input grad always
|
||||
// matches the filter shape.
|
||||
INFER_OK(op, "[2];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,d1_2]");
|
||||
}
|
||||
|
||||
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
|
||||
ShapeInferenceTestOp op("Conv3DBackpropInput");
|
||||
|
||||
|
@ -836,8 +836,9 @@ class Conv2DTest(test.TestCase):
|
||||
x2 = self._CreateNumpyTensor(output_sizes)
|
||||
dilations = list(dilations)
|
||||
with test_util.device(use_gpu):
|
||||
if data_format == "NCHW":
|
||||
input_sizes = test_util.NHWCToNCHW(input_sizes)
|
||||
if len(input_sizes) == 4:
|
||||
if data_format == "NCHW":
|
||||
input_sizes = test_util.NHWCToNCHW(input_sizes)
|
||||
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
|
||||
t1 = constant_op.constant(x1, shape=filter_sizes)
|
||||
t2 = constant_op.constant(x2, shape=output_sizes)
|
||||
@ -1007,6 +1008,22 @@ class Conv2DTest(test.TestCase):
|
||||
use_gpu=use_gpu,
|
||||
err=1e-5)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@test_util.disable_xla("XLA requires input_sizes to be a 4D shape.")
|
||||
def testConv2DInputSizesContainsOnlySpatialDimensionsBackpropInput(self):
|
||||
expected_output = [5.0, 11.0, 17.0, 23.0]
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
self._RunAndVerifyBackpropInput(
|
||||
input_sizes=[2, 2],
|
||||
filter_sizes=[2, 2, 1, 2],
|
||||
output_sizes=[1, 1, 1, 2],
|
||||
strides=[1, 1],
|
||||
padding="VALID",
|
||||
expected=expected_output,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu,
|
||||
err=1e-5)
|
||||
|
||||
# Testing for backprops
|
||||
def _RunAndVerifyBackpropFilter(self,
|
||||
input_sizes,
|
||||
|
Loading…
Reference in New Issue
Block a user