[TF2XLA] Print a helpful log message when the argument which has to be constant is not a loop invariant

PiperOrigin-RevId: 354230507
Change-Id: I8010486d560e85f11bc57780c590d08a09b65120
This commit is contained in:
George Karpenkov 2021-01-27 20:13:27 -08:00 committed by TensorFlower Gardener
parent da6545d303
commit 227e5f1e86

View File

@ -113,20 +113,21 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
for (int i = 0; i < num_inputs; i++) {
if (compile_time_const_arg_indices[i]) {
// Check that this input is actually a loop invariant.
// NOTE(srbs): Ideally this should raise an error if the loop body
// requires the input at this index to be a compile time const but it is
// not a loop invariant. However, that causes problems because const
// analysis is performed for the entire graph (in the
// MarkForCompilationPass for example) and not just for the ops
// that will actually be run using XLA kernels. So we silently return
// here and let the error be raised during the actual compilation of the
// XLA graph.
Node* arg_i = fbody->arg_nodes[i];
Node* ret_i = fbody->ret_nodes[i];
const Node* ret_i_input_0;
TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0));
if (ret_i_input_0->id() == arg_i->id()) {
const_input_idxs->push_back(i);
} else {
// TODO(b/178546817): Verify that it's OK and raise an error if we are
// using this branch from jit_compile=True.
VLOG(1) << "Argument " << i << " to while-loop "
<< node.ShortDebugString()
<< " has to be constant, but it's not a loop invariant, "
"cluster compilation likely to fail at compile time: "
<< arg_i->def().ShortDebugString() << " vs. "
<< ret_i->def().ShortDebugString();
}
}
}