[XLA:Python] Plumb xla_gpu_enable_fast_min_max into the XLA:Python client.
Disable it by default to get correct NaN semantics for min/max. Will fix https://github.com/google/jax/issues/1072 when deployed in jaxlib. PiperOrigin-RevId: 260567980
This commit is contained in:
parent
f7dc3f6961
commit
df2fbb8958
@ -425,7 +425,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
&DebugOptions::set_xla_cpu_fast_math_honor_nans)
|
||||
.def_property("xla_cpu_fast_math_honor_division",
|
||||
&DebugOptions::xla_cpu_fast_math_honor_division,
|
||||
&DebugOptions::set_xla_cpu_fast_math_honor_division);
|
||||
&DebugOptions::set_xla_cpu_fast_math_honor_division)
|
||||
.def_property("xla_gpu_enable_fast_min_max",
|
||||
&DebugOptions::xla_gpu_enable_fast_min_max,
|
||||
&DebugOptions::set_xla_gpu_enable_fast_min_max);
|
||||
|
||||
py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions")
|
||||
.def(py::init<>())
|
||||
|
@ -109,6 +109,7 @@ class LocalBackend(Backend):
|
||||
options.debug_options.xla_cpu_fast_math_honor_infs = True
|
||||
options.debug_options.xla_cpu_fast_math_honor_nans = True
|
||||
options.debug_options.xla_cpu_fast_math_honor_division = True
|
||||
options.debug_options.xla_gpu_enable_fast_min_max = False
|
||||
return _xla.LocalExecutable.Compile(c_computation,
|
||||
compile_options.argument_layouts,
|
||||
options, self.client,
|
||||
|
Loading…
x
Reference in New Issue
Block a user