diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index ddd46c167bf..f61d8478177 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -261,9 +261,8 @@ class TFETensorTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes def testCompatibility(self): - # TODO(nareshmodi): uint32, uint64 are not correctly handled in graph mode. integer_types = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, - dtypes.uint8, dtypes.uint16] + dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64] # Floats are not compatible with ints for t in integer_types: diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 0582e986032..879addde538 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -339,11 +339,29 @@ _TF_TO_IS_OK = { dtypes.string: [_FilterStr], dtypes.uint16: [_FilterInt], dtypes.uint8: [_FilterInt], + dtypes.uint32: [_FilterInt], + dtypes.uint64: [_FilterInt], } def _AssertCompatible(values, dtype): - fn_list = _TF_TO_IS_OK.get(dtype, [_FilterNotTensor]) + if dtype is None: + fn_list = [_FilterNotTensor] + else: + try: + fn_list = _TF_TO_IS_OK[dtype] + except KeyError: + # There isn't a specific fn_list, so we try to do the best possible. + if dtype.is_integer: + fn_list = [_FilterInt] + elif dtype.is_floating: + fn_list = [_FilterFloat] + elif dtype.is_complex: + fn_list = [_FilterComplex] + elif dtype.is_quantized: + fn_list = [_FilterInt, _FilterTuple] + else: + fn_list = [_FilterNotTensor] mismatch = _FirstNotNone([fn(values) for fn in fn_list]) if mismatch is not None: if dtype is None: