From afc4a762f73205603f840afe867f3d29d56d1fa8 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Tue, 30 Jul 2019 00:58:28 -0400 Subject: [PATCH 1/7] Add a uint8 nearest neighbor resize test. --- tensorflow/compiler/tests/image_ops_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index fb4b2711905..f7b186d9b7a 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -514,6 +514,24 @@ class ResizeNearestNeighborTest(xla_test.XLATestCase): [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]], dtype=np.float32)) + def testAlignCorners3x3To12x12_uint8(self): + # Ensure that resize with convolution works on XLA/GPU for integer types + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8), [12, 12], + expected=np.array([[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]], + dtype=np.uint8)) + class ResizeBilinearTest(parameterized.TestCase, xla_test.XLATestCase): From dab29eebfa2f80fe7739ca34cdfdbc64ef24a31c Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Fri, 2 Aug 2019 01:32:39 -0400 Subject: [PATCH 2/7] Remove the threshold on tile_size_x for column reduction from IsReductionFromOrToContiguousDimensions, so that it returns true regardless tile_size_x. --- tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 404d3347772..b5c197f85b2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -220,9 +220,9 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { } // For column reduction, the tile block is tize_size_y x tile_size_x, and we - // are reducing along tile_size_y. Both tile_size_x and tile_size_y need to be + // are reducing along tile_size_y. tile_size_y needs to be // large enough to make the tiling implementation efficient. - return dims_in_elem[2] >= kWarpSize && dims_in_elem[1] >= kWarpSize; + return dims_in_elem[1] >= kWarpSize; } std::pair GetReductionKindAndContiguousComponents( From 94c7471e3a29e44681cb40d8a0406026867fe299 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Fri, 2 Aug 2019 14:11:25 -0400 Subject: [PATCH 3/7] Minor change to one comment. --- tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index b5c197f85b2..78f8e22a857 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -220,7 +220,7 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { } // For column reduction, the tile block is tize_size_y x tile_size_x, and we - // are reducing along tile_size_y. tile_size_y needs to be + // are reducing along tile_size_y. Only tile_size_y needs to be // large enough to make the tiling implementation efficient. return dims_in_elem[1] >= kWarpSize; } From 09131b078f73ee0eda525b016349e1177dd86484 Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Sat, 3 Aug 2019 02:16:17 -0400 Subject: [PATCH 4/7] Add a test to ensure that column reduction with small x_tile_size is still marked as KInput fusion instead of KLoop. --- .../gpu/tests/gpu_kernel_tiling_test.cc | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index a12932f573b..a629bbb6787 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -520,6 +520,53 @@ TEST_F(GpuKernelTilingTest, EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); } +TEST_F(GpuKernelTilingTest, ColumnReductionSmallXTileSize) { + const char *const kHloString = R"( + HloModule Test + + %scalar_add_computation.1 { + %scalar_lhs.1 = f32[] parameter(0) + %scalar_rhs.1 = f32[] parameter(1) + ROOT %add.6 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1) + } + + ENTRY Test { + + %param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3) + %constant_661 = f16[] constant(0), metadata={op_type="Relu" op_name="Relu_19"} + %broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(f16[] %constant_661), dimensions={}, metadata={op_type="Relu" op_name="Relu_19"} + %compare.42 = pred[512,2,9,9]{1,3,2,0} compare(f16[512,2,9,9]{1,3,2,0} %param_3.241, f16[512,2,9,9]{1,3,2,0} %broadcast.695), direction=GT, metadata={op_type="ReluGrad" op_name="gradients/Relu_19_grad/ReluGrad"} + %param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2) + %select.40 = f16[512,2,9,9]{1,3,2,0} select(pred[512,2,9,9]{1,3,2,0} %compare.42, f16[512,2,9,9]{1,3,2,0} %param_2.401, f16[512,2,9,9]{1,3,2,0} %broadcast.695), metadata={op_type="ReluGrad" op_name="gradients/Relu_19_grad/ReluGrad"} + %convert.196 = f32[512,2,9,9]{1,3,2,0} convert(f16[512,2,9,9]{1,3,2,0} %select.40), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} + %param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1) + %copy.335 = f16[512,2,9,9]{1,3,2,0} copy(f16[512,2,9,9]{1,3,2,0} %param_1.809), metadata={op_name="XLA_Args"} + %convert.218 = f32[512,2,9,9]{1,3,2,0} convert(f16[512,2,9,9]{1,3,2,0} %copy.335), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} + %param_0.668 = f32[2]{0} parameter(0) + %broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(f32[2]{0} %param_0.668), dimensions={1}, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} + %subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(f32[512,2,9,9]{1,3,2,0} %convert.218, f32[512,2,9,9]{1,3,2,0} %broadcast.687), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} + %multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(f32[512,2,9,9]{1,3,2,0} %convert.196, f32[512,2,9,9]{1,3,2,0} %subtract.136), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} + %constant_485 = f32[] constant(0), metadata={op_type="L2Loss" op_name="L2Loss_21"} + %reduce.139 = f32[2]{0} reduce(f32[512,2,9,9]{1,3,2,0} %multiply.579, f32[] %constant_485), dimensions={0,2,3}, to_apply=%scalar_add_computation.1, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} + %reduce.140.clone.1 = f32[2]{0} reduce(f32[512,2,9,9]{1,3,2,0} %convert.196, f32[] %constant_485), dimensions={0,2,3}, to_apply=%scalar_add_computation.1, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} + ROOT %tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(f32[2]{0} %reduce.139, f32[2]{0} %reduce.140.clone.1) +})"; + + // Check that four calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) + .ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: reduce.0.loop_header +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) { const char *const kHloString = R"( HloModule reduction From 6e02abed0bcbdd6a76daf995732c8d8f8553c93e Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Sat, 3 Aug 2019 02:22:16 -0400 Subject: [PATCH 5/7] Fixed comments in the previous commit. --- .../compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index a629bbb6787..ef74e0c2937 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -552,7 +552,7 @@ TEST_F(GpuKernelTilingTest, ColumnReductionSmallXTileSize) { ROOT %tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(f32[2]{0} %reduce.139, f32[2]{0} %reduce.140.clone.1) })"; - // Check that four calls to llvm.nvvm.atomic are generated. + // Check that no loop is generated for reduction. auto hlo_module = ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) .ValueOrDie(); From 8d7fe0a1e276156c3aa855e6afddf9c0950a217b Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Sat, 3 Aug 2019 02:26:32 -0400 Subject: [PATCH 6/7] Fixed test name in the previous commit. --- .../compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index ef74e0c2937..86059575c85 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -520,7 +520,7 @@ TEST_F(GpuKernelTilingTest, EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); } -TEST_F(GpuKernelTilingTest, ColumnReductionSmallXTileSize) { +TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) { const char *const kHloString = R"( HloModule Test From cc55b7f34636b3c1de611033ca20b7e8746d1d9d Mon Sep 17 00:00:00 2001 From: Yongfeng Gu Date: Mon, 5 Aug 2019 12:02:39 -0400 Subject: [PATCH 7/7] Update testing hlo generated by hlo_converter. --- .../gpu/tests/gpu_kernel_tiling_test.cc | 48 +++++++++---------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 86059575c85..96b15af1804 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -524,33 +524,31 @@ TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) { const char *const kHloString = R"( HloModule Test - %scalar_add_computation.1 { - %scalar_lhs.1 = f32[] parameter(0) - %scalar_rhs.1 = f32[] parameter(1) - ROOT %add.6 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1) + scalar_add_computation.1 { + scalar_lhs.1 = f32[] parameter(0) + scalar_rhs.1 = f32[] parameter(1) + ROOT add.6 = f32[] add(scalar_lhs.1, scalar_rhs.1) } - ENTRY Test { - - %param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3) - %constant_661 = f16[] constant(0), metadata={op_type="Relu" op_name="Relu_19"} - %broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(f16[] %constant_661), dimensions={}, metadata={op_type="Relu" op_name="Relu_19"} - %compare.42 = pred[512,2,9,9]{1,3,2,0} compare(f16[512,2,9,9]{1,3,2,0} %param_3.241, f16[512,2,9,9]{1,3,2,0} %broadcast.695), direction=GT, metadata={op_type="ReluGrad" op_name="gradients/Relu_19_grad/ReluGrad"} - %param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2) - %select.40 = f16[512,2,9,9]{1,3,2,0} select(pred[512,2,9,9]{1,3,2,0} %compare.42, f16[512,2,9,9]{1,3,2,0} %param_2.401, f16[512,2,9,9]{1,3,2,0} %broadcast.695), metadata={op_type="ReluGrad" op_name="gradients/Relu_19_grad/ReluGrad"} - %convert.196 = f32[512,2,9,9]{1,3,2,0} convert(f16[512,2,9,9]{1,3,2,0} %select.40), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} - %param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1) - %copy.335 = f16[512,2,9,9]{1,3,2,0} copy(f16[512,2,9,9]{1,3,2,0} %param_1.809), metadata={op_name="XLA_Args"} - %convert.218 = f32[512,2,9,9]{1,3,2,0} convert(f16[512,2,9,9]{1,3,2,0} %copy.335), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} - %param_0.668 = f32[2]{0} parameter(0) - %broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(f32[2]{0} %param_0.668), dimensions={1}, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} - %subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(f32[512,2,9,9]{1,3,2,0} %convert.218, f32[512,2,9,9]{1,3,2,0} %broadcast.687), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} - %multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(f32[512,2,9,9]{1,3,2,0} %convert.196, f32[512,2,9,9]{1,3,2,0} %subtract.136), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} - %constant_485 = f32[] constant(0), metadata={op_type="L2Loss" op_name="L2Loss_21"} - %reduce.139 = f32[2]{0} reduce(f32[512,2,9,9]{1,3,2,0} %multiply.579, f32[] %constant_485), dimensions={0,2,3}, to_apply=%scalar_add_computation.1, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} - %reduce.140.clone.1 = f32[2]{0} reduce(f32[512,2,9,9]{1,3,2,0} %convert.196, f32[] %constant_485), dimensions={0,2,3}, to_apply=%scalar_add_computation.1, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"} - ROOT %tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(f32[2]{0} %reduce.139, f32[2]{0} %reduce.140.clone.1) -})"; + param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3) + constant_661 = f16[] constant(0) + broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(constant_661), dimensions={} + compare.42 = pred[512,2,9,9]{1,3,2,0} compare(param_3.241, broadcast.695), direction=GT + param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2) + select.40 = f16[512,2,9,9]{1,3,2,0} select(compare.42, param_2.401, broadcast.695) + convert.196 = f32[512,2,9,9]{1,3,2,0} convert(select.40) + param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1) + copy.335 = f16[512,2,9,9]{1,3,2,0} copy(param_1.809) + convert.218 = f32[512,2,9,9]{1,3,2,0} convert(copy.335) + param_0.668 = f32[2]{0} parameter(0) + broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(param_0.668), dimensions={1} + subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(convert.218, broadcast.687) + multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(convert.196, subtract.136) + constant_485 = f32[] constant(0) + reduce.139 = f32[2]{0} reduce(multiply.579, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1 + reduce.140.clone.1 = f32[2]{0} reduce(convert.196, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1 + ROOT tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(reduce.139, reduce.140.clone.1) + })"; // Check that no loop is generated for reduction. auto hlo_module =