Add python test for While op lowering.

Test that fetching values of while outputs in sess.run by tensor name works. This tests that an IdentityN node with the same name and outputs as the original while op was added to the graph during lowering.

PiperOrigin-RevId: 211827934
This commit is contained in:
Saurabh Saxena 2018-09-06 10:20:55 -07:00 committed by TensorFlower Gardener
parent 43a3c393d7
commit 84f091dff8
2 changed files with 36 additions and 0 deletions

View File

@ -1634,6 +1634,7 @@ cuda_py_test(
srcs = ["functional_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import iterator_ops
@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
def testWhileLowering(self):
def Run(n, fetch_by_name):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
@function.Defun(*[dtypes.float32] * 2)
def Cond(n, unused_x):
return n > 0
@function.Defun(*[dtypes.float32] * 2)
def Body(n, x):
return n - 1, x + n
# outputs: [0, n*(n+1)/2]
outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
# `outputs` is the list of output tensors of the While op. We
# arbitrarily choose the 0th tensor to get the While op and set the
# lowering attribute on it.
outputs[0].op._set_attr("_lower_using_switch_merge",
attr_value_pb2.AttrValue(b=True))
if not fetch_by_name:
fetch = outputs[1]
else:
fetch = "my_while:1"
with self.test_session(graph=g, use_gpu=use_gpu) as sess:
return sess.run(fetch)
self.assertAllEqual(Run(20., False), 210.)
self.assertAllEqual(Run(20., True), 210.)
self.assertAllEqual(Run(100., False), 5050.)
self.assertAllEqual(Run(100., True), 5050.)
def testWhileError(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g: