"While" works like a function application. Add it to this list of function
applications that need to pass int32 tensors as host tensors. PiperOrigin-RevId: 227047393
This commit is contained in:
parent
ad3ccf59df
commit
1a8bc41039
@ -62,7 +62,7 @@ void MemoryTypesHelper(const NameRangeMap& name_map,
|
||||
|
||||
bool IsFunctionCallOp(const string& op_type) {
|
||||
return op_type == "SymbolicGradient" || op_type == "PartitionedCall" ||
|
||||
op_type == "StatefulPartitionedCall";
|
||||
op_type == "StatefulPartitionedCall" || op_type == "While";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -770,6 +770,26 @@ class FunctionalOpsTest(test.TestCase):
|
||||
self.assertAllEqual(Run(sess, 20.), 210.)
|
||||
self.assertAllEqual(Run(sess, 100.), 5050.)
|
||||
|
||||
# Like above, but using int32 in order to ensure that int32 tensors don't get
|
||||
# copied to the GPU during the application of the while.
|
||||
def testWhileInt32(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
|
||||
@function.Defun(*[dtypes.int32] * 2)
|
||||
def Cond(n, unused_x):
|
||||
return n > 0
|
||||
|
||||
@function.Defun(*[dtypes.int32] * 2)
|
||||
def Body(n, x):
|
||||
return n - 1, x + n
|
||||
|
||||
def Run(sess, n):
|
||||
return sess.run(functional_ops.While([n, 0], Cond, Body))[1]
|
||||
|
||||
with self.session(graph=g, use_gpu=True) as sess:
|
||||
self.assertAllEqual(Run(sess, 20), 210)
|
||||
self.assertAllEqual(Run(sess, 100), 5050)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWhileLowering(self):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user