Special case 1D resizes and no-op resizes.
Bilinear resize from [X, Y] -> [X, Y] is a no-op, but previously, ResizeBilinear would generated a 1x1 convolution. Bilinear resize from [X, A] -> [X, B] or [A, Y] -> [B, Y] are one dimensional, but previously the code used the two dimensional general case if the kernel size was sufficiently small. PiperOrigin-RevId: 251713592
This commit is contained in:
parent
50abb98b7d
commit
ac034fd4e6
@ -560,6 +560,8 @@ class ResizeBilinearTest(parameterized.TestCase, xla_test.XLATestCase):
|
|||||||
("72x72To456x456", 72, 72, 456, 456),
|
("72x72To456x456", 72, 72, 456, 456),
|
||||||
("86x86To456x456", 86, 86, 456, 456),
|
("86x86To456x456", 86, 86, 456, 456),
|
||||||
("100x100To456x456", 100, 100, 456, 456),
|
("100x100To456x456", 100, 100, 456, 456),
|
||||||
|
("64x64To224x224", 64, 64, 224, 224),
|
||||||
|
("224x224To224x224", 224, 224, 224, 224),
|
||||||
# This test is disabled because it is very slow. It is slow because
|
# This test is disabled because it is very slow. It is slow because
|
||||||
# 383 is prime, 383 and 2047 are coprime, and 2048 is large.
|
# 383 is prime, 383 and 2047 are coprime, and 2048 is large.
|
||||||
# ("Disabled_384x72To2048x384", 384, 72, 2048, 384),
|
# ("Disabled_384x72To2048x384", 384, 72, 2048, 384),
|
||||||
|
@ -312,8 +312,11 @@ xla::XlaOp ResizeUsingDilationAndConvolution(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Split convolutions into independent dimensions if they would be a very
|
// Split convolutions into independent dimensions if they would be a very
|
||||||
// large kernel.
|
// large kernel or if one or more of the dimensions are already equal.
|
||||||
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
|
bool decompose_resize =
|
||||||
|
in_size[0] == out_size[0] || in_size[1] == out_size[1] ||
|
||||||
|
dims.kernel_size[0] * dims.kernel_size[1] >= kMax2DKernelSize;
|
||||||
|
if (!decompose_resize) {
|
||||||
xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
|
xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
|
||||||
channels, is_kernel_bilinear);
|
channels, is_kernel_bilinear);
|
||||||
output =
|
output =
|
||||||
|
Loading…
Reference in New Issue
Block a user