[TF2XLA] Print a helpful log message when the argument which has to be constant is not a loop invariant
PiperOrigin-RevId: 354230507 Change-Id: I8010486d560e85f11bc57780c590d08a09b65120
This commit is contained in:
parent
da6545d303
commit
227e5f1e86
@ -113,20 +113,21 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
if (compile_time_const_arg_indices[i]) {
|
||||
// Check that this input is actually a loop invariant.
|
||||
// NOTE(srbs): Ideally this should raise an error if the loop body
|
||||
// requires the input at this index to be a compile time const but it is
|
||||
// not a loop invariant. However, that causes problems because const
|
||||
// analysis is performed for the entire graph (in the
|
||||
// MarkForCompilationPass for example) and not just for the ops
|
||||
// that will actually be run using XLA kernels. So we silently return
|
||||
// here and let the error be raised during the actual compilation of the
|
||||
// XLA graph.
|
||||
Node* arg_i = fbody->arg_nodes[i];
|
||||
Node* ret_i = fbody->ret_nodes[i];
|
||||
const Node* ret_i_input_0;
|
||||
TF_RETURN_IF_ERROR(ret_i->input_node(0, &ret_i_input_0));
|
||||
if (ret_i_input_0->id() == arg_i->id()) {
|
||||
const_input_idxs->push_back(i);
|
||||
} else {
|
||||
// TODO(b/178546817): Verify that it's OK and raise an error if we are
|
||||
// using this branch from jit_compile=True.
|
||||
VLOG(1) << "Argument " << i << " to while-loop "
|
||||
<< node.ShortDebugString()
|
||||
<< " has to be constant, but it's not a loop invariant, "
|
||||
"cluster compilation likely to fail at compile time: "
|
||||
<< arg_i->def().ShortDebugString() << " vs. "
|
||||
<< ret_i->def().ShortDebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user