From 84f091dff8e1bcd93ac2d69d2cc11faca3790ac9 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 6 Sep 2018 10:20:55 -0700 Subject: [PATCH] Add python test for While op lowering. Test that fetching values of while outputs in sess.run by tensor name works. This tests that an IdentityN node with the same name and outputs as the original while op was added to the graph during lowering. PiperOrigin-RevId: 211827934 --- tensorflow/python/kernel_tests/BUILD | 1 + .../kernel_tests/functional_ops_test.py | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 3026c7755ad..58c8975daa1 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1634,6 +1634,7 @@ cuda_py_test( srcs = ["functional_ops_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 3ddb5e06c92..e39daf1371d 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import iterator_ops @@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(Run(sess, 20.), 210.) self.assertAllEqual(Run(sess, 100.), 5050.) + def testWhileLowering(self): + + def Run(n, fetch_by_name): + for use_gpu in (True, False): + with ops.Graph().as_default() as g: + + @function.Defun(*[dtypes.float32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[dtypes.float32] * 2) + def Body(n, x): + return n - 1, x + n + + # outputs: [0, n*(n+1)/2] + outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while") + + # `outputs` is the list of output tensors of the While op. We + # arbitrarily choose the 0th tensor to get the While op and set the + # lowering attribute on it. + outputs[0].op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + if not fetch_by_name: + fetch = outputs[1] + else: + fetch = "my_while:1" + with self.test_session(graph=g, use_gpu=use_gpu) as sess: + return sess.run(fetch) + + self.assertAllEqual(Run(20., False), 210.) + self.assertAllEqual(Run(20., True), 210.) + self.assertAllEqual(Run(100., False), 5050.) + self.assertAllEqual(Run(100., True), 5050.) + def testWhileError(self): for use_gpu in (True, False): with ops.Graph().as_default() as g: