Merge pull request from yongtang:35938-tf.cast-float64

PiperOrigin-RevId: 290962892
Change-Id: Ic0911f8593ac6f1c12f4fcde57f90d8a7e357984
This commit is contained in:
TensorFlower Gardener 2020-01-22 08:57:47 -08:00
commit 466e818c1e
2 changed files with 8 additions and 1 deletions
tensorflow/python

View File

@ -213,6 +213,13 @@ class CastOpTest(test.TestCase):
err = gradient_checker.compute_gradient_error(x, [], y, [])
self.assertLess(err, 1e-3)
def testPythonDataTypes(self):
with self.cached_session():
# GitHub issue 35938, a of 0.2 is for python native type.
a = 0.2
b = math_ops.cast(a, dtypes.float64)
self.assertAllEqual(a, self.evaluate(b))
class SparseTensorCastTest(test.TestCase):

View File

@ -745,7 +745,7 @@ def cast(x, dtype, name=None):
# ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
# allows some conversions that cast() can't do, e.g. casting numbers to
# strings.
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor(x, dtype_hint=base_type, name="x")
if x.dtype.base_dtype != base_type:
x = gen_math_ops.cast(x, base_type, name=name)
if x.dtype.is_complex and base_type.is_floating: