Add python test for While op lowering.
Test that fetching values of while outputs in sess.run by tensor name works. This tests that an IdentityN node with the same name and outputs as the original while op was added to the graph during lowering. PiperOrigin-RevId: 211827934
This commit is contained in:
parent
43a3c393d7
commit
84f091dff8
@ -1634,6 +1634,7 @@ cuda_py_test(
|
|||||||
srcs = ["functional_ops_test.py"],
|
srcs = ["functional_ops_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase):
|
|||||||
self.assertAllEqual(Run(sess, 20.), 210.)
|
self.assertAllEqual(Run(sess, 20.), 210.)
|
||||||
self.assertAllEqual(Run(sess, 100.), 5050.)
|
self.assertAllEqual(Run(sess, 100.), 5050.)
|
||||||
|
|
||||||
|
def testWhileLowering(self):
|
||||||
|
|
||||||
|
def Run(n, fetch_by_name):
|
||||||
|
for use_gpu in (True, False):
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
|
||||||
|
@function.Defun(*[dtypes.float32] * 2)
|
||||||
|
def Cond(n, unused_x):
|
||||||
|
return n > 0
|
||||||
|
|
||||||
|
@function.Defun(*[dtypes.float32] * 2)
|
||||||
|
def Body(n, x):
|
||||||
|
return n - 1, x + n
|
||||||
|
|
||||||
|
# outputs: [0, n*(n+1)/2]
|
||||||
|
outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
|
||||||
|
|
||||||
|
# `outputs` is the list of output tensors of the While op. We
|
||||||
|
# arbitrarily choose the 0th tensor to get the While op and set the
|
||||||
|
# lowering attribute on it.
|
||||||
|
outputs[0].op._set_attr("_lower_using_switch_merge",
|
||||||
|
attr_value_pb2.AttrValue(b=True))
|
||||||
|
if not fetch_by_name:
|
||||||
|
fetch = outputs[1]
|
||||||
|
else:
|
||||||
|
fetch = "my_while:1"
|
||||||
|
with self.test_session(graph=g, use_gpu=use_gpu) as sess:
|
||||||
|
return sess.run(fetch)
|
||||||
|
|
||||||
|
self.assertAllEqual(Run(20., False), 210.)
|
||||||
|
self.assertAllEqual(Run(20., True), 210.)
|
||||||
|
self.assertAllEqual(Run(100., False), 5050.)
|
||||||
|
self.assertAllEqual(Run(100., True), 5050.)
|
||||||
|
|
||||||
def testWhileError(self):
|
def testWhileError(self):
|
||||||
for use_gpu in (True, False):
|
for use_gpu in (True, False):
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
|
Loading…
Reference in New Issue
Block a user