From a565c473c14da0c1bb35bd1771bdc195a72f7982 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Fri, 5 Jun 2020 11:25:01 -0700 Subject: [PATCH] Improve error message when attributes are invalid If we are unable to find any valid devices for a node, we can do a quick check to see if the node is even valid as per the op definition. This greatly improves the eager error message since there is no point in listing all the available kernels across all devices if we know none of them can match. Previous: NotFoundError: Could not find device for node: {{node GatherV2}} = GatherV2[Taxis=DT_INT32, Tindices=DT_FLOAT, Tparams=DT_INT32, batch_dims=0] All kernels registered for op GatherV2: device='CPU'; Tparams in [DT_INT64]; Tindices in [DT_INT32] device='CPU'; Tparams in [DT_INT64]; Tindices in [DT_INT64] device='CPU'; Tparams in [DT_INT32]; Tindices in [DT_INT32] ... Many more registrations ... New: InvalidArgumentError: Value for attr 'Tindices' of float is not in the list of allowed values: int32, int64 ; NodeDef: {{node GatherV2}}; ... PiperOrigin-RevId: 314963092 Change-Id: I8072e7ba9e6d316570a536780d78992691e620f1 --- tensorflow/core/framework/op_kernel.cc | 7 +++++++ tensorflow/python/kernel_tests/gather_op_test.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 2e7747380b4..abf73cb57df 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -1497,6 +1497,13 @@ Status SupportedDeviceTypesForNode( } } } + + // If we were unable to find any valid devices let's validate if the node is + // even valid. + if (prioritized_device_types->empty()) { + TF_RETURN_IF_ERROR(ValidateNodeDef(def, op_reg_data->op_def)); + } + std::sort(prioritized_device_types->begin(), prioritized_device_types->end(), [](const std::pair& a, diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index b966110963c..c5b42dd60a7 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -212,6 +213,12 @@ class GatherTest(test.TestCase, parameterized.TestCase): gather_t = array_ops.gather(params, indices, axis=axis) self.assertEqual(None, gather_t.shape) + def testBadIndicesType(self): + with self.assertRaisesRegex( + (TypeError, errors.InvalidArgumentError), + "float.* not in.* list of allowed values: int32, int64"): + self.evaluate(array_ops.gather([0], 0.)) + @test_util.disable_xla( "Assertion inside an op is not supported in XLA. Instead XLA clamps the " "index to be in bounds and returns the indexed value there (Don't rely "