[XLA:Python] Plumb xla_gpu_enable_fast_min_max into the XLA:Python client.

Disable it by default to get correct NaN semantics for min/max.

Will fix https://github.com/google/jax/issues/1072 when deployed in jaxlib.

PiperOrigin-RevId: 260567980
This commit is contained in:
Peter Hawkins 2019-07-29 13:30:03 -07:00 committed by TensorFlower Gardener
parent f7dc3f6961
commit df2fbb8958
2 changed files with 5 additions and 1 deletions

View File

@ -425,7 +425,10 @@ PYBIND11_MODULE(xla_extension, m) {
&DebugOptions::set_xla_cpu_fast_math_honor_nans)
.def_property("xla_cpu_fast_math_honor_division",
&DebugOptions::xla_cpu_fast_math_honor_division,
&DebugOptions::set_xla_cpu_fast_math_honor_division);
&DebugOptions::set_xla_cpu_fast_math_honor_division)
.def_property("xla_gpu_enable_fast_min_max",
&DebugOptions::xla_gpu_enable_fast_min_max,
&DebugOptions::set_xla_gpu_enable_fast_min_max);
py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions")
.def(py::init<>())

View File

@ -109,6 +109,7 @@ class LocalBackend(Backend):
options.debug_options.xla_cpu_fast_math_honor_infs = True
options.debug_options.xla_cpu_fast_math_honor_nans = True
options.debug_options.xla_cpu_fast_math_honor_division = True
options.debug_options.xla_gpu_enable_fast_min_max = False
return _xla.LocalExecutable.Compile(c_computation,
compile_options.argument_layouts,
options, self.client,