[XLA:CPU] Run ScatterExpander much earlier in CPU pipeline.

Before, the ScatterExpander was run after fusion (!), meaning that nothing it
emitted would ever be fused.

On my machine, this is good for a 3.2/2.6 = 1.2x speedup on the testcase from
https://github.com/google/jax/issues/695.

PiperOrigin-RevId: 248950865
This commit is contained in:
Justin Lebar 2019-05-19 11:44:16 -07:00 committed by TensorFlower Gardener
parent abdee716ce
commit ee4657facf

View File

@ -297,6 +297,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false); /*allow_mixed_precision=*/false);
pass.AddPass<ScatterExpander>();
pass.AddPass<BatchNormExpander>( pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true, /*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true, /*rewrite_inference_op=*/true,
@ -340,8 +341,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<CpuInstructionFusion>(); pipeline.AddPass<CpuInstructionFusion>();
pipeline.AddPass<ScatterExpander>();
ReducePrecisionInsertion::AddPasses( ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(), &pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION); ReducePrecisionInsertion::PassTiming::AFTER_FUSION);