From f096affea02c69d53aa7cf2cdba82846ad99a5a2 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Tue, 17 Nov 2020 12:14:19 -0800 Subject: [PATCH] Accept `None` as a trivial LazyExpr. PiperOrigin-RevId: 342917973 Change-Id: I3d0e1a2e46c614dbedafb51137b21d351f7ecc2c --- tensorflow/compiler/xla/python/jax_jit.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index b7d833e5948..af02c1ef0d4 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -375,10 +375,13 @@ namespace { // // Expects *only* instances of `DeviceArray`. bool HasTrivialLazyExpr(py::handle device_array) { + auto lexpr = py::getattr(device_array, "_lazy_expr"); + if (lexpr.is_none()) { + return true; + } + static const auto* lazy_module = new py::module(py::module::import("jax.lazy")); - - auto lexpr = py::getattr(device_array, "_lazy_expr"); auto input = py::getattr(lexpr, "input"); if (!input.get_type().is(lazy_module->attr("ArrayVar"))) { return false;