Improve error message when attributes are invalid

If we are unable to find any valid devices for a node, we can do a quick
check to see if the node is even valid as per the op definition. This
greatly improves the eager error message since there is no point in
listing all the available kernels across all devices if we know none of
them can match.

Previous:
NotFoundError: Could not find device for node: {{node GatherV2}} = GatherV2[Taxis=DT_INT32, Tindices=DT_FLOAT, Tparams=DT_INT32, batch_dims=0]
All kernels registered for op GatherV2:
  device='CPU'; Tparams in [DT_INT64]; Tindices in [DT_INT32]
  device='CPU'; Tparams in [DT_INT64]; Tindices in [DT_INT64]
  device='CPU'; Tparams in [DT_INT32]; Tindices in [DT_INT32]

... Many more registrations ...

New:
InvalidArgumentError: Value for attr 'Tindices' of float is not in the list of allowed values: int32, int64
	; NodeDef: {{node GatherV2}}; ...
PiperOrigin-RevId: 314963092
Change-Id: I8072e7ba9e6d316570a536780d78992691e620f1
This commit is contained in:
Gaurav Jain 2020-06-05 11:25:01 -07:00 committed by TensorFlower Gardener
parent 86c745a112
commit a565c473c1
2 changed files with 14 additions and 0 deletions

View File

@ -1497,6 +1497,13 @@ Status SupportedDeviceTypesForNode(
} }
} }
} }
// If we were unable to find any valid devices let's validate if the node is
// even valid.
if (prioritized_device_types->empty()) {
TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def));
}
std::sort(prioritized_device_types->begin(), std::sort(prioritized_device_types->begin(),
prioritized_device_types->end(), prioritized_device_types->end(),
[](const std::pair<DeviceType, int32>& a, [](const std::pair<DeviceType, int32>& a,

View File

@ -25,6 +25,7 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -212,6 +213,12 @@ class GatherTest(test.TestCase, parameterized.TestCase):
gather_t = array_ops.gather(params, indices, axis=axis) gather_t = array_ops.gather(params, indices, axis=axis)
self.assertEqual(None, gather_t.shape) self.assertEqual(None, gather_t.shape)
def testBadIndicesType(self):
with self.assertRaisesRegex(
(TypeError, errors.InvalidArgumentError),
"float.* not in.* list of allowed values: int32, int64"):
self.evaluate(array_ops.gather([0], 0.))
@test_util.disable_xla( @test_util.disable_xla(
"Assertion inside an op is not supported in XLA. Instead XLA clamps the " "Assertion inside an op is not supported in XLA. Instead XLA clamps the "
"index to be in bounds and returns the indexed value there (Don't rely " "index to be in bounds and returns the indexed value there (Don't rely "