Raise better error message if input to eager execute

contains keras symbolic tensor.

PiperOrigin-RevId: 256077279
This commit is contained in:
Zhenyu Tan 2019-07-01 18:36:26 -07:00 committed by TensorFlower Gardener
parent 99d4a96121
commit a59b74ccaa

View File

@ -66,8 +66,13 @@ def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
message = e.message
six.raise_from(core._status_to_exception(e.code, message), None)
except TypeError as e:
if any(ops._is_keras_symbolic_tensor(x) for x in inputs):
raise core._SymbolicException
keras_symbolic_tensors = [
x for x in inputs if ops._is_keras_symbolic_tensor(x)
]
if keras_symbolic_tensors:
raise core._SymbolicException(
"Inputs to eager execution function cannot be Keras symbolic "
"tensors, but found {}".format(keras_symbolic_tensors))
raise e
# pylint: enable=protected-access
return tensors
@ -202,9 +207,12 @@ def args_to_matching_eager(l, ctx, default_dtype=None):
# TODO(slebedev): consider removing this as it leaks a Keras concept.
# pylint: disable=protected-access
if any(ops._is_keras_symbolic_tensor(x) for x in ret):
keras_symbolic_tensors = [x for x in ret if
ops._is_keras_symbolic_tensor(x)]
if keras_symbolic_tensors:
raise core._SymbolicException(
"Using the symbolic output of a Keras layer during eager execution.")
"Using symbolic output of a Keras layer during eager execution "
"{}".format(keras_symbolic_tensors))
# pylint: enable=protected-access
return dtype.as_datatype_enum, ret