Merge pull request #31259 from yongfeng-nv:xla-modify-column-reduction-threshold
PiperOrigin-RevId: 263012404
This commit is contained in:
commit
7685639e2f
@ -514,6 +514,27 @@ class ResizeNearestNeighborTest(xla_test.XLATestCase):
|
|||||||
[7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]],
|
[7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]],
|
||||||
dtype=np.float32))
|
dtype=np.float32))
|
||||||
|
|
||||||
|
def testAlignCorners3x3To12x12_uint8(self):
|
||||||
|
# TODO(b/72099414): enable the test for TPU when the issue is fixed.
|
||||||
|
if (self.device not in ["XLA_GPU", "XLA_CPU"]):
|
||||||
|
return
|
||||||
|
# Ensure that resize with convolution works on XLA/GPU for integer types
|
||||||
|
self._assertForwardOpMatchesExpected(
|
||||||
|
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8), [12, 12],
|
||||||
|
expected=np.array([[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3],
|
||||||
|
[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3],
|
||||||
|
[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3],
|
||||||
|
[4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6],
|
||||||
|
[4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6],
|
||||||
|
[4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6],
|
||||||
|
[4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6],
|
||||||
|
[4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6],
|
||||||
|
[4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6],
|
||||||
|
[7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9],
|
||||||
|
[7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9],
|
||||||
|
[7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]],
|
||||||
|
dtype=np.uint8))
|
||||||
|
|
||||||
|
|
||||||
class ResizeBilinearTest(parameterized.TestCase, xla_test.XLATestCase):
|
class ResizeBilinearTest(parameterized.TestCase, xla_test.XLATestCase):
|
||||||
|
|
||||||
|
@ -220,9 +220,9 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// For column reduction, the tile block is tize_size_y x tile_size_x, and we
|
// For column reduction, the tile block is tize_size_y x tile_size_x, and we
|
||||||
// are reducing along tile_size_y. Both tile_size_x and tile_size_y need to be
|
// are reducing along tile_size_y. Only tile_size_y needs to be
|
||||||
// large enough to make the tiling implementation efficient.
|
// large enough to make the tiling implementation efficient.
|
||||||
return dims_in_elem[2] >= kWarpSize && dims_in_elem[1] >= kWarpSize;
|
return dims_in_elem[1] >= kWarpSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<bool, DimensionVector> GetReductionKindAndContiguousComponents(
|
std::pair<bool, DimensionVector> GetReductionKindAndContiguousComponents(
|
||||||
|
@ -536,6 +536,51 @@ TEST_F(GpuKernelTilingTest,
|
|||||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) {
|
||||||
|
const char *const kHloString = R"(
|
||||||
|
HloModule Test
|
||||||
|
|
||||||
|
scalar_add_computation.1 {
|
||||||
|
scalar_lhs.1 = f32[] parameter(0)
|
||||||
|
scalar_rhs.1 = f32[] parameter(1)
|
||||||
|
ROOT add.6 = f32[] add(scalar_lhs.1, scalar_rhs.1)
|
||||||
|
}
|
||||||
|
ENTRY Test {
|
||||||
|
param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3)
|
||||||
|
constant_661 = f16[] constant(0)
|
||||||
|
broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(constant_661), dimensions={}
|
||||||
|
compare.42 = pred[512,2,9,9]{1,3,2,0} compare(param_3.241, broadcast.695), direction=GT
|
||||||
|
param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2)
|
||||||
|
select.40 = f16[512,2,9,9]{1,3,2,0} select(compare.42, param_2.401, broadcast.695)
|
||||||
|
convert.196 = f32[512,2,9,9]{1,3,2,0} convert(select.40)
|
||||||
|
param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1)
|
||||||
|
copy.335 = f16[512,2,9,9]{1,3,2,0} copy(param_1.809)
|
||||||
|
convert.218 = f32[512,2,9,9]{1,3,2,0} convert(copy.335)
|
||||||
|
param_0.668 = f32[2]{0} parameter(0)
|
||||||
|
broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(param_0.668), dimensions={1}
|
||||||
|
subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(convert.218, broadcast.687)
|
||||||
|
multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(convert.196, subtract.136)
|
||||||
|
constant_485 = f32[] constant(0)
|
||||||
|
reduce.139 = f32[2]{0} reduce(multiply.579, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1
|
||||||
|
reduce.140.clone.1 = f32[2]{0} reduce(convert.196, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1
|
||||||
|
ROOT tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(reduce.139, reduce.140.clone.1)
|
||||||
|
})";
|
||||||
|
|
||||||
|
// Check that no loop is generated for reduction.
|
||||||
|
auto hlo_module =
|
||||||
|
ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
|
||||||
|
.ValueOrDie();
|
||||||
|
CompileAndVerifyIr(std::move(hlo_module),
|
||||||
|
R"(
|
||||||
|
; CHECK-LABEL: define void @fusion
|
||||||
|
; CHECK-NOT: reduce.0.loop_header
|
||||||
|
; CHECK: }
|
||||||
|
)",
|
||||||
|
/*match_optimized_ir=*/true);
|
||||||
|
// Check that the kernel runs correctly.
|
||||||
|
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) {
|
TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) {
|
||||||
const char *const kHloString = R"(
|
const char *const kHloString = R"(
|
||||||
HloModule reduction
|
HloModule reduction
|
||||||
|
Loading…
x
Reference in New Issue
Block a user