Don't allow floats to be converted to uint32/64
tf.constant(0, dtype=tf.uint32) < 0.5 would return False before this commit. PiperOrigin-RevId: 220871701
This commit is contained in:
parent
75aab0042e
commit
d00c00c109
@ -261,9 +261,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testCompatibility(self):
|
def testCompatibility(self):
|
||||||
# TODO(nareshmodi): uint32, uint64 are not correctly handled in graph mode.
|
|
||||||
integer_types = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
integer_types = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
|
||||||
dtypes.uint8, dtypes.uint16]
|
dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64]
|
||||||
|
|
||||||
# Floats are not compatible with ints
|
# Floats are not compatible with ints
|
||||||
for t in integer_types:
|
for t in integer_types:
|
||||||
|
@ -339,11 +339,29 @@ _TF_TO_IS_OK = {
|
|||||||
dtypes.string: [_FilterStr],
|
dtypes.string: [_FilterStr],
|
||||||
dtypes.uint16: [_FilterInt],
|
dtypes.uint16: [_FilterInt],
|
||||||
dtypes.uint8: [_FilterInt],
|
dtypes.uint8: [_FilterInt],
|
||||||
|
dtypes.uint32: [_FilterInt],
|
||||||
|
dtypes.uint64: [_FilterInt],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _AssertCompatible(values, dtype):
|
def _AssertCompatible(values, dtype):
|
||||||
fn_list = _TF_TO_IS_OK.get(dtype, [_FilterNotTensor])
|
if dtype is None:
|
||||||
|
fn_list = [_FilterNotTensor]
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
fn_list = _TF_TO_IS_OK[dtype]
|
||||||
|
except KeyError:
|
||||||
|
# There isn't a specific fn_list, so we try to do the best possible.
|
||||||
|
if dtype.is_integer:
|
||||||
|
fn_list = [_FilterInt]
|
||||||
|
elif dtype.is_floating:
|
||||||
|
fn_list = [_FilterFloat]
|
||||||
|
elif dtype.is_complex:
|
||||||
|
fn_list = [_FilterComplex]
|
||||||
|
elif dtype.is_quantized:
|
||||||
|
fn_list = [_FilterInt, _FilterTuple]
|
||||||
|
else:
|
||||||
|
fn_list = [_FilterNotTensor]
|
||||||
mismatch = _FirstNotNone([fn(values) for fn in fn_list])
|
mismatch = _FirstNotNone([fn(values) for fn in fn_list])
|
||||||
if mismatch is not None:
|
if mismatch is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user