Internal change
PiperOrigin-RevId: 290967854 Change-Id: I3c2d64e64414a805555bd3a5228d277199163ce3
This commit is contained in:
parent
5708850f62
commit
c4b3910052
|
@ -213,13 +213,6 @@ class CastOpTest(test.TestCase):
|
||||||
err = gradient_checker.compute_gradient_error(x, [], y, [])
|
err = gradient_checker.compute_gradient_error(x, [], y, [])
|
||||||
self.assertLess(err, 1e-3)
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
def testPythonDataTypes(self):
|
|
||||||
with self.cached_session():
|
|
||||||
# GitHub issue 35938, a of 0.2 is for python native type.
|
|
||||||
a = 0.2
|
|
||||||
b = math_ops.cast(a, dtypes.float64)
|
|
||||||
self.assertAllEqual(a, self.evaluate(b))
|
|
||||||
|
|
||||||
|
|
||||||
class SparseTensorCastTest(test.TestCase):
|
class SparseTensorCastTest(test.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -745,7 +745,7 @@ def cast(x, dtype, name=None):
|
||||||
# ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
|
# ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
|
||||||
# allows some conversions that cast() can't do, e.g. casting numbers to
|
# allows some conversions that cast() can't do, e.g. casting numbers to
|
||||||
# strings.
|
# strings.
|
||||||
x = ops.convert_to_tensor(x, dtype_hint=base_type, name="x")
|
x = ops.convert_to_tensor(x, name="x")
|
||||||
if x.dtype.base_dtype != base_type:
|
if x.dtype.base_dtype != base_type:
|
||||||
x = gen_math_ops.cast(x, base_type, name=name)
|
x = gen_math_ops.cast(x, base_type, name=name)
|
||||||
if x.dtype.is_complex and base_type.is_floating:
|
if x.dtype.is_complex and base_type.is_floating:
|
||||||
|
|
Loading…
Reference in New Issue