Improve error message when attributes are invalid
If we are unable to find any valid devices for a node, we can do a quick check to see if the node is even valid as per the op definition. This greatly improves the eager error message since there is no point in listing all the available kernels across all devices if we know none of them can match. Previous: NotFoundError: Could not find device for node: {{node GatherV2}} = GatherV2[Taxis=DT_INT32, Tindices=DT_FLOAT, Tparams=DT_INT32, batch_dims=0] All kernels registered for op GatherV2: device='CPU'; Tparams in [DT_INT64]; Tindices in [DT_INT32] device='CPU'; Tparams in [DT_INT64]; Tindices in [DT_INT64] device='CPU'; Tparams in [DT_INT32]; Tindices in [DT_INT32] ... Many more registrations ... New: InvalidArgumentError: Value for attr 'Tindices' of float is not in the list of allowed values: int32, int64 ; NodeDef: {{node GatherV2}}; ... PiperOrigin-RevId: 314963092 Change-Id: I8072e7ba9e6d316570a536780d78992691e620f1
This commit is contained in:
parent
86c745a112
commit
a565c473c1
@ -1497,6 +1497,13 @@ Status SupportedDeviceTypesForNode(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we were unable to find any valid devices let's validate if the node is
|
||||
// even valid.
|
||||
if (prioritized_device_types->empty()) {
|
||||
TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def));
|
||||
}
|
||||
|
||||
std::sort(prioritized_device_types->begin(),
|
||||
prioritized_device_types->end(),
|
||||
[](const std::pair<DeviceType, int32>& a,
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -212,6 +213,12 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
gather_t = array_ops.gather(params, indices, axis=axis)
|
||||
self.assertEqual(None, gather_t.shape)
|
||||
|
||||
def testBadIndicesType(self):
|
||||
with self.assertRaisesRegex(
|
||||
(TypeError, errors.InvalidArgumentError),
|
||||
"float.* not in.* list of allowed values: int32, int64"):
|
||||
self.evaluate(array_ops.gather([0], 0.))
|
||||
|
||||
@test_util.disable_xla(
|
||||
"Assertion inside an op is not supported in XLA. Instead XLA clamps the "
|
||||
"index to be in bounds and returns the indexed value there (Don't rely "
|
||||
|
Loading…
Reference in New Issue
Block a user