Fallback when faced with other runtimes.

PiperOrigin-RevId: 336173651
Change-Id: Ia585abd06c4f24f77d5b9cabc4487d238b4bb4ac
This commit is contained in:
Jean-Baptiste Lespiau 2020-10-08 15:08:37 -07:00 committed by TensorFlower Gardener
parent ebe1d5c6c0
commit d483dea0c6

View File

@ -280,6 +280,8 @@ class CompiledFunction {
py::object out_and_fastpath_data);
bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); }
bool always_fallback_to_python_ = false;
const py::function fun_; // The Python function to jit.
// See JAX _cpp_jit in api.py for documentation.
const py::function cache_miss_;
@ -762,6 +764,9 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
}
py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
if (always_fallback_to_python_) {
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
// Delayed values are retrieved on the first call to `Call`.
if (!default_device_) {
// As we are calling Python code, that may release the GIL, we first hold
@ -775,8 +780,14 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
jax_disable_jit_ = py::cast<bool>(get_jax_disable_jit_());
if (!default_device_) {
py::object device_and_is_committed = get_device_();
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
device_and_is_committed.attr("default_device"));
try {
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
device_and_is_committed.attr("default_device"));
} catch (const py::cast_error& e) {
// Pathways and Cloud TPU 2VM runtime.
always_fallback_to_python_ = true;
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
}
default_pyclient_ = default_pydevice_.client;
default_device_ = default_pydevice_.contents;
is_committed_ =