Accept None as a trivial LazyExpr.
PiperOrigin-RevId: 342917973 Change-Id: I3d0e1a2e46c614dbedafb51137b21d351f7ecc2c
This commit is contained in:
parent
f8e2de9d0c
commit
f096affea0
@ -375,10 +375,13 @@ namespace {
|
|||||||
//
|
//
|
||||||
// Expects *only* instances of `DeviceArray`.
|
// Expects *only* instances of `DeviceArray`.
|
||||||
bool HasTrivialLazyExpr(py::handle device_array) {
|
bool HasTrivialLazyExpr(py::handle device_array) {
|
||||||
|
auto lexpr = py::getattr(device_array, "_lazy_expr");
|
||||||
|
if (lexpr.is_none()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static const auto* lazy_module =
|
static const auto* lazy_module =
|
||||||
new py::module(py::module::import("jax.lazy"));
|
new py::module(py::module::import("jax.lazy"));
|
||||||
|
|
||||||
auto lexpr = py::getattr(device_array, "_lazy_expr");
|
|
||||||
auto input = py::getattr(lexpr, "input");
|
auto input = py::getattr(lexpr, "input");
|
||||||
if (!input.get_type().is(lazy_module->attr("ArrayVar"))) {
|
if (!input.get_type().is(lazy_module->attr("ArrayVar"))) {
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user