"While" works like a function application. Add it to this list of function

applications that need to pass int32 tensors as host tensors.

PiperOrigin-RevId: 227047393
This commit is contained in:
A. Unique TensorFlower 2018-12-27 11:37:27 -08:00 committed by TensorFlower Gardener
parent ad3ccf59df
commit 1a8bc41039
2 changed files with 21 additions and 1 deletions

View File

@ -62,7 +62,7 @@ void MemoryTypesHelper(const NameRangeMap& name_map,
bool IsFunctionCallOp(const string& op_type) {
return op_type == "SymbolicGradient" || op_type == "PartitionedCall" ||
op_type == "StatefulPartitionedCall";
op_type == "StatefulPartitionedCall" || op_type == "While";
}
} // namespace

View File

@ -770,6 +770,26 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
# Like above, but using int32 in order to ensure that int32 tensors don't get
# copied to the GPU during the application of the while.
def testWhileInt32(self):
with ops.Graph().as_default() as g:
@function.Defun(*[dtypes.int32] * 2)
def Cond(n, unused_x):
return n > 0
@function.Defun(*[dtypes.int32] * 2)
def Body(n, x):
return n - 1, x + n
def Run(sess, n):
return sess.run(functional_ops.While([n, 0], Cond, Body))[1]
with self.session(graph=g, use_gpu=True) as sess:
self.assertAllEqual(Run(sess, 20), 210)
self.assertAllEqual(Run(sess, 100), 5050)
@test_util.run_deprecated_v1
def testWhileLowering(self):