Update testing hlo generated by hlo_converter.
This commit is contained in:
parent
8d7fe0a1e2
commit
cc55b7f346
@ -524,33 +524,31 @@ TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) {
|
||||
const char *const kHloString = R"(
|
||||
HloModule Test
|
||||
|
||||
%scalar_add_computation.1 {
|
||||
%scalar_lhs.1 = f32[] parameter(0)
|
||||
%scalar_rhs.1 = f32[] parameter(1)
|
||||
ROOT %add.6 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
|
||||
scalar_add_computation.1 {
|
||||
scalar_lhs.1 = f32[] parameter(0)
|
||||
scalar_rhs.1 = f32[] parameter(1)
|
||||
ROOT add.6 = f32[] add(scalar_lhs.1, scalar_rhs.1)
|
||||
}
|
||||
|
||||
ENTRY Test {
|
||||
|
||||
%param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3)
|
||||
%constant_661 = f16[] constant(0), metadata={op_type="Relu" op_name="Relu_19"}
|
||||
%broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(f16[] %constant_661), dimensions={}, metadata={op_type="Relu" op_name="Relu_19"}
|
||||
%compare.42 = pred[512,2,9,9]{1,3,2,0} compare(f16[512,2,9,9]{1,3,2,0} %param_3.241, f16[512,2,9,9]{1,3,2,0} %broadcast.695), direction=GT, metadata={op_type="ReluGrad" op_name="gradients/Relu_19_grad/ReluGrad"}
|
||||
%param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2)
|
||||
%select.40 = f16[512,2,9,9]{1,3,2,0} select(pred[512,2,9,9]{1,3,2,0} %compare.42, f16[512,2,9,9]{1,3,2,0} %param_2.401, f16[512,2,9,9]{1,3,2,0} %broadcast.695), metadata={op_type="ReluGrad" op_name="gradients/Relu_19_grad/ReluGrad"}
|
||||
%convert.196 = f32[512,2,9,9]{1,3,2,0} convert(f16[512,2,9,9]{1,3,2,0} %select.40), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"}
|
||||
%param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1)
|
||||
%copy.335 = f16[512,2,9,9]{1,3,2,0} copy(f16[512,2,9,9]{1,3,2,0} %param_1.809), metadata={op_name="XLA_Args"}
|
||||
%convert.218 = f32[512,2,9,9]{1,3,2,0} convert(f16[512,2,9,9]{1,3,2,0} %copy.335), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"}
|
||||
%param_0.668 = f32[2]{0} parameter(0)
|
||||
%broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(f32[2]{0} %param_0.668), dimensions={1}, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"}
|
||||
%subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(f32[512,2,9,9]{1,3,2,0} %convert.218, f32[512,2,9,9]{1,3,2,0} %broadcast.687), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"}
|
||||
%multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(f32[512,2,9,9]{1,3,2,0} %convert.196, f32[512,2,9,9]{1,3,2,0} %subtract.136), metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"}
|
||||
%constant_485 = f32[] constant(0), metadata={op_type="L2Loss" op_name="L2Loss_21"}
|
||||
%reduce.139 = f32[2]{0} reduce(f32[512,2,9,9]{1,3,2,0} %multiply.579, f32[] %constant_485), dimensions={0,2,3}, to_apply=%scalar_add_computation.1, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"}
|
||||
%reduce.140.clone.1 = f32[2]{0} reduce(f32[512,2,9,9]{1,3,2,0} %convert.196, f32[] %constant_485), dimensions={0,2,3}, to_apply=%scalar_add_computation.1, metadata={op_type="FusedBatchNormGradV2" op_name="gradients/batch_normalization_19/FusedBatchNormV2_grad/FusedBatchNormGradV2"}
|
||||
ROOT %tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(f32[2]{0} %reduce.139, f32[2]{0} %reduce.140.clone.1)
|
||||
})";
|
||||
param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3)
|
||||
constant_661 = f16[] constant(0)
|
||||
broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(constant_661), dimensions={}
|
||||
compare.42 = pred[512,2,9,9]{1,3,2,0} compare(param_3.241, broadcast.695), direction=GT
|
||||
param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2)
|
||||
select.40 = f16[512,2,9,9]{1,3,2,0} select(compare.42, param_2.401, broadcast.695)
|
||||
convert.196 = f32[512,2,9,9]{1,3,2,0} convert(select.40)
|
||||
param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1)
|
||||
copy.335 = f16[512,2,9,9]{1,3,2,0} copy(param_1.809)
|
||||
convert.218 = f32[512,2,9,9]{1,3,2,0} convert(copy.335)
|
||||
param_0.668 = f32[2]{0} parameter(0)
|
||||
broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(param_0.668), dimensions={1}
|
||||
subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(convert.218, broadcast.687)
|
||||
multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(convert.196, subtract.136)
|
||||
constant_485 = f32[] constant(0)
|
||||
reduce.139 = f32[2]{0} reduce(multiply.579, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1
|
||||
reduce.140.clone.1 = f32[2]{0} reduce(convert.196, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1
|
||||
ROOT tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(reduce.139, reduce.140.clone.1)
|
||||
})";
|
||||
|
||||
// Check that no loop is generated for reduction.
|
||||
auto hlo_module =
|
||||
|
Loading…
x
Reference in New Issue
Block a user