Accept None as a trivial LazyExpr.

PiperOrigin-RevId: 342917973
Change-Id: I3d0e1a2e46c614dbedafb51137b21d351f7ecc2c
This commit is contained in:
Jean-Baptiste Lespiau 2020-11-17 12:14:19 -08:00 committed by TensorFlower Gardener
parent f8e2de9d0c
commit f096affea0

View File

@ -375,10 +375,13 @@ namespace {
//
// Expects *only* instances of `DeviceArray`.
bool HasTrivialLazyExpr(py::handle device_array) {
auto lexpr = py::getattr(device_array, "_lazy_expr");
if (lexpr.is_none()) {
return true;
}
static const auto* lazy_module =
new py::module(py::module::import("jax.lazy"));
auto lexpr = py::getattr(device_array, "_lazy_expr");
auto input = py::getattr(lexpr, "input");
if (!input.get_type().is(lazy_module->attr("ArrayVar"))) {
return false;