Raise better error message if input to eager execute
contains keras symbolic tensor. PiperOrigin-RevId: 256077279
This commit is contained in:
parent
99d4a96121
commit
a59b74ccaa
@ -66,8 +66,13 @@ def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
|
||||
message = e.message
|
||||
six.raise_from(core._status_to_exception(e.code, message), None)
|
||||
except TypeError as e:
|
||||
if any(ops._is_keras_symbolic_tensor(x) for x in inputs):
|
||||
raise core._SymbolicException
|
||||
keras_symbolic_tensors = [
|
||||
x for x in inputs if ops._is_keras_symbolic_tensor(x)
|
||||
]
|
||||
if keras_symbolic_tensors:
|
||||
raise core._SymbolicException(
|
||||
"Inputs to eager execution function cannot be Keras symbolic "
|
||||
"tensors, but found {}".format(keras_symbolic_tensors))
|
||||
raise e
|
||||
# pylint: enable=protected-access
|
||||
return tensors
|
||||
@ -202,9 +207,12 @@ def args_to_matching_eager(l, ctx, default_dtype=None):
|
||||
|
||||
# TODO(slebedev): consider removing this as it leaks a Keras concept.
|
||||
# pylint: disable=protected-access
|
||||
if any(ops._is_keras_symbolic_tensor(x) for x in ret):
|
||||
keras_symbolic_tensors = [x for x in ret if
|
||||
ops._is_keras_symbolic_tensor(x)]
|
||||
if keras_symbolic_tensors:
|
||||
raise core._SymbolicException(
|
||||
"Using the symbolic output of a Keras layer during eager execution.")
|
||||
"Using symbolic output of a Keras layer during eager execution "
|
||||
"{}".format(keras_symbolic_tensors))
|
||||
# pylint: enable=protected-access
|
||||
return dtype.as_datatype_enum, ret
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user