Don't allow floats to be converted to uint32/64
tf.constant(0, dtype=tf.uint32) < 0.5 would return False before this commit. PiperOrigin-RevId: 220871701
This commit is contained in:
parent
75aab0042e
commit
d00c00c109
@ -261,9 +261,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCompatibility(self):
|
||||
# TODO(nareshmodi): uint32, uint64 are not correctly handled in graph mode.
|
||||
integer_types = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||
dtypes.uint8, dtypes.uint16]
|
||||
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
|
||||
|
||||
# Floats are not compatible with ints
|
||||
for t in integer_types:
|
||||
|
@ -339,11 +339,29 @@ _TF_TO_IS_OK = {
|
||||
dtypes.string: [_FilterStr],
|
||||
dtypes.uint16: [_FilterInt],
|
||||
dtypes.uint8: [_FilterInt],
|
||||
dtypes.uint32: [_FilterInt],
|
||||
dtypes.uint64: [_FilterInt],
|
||||
}
|
||||
|
||||
|
||||
def _AssertCompatible(values, dtype):
|
||||
fn_list = _TF_TO_IS_OK.get(dtype, [_FilterNotTensor])
|
||||
if dtype is None:
|
||||
fn_list = [_FilterNotTensor]
|
||||
else:
|
||||
try:
|
||||
fn_list = _TF_TO_IS_OK[dtype]
|
||||
except KeyError:
|
||||
# There isn't a specific fn_list, so we try to do the best possible.
|
||||
if dtype.is_integer:
|
||||
fn_list = [_FilterInt]
|
||||
elif dtype.is_floating:
|
||||
fn_list = [_FilterFloat]
|
||||
elif dtype.is_complex:
|
||||
fn_list = [_FilterComplex]
|
||||
elif dtype.is_quantized:
|
||||
fn_list = [_FilterInt, _FilterTuple]
|
||||
else:
|
||||
fn_list = [_FilterNotTensor]
|
||||
mismatch = _FirstNotNone([fn(values) for fn in fn_list])
|
||||
if mismatch is not None:
|
||||
if dtype is None:
|
||||
|
Loading…
Reference in New Issue
Block a user