Add test for creating variable inside function passed to tf.vectorized_map.

Switch pfor logic to use tf.function.

PiperOrigin-RevId: 285504833
Change-Id: I1d0712349dae0cbe49a05684a0bdbba48eb517ff
This commit is contained in:
A. Unique TensorFlower 2019-12-13 17:41:15 -08:00 committed by TensorFlower Gardener
parent 539eb9f8b2
commit d4c014f6fe
2 changed files with 2 additions and 33 deletions

View File

@ -22,7 +22,7 @@ from __future__ import print_function
import functools import functools
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import function
from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
@ -185,7 +185,7 @@ def pfor(loop_fn, iters, parallel_iterations=None):
# XLA compilation. The latter is so that we don't compile operations like # XLA compilation. The latter is so that we don't compile operations like
# tf.placeholder that are created by the loop body. # tf.placeholder that are created by the loop body.
if context.executing_eagerly() or _is_under_xla_context(): if context.executing_eagerly() or _is_under_xla_context():
f = def_function.function(f) f = function.defun(f)
return f() return f()

View File

@ -1477,36 +1477,5 @@ class PartitionedCallTest(PForTestCase):
self._test_loop_fn(loop_fn, 4) self._test_loop_fn(loop_fn, 4)
class VariableTest(PForTestCase):
def test_create_variable_once(self):
x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
a_var = []
def f(z):
if not a_var:
a_var.append(variables.Variable(lambda: y, name="a"))
return math_ops.matmul(z, a_var[0] / 16)
pfor_control_flow_ops.vectorized_map(f, x)
@test_util.run_v2_only
def test_create_variable_repeated(self):
x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
def f(z):
a_var = variables.Variable(lambda: y, name="a") / 4
return math_ops.matmul(z, a_var / 16)
# Note that this error is only raised under v2 behavior.
with self.assertRaisesRegexp(
ValueError,
"tf.function-decorated function tried to create variables on non-first"
):
pfor_control_flow_ops.vectorized_map(f, x)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()