Don't allow floats to be converted to uint32/64

tf.constant(0, dtype=tf.uint32) < 0.5 would return False before this commit.

PiperOrigin-RevId: 220871701
This commit is contained in:
Akshay Modi 2018-11-09 15:05:42 -08:00 committed by TensorFlower Gardener
parent 75aab0042e
commit d00c00c109
2 changed files with 20 additions and 3 deletions

View File

@ -261,9 +261,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes
def testCompatibility(self):
# TODO(nareshmodi): uint32, uint64 are not correctly handled in graph mode.
integer_types = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
dtypes.uint8, dtypes.uint16]
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
# Floats are not compatible with ints
for t in integer_types:

View File

@ -339,11 +339,29 @@ _TF_TO_IS_OK = {
dtypes.string: [_FilterStr],
dtypes.uint16: [_FilterInt],
dtypes.uint8: [_FilterInt],
dtypes.uint32: [_FilterInt],
dtypes.uint64: [_FilterInt],
}
def _AssertCompatible(values, dtype):
fn_list = _TF_TO_IS_OK.get(dtype, [_FilterNotTensor])
if dtype is None:
fn_list = [_FilterNotTensor]
else:
try:
fn_list = _TF_TO_IS_OK[dtype]
except KeyError:
# There isn't a specific fn_list, so we try to do the best possible.
if dtype.is_integer:
fn_list = [_FilterInt]
elif dtype.is_floating:
fn_list = [_FilterFloat]
elif dtype.is_complex:
fn_list = [_FilterComplex]
elif dtype.is_quantized:
fn_list = [_FilterInt, _FilterTuple]
else:
fn_list = [_FilterNotTensor]
mismatch = _FirstNotNone([fn(values) for fn in fn_list])
if mismatch is not None:
if dtype is None: