Add test for creating variable inside function passed to tf.vectorized_map.
Switch pfor logic to use tf.function. PiperOrigin-RevId: 285504833 Change-Id: I1d0712349dae0cbe49a05684a0bdbba48eb517ff
This commit is contained in:
parent
539eb9f8b2
commit
d4c014f6fe
@ -22,7 +22,7 @@ from __future__ import print_function
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import indexed_slices
|
from tensorflow.python.framework import indexed_slices
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -185,7 +185,7 @@ def pfor(loop_fn, iters, parallel_iterations=None):
|
|||||||
# XLA compilation. The latter is so that we don't compile operations like
|
# XLA compilation. The latter is so that we don't compile operations like
|
||||||
# tf.placeholder that are created by the loop body.
|
# tf.placeholder that are created by the loop body.
|
||||||
if context.executing_eagerly() or _is_under_xla_context():
|
if context.executing_eagerly() or _is_under_xla_context():
|
||||||
f = def_function.function(f)
|
f = function.defun(f)
|
||||||
return f()
|
return f()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1477,36 +1477,5 @@ class PartitionedCallTest(PForTestCase):
|
|||||||
self._test_loop_fn(loop_fn, 4)
|
self._test_loop_fn(loop_fn, 4)
|
||||||
|
|
||||||
|
|
||||||
class VariableTest(PForTestCase):
|
|
||||||
|
|
||||||
def test_create_variable_once(self):
|
|
||||||
x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
|
|
||||||
y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
|
|
||||||
a_var = []
|
|
||||||
|
|
||||||
def f(z):
|
|
||||||
if not a_var:
|
|
||||||
a_var.append(variables.Variable(lambda: y, name="a"))
|
|
||||||
return math_ops.matmul(z, a_var[0] / 16)
|
|
||||||
|
|
||||||
pfor_control_flow_ops.vectorized_map(f, x)
|
|
||||||
|
|
||||||
@test_util.run_v2_only
|
|
||||||
def test_create_variable_repeated(self):
|
|
||||||
x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
|
|
||||||
y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
|
|
||||||
|
|
||||||
def f(z):
|
|
||||||
a_var = variables.Variable(lambda: y, name="a") / 4
|
|
||||||
return math_ops.matmul(z, a_var / 16)
|
|
||||||
|
|
||||||
# Note that this error is only raised under v2 behavior.
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
ValueError,
|
|
||||||
"tf.function-decorated function tried to create variables on non-first"
|
|
||||||
):
|
|
||||||
pfor_control_flow_ops.vectorized_map(f, x)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user