Fallback when faced with other runtimes.
PiperOrigin-RevId: 336173651 Change-Id: Ia585abd06c4f24f77d5b9cabc4487d238b4bb4ac
This commit is contained in:
parent
ebe1d5c6c0
commit
d483dea0c6
@ -280,6 +280,8 @@ class CompiledFunction {
|
|||||||
py::object out_and_fastpath_data);
|
py::object out_and_fastpath_data);
|
||||||
bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); }
|
bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); }
|
||||||
|
|
||||||
|
bool always_fallback_to_python_ = false;
|
||||||
|
|
||||||
const py::function fun_; // The Python function to jit.
|
const py::function fun_; // The Python function to jit.
|
||||||
// See JAX _cpp_jit in api.py for documentation.
|
// See JAX _cpp_jit in api.py for documentation.
|
||||||
const py::function cache_miss_;
|
const py::function cache_miss_;
|
||||||
@ -762,6 +764,9 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
|
|||||||
}
|
}
|
||||||
|
|
||||||
py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
||||||
|
if (always_fallback_to_python_) {
|
||||||
|
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||||
|
}
|
||||||
// Delayed values are retrieved on the first call to `Call`.
|
// Delayed values are retrieved on the first call to `Call`.
|
||||||
if (!default_device_) {
|
if (!default_device_) {
|
||||||
// As we are calling Python code, that may release the GIL, we first hold
|
// As we are calling Python code, that may release the GIL, we first hold
|
||||||
@ -775,8 +780,14 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
|
|||||||
jax_disable_jit_ = py::cast<bool>(get_jax_disable_jit_());
|
jax_disable_jit_ = py::cast<bool>(get_jax_disable_jit_());
|
||||||
if (!default_device_) {
|
if (!default_device_) {
|
||||||
py::object device_and_is_committed = get_device_();
|
py::object device_and_is_committed = get_device_();
|
||||||
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
|
try {
|
||||||
device_and_is_committed.attr("default_device"));
|
default_pydevice_ = py::cast<ClientAndPtr<PjRtDevice>>(
|
||||||
|
device_and_is_committed.attr("default_device"));
|
||||||
|
} catch (const py::cast_error& e) {
|
||||||
|
// Pathways and Cloud TPU 2VM runtime.
|
||||||
|
always_fallback_to_python_ = true;
|
||||||
|
return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
|
||||||
|
}
|
||||||
default_pyclient_ = default_pydevice_.client;
|
default_pyclient_ = default_pydevice_.client;
|
||||||
default_device_ = default_pydevice_.contents;
|
default_device_ = default_pydevice_.contents;
|
||||||
is_committed_ =
|
is_committed_ =
|
||||||
|
Loading…
Reference in New Issue
Block a user