Merge pull request #35961 from yongtang:35938-tf.cast-float64
PiperOrigin-RevId: 290962892 Change-Id: Ic0911f8593ac6f1c12f4fcde57f90d8a7e357984
This commit is contained in:
commit
466e818c1e
tensorflow/python
@ -213,6 +213,13 @@ class CastOpTest(test.TestCase):
|
||||
err = gradient_checker.compute_gradient_error(x, [], y, [])
|
||||
self.assertLess(err, 1e-3)
|
||||
|
||||
def testPythonDataTypes(self):
|
||||
with self.cached_session():
|
||||
# GitHub issue 35938, a of 0.2 is for python native type.
|
||||
a = 0.2
|
||||
b = math_ops.cast(a, dtypes.float64)
|
||||
self.assertAllEqual(a, self.evaluate(b))
|
||||
|
||||
|
||||
class SparseTensorCastTest(test.TestCase):
|
||||
|
||||
|
@ -745,7 +745,7 @@ def cast(x, dtype, name=None):
|
||||
# ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
|
||||
# allows some conversions that cast() can't do, e.g. casting numbers to
|
||||
# strings.
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor(x, dtype_hint=base_type, name="x")
|
||||
if x.dtype.base_dtype != base_type:
|
||||
x = gen_math_ops.cast(x, base_type, name=name)
|
||||
if x.dtype.is_complex and base_type.is_floating:
|
||||
|
Loading…
Reference in New Issue
Block a user