From 1a8bc4103932c6f81c26d5164c8f34210010e518 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Dec 2018 11:37:27 -0800 Subject: [PATCH] "While" works like a function application. Add it to this list of function applications that need to pass int32 tensors as host tensors. PiperOrigin-RevId: 227047393 --- tensorflow/core/framework/memory_types.cc | 2 +- .../kernel_tests/functional_ops_test.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 6dff6fe654a..8caea351be4 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -62,7 +62,7 @@ void MemoryTypesHelper(const NameRangeMap& name_map, bool IsFunctionCallOp(const string& op_type) { return op_type == "SymbolicGradient" || op_type == "PartitionedCall" || - op_type == "StatefulPartitionedCall"; + op_type == "StatefulPartitionedCall" || op_type == "While"; } } // namespace diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 3b3db429d8a..8988305bde5 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -770,6 +770,26 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(Run(sess, 20.), 210.) self.assertAllEqual(Run(sess, 100.), 5050.) + # Like above, but using int32 in order to ensure that int32 tensors don't get + # copied to the GPU during the application of the while. + def testWhileInt32(self): + with ops.Graph().as_default() as g: + + @function.Defun(*[dtypes.int32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[dtypes.int32] * 2) + def Body(n, x): + return n - 1, x + n + + def Run(sess, n): + return sess.run(functional_ops.While([n, 0], Cond, Body))[1] + + with self.session(graph=g, use_gpu=True) as sess: + self.assertAllEqual(Run(sess, 20), 210) + self.assertAllEqual(Run(sess, 100), 5050) + @test_util.run_deprecated_v1 def testWhileLowering(self):