Add test for creating variable inside function passed to tf.vectorized_map.
Switch pfor logic to use tf.function. PiperOrigin-RevId: 285504833 Change-Id: I1d0712349dae0cbe49a05684a0bdbba48eb517ff
This commit is contained in:
parent
539eb9f8b2
commit
d4c014f6fe
@ -22,7 +22,7 @@ from __future__ import print_function
|
||||
import functools
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -185,7 +185,7 @@ def pfor(loop_fn, iters, parallel_iterations=None):
|
||||
# XLA compilation. The latter is so that we don't compile operations like
|
||||
# tf.placeholder that are created by the loop body.
|
||||
if context.executing_eagerly() or _is_under_xla_context():
|
||||
f = def_function.function(f)
|
||||
f = function.defun(f)
|
||||
return f()
|
||||
|
||||
|
||||
|
@ -1477,36 +1477,5 @@ class PartitionedCallTest(PForTestCase):
|
||||
self._test_loop_fn(loop_fn, 4)
|
||||
|
||||
|
||||
class VariableTest(PForTestCase):
|
||||
|
||||
def test_create_variable_once(self):
|
||||
x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
|
||||
y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
|
||||
a_var = []
|
||||
|
||||
def f(z):
|
||||
if not a_var:
|
||||
a_var.append(variables.Variable(lambda: y, name="a"))
|
||||
return math_ops.matmul(z, a_var[0] / 16)
|
||||
|
||||
pfor_control_flow_ops.vectorized_map(f, x)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_create_variable_repeated(self):
|
||||
x = array_ops.ones(shape=(3, 2, 2), dtype=dtypes.float32)
|
||||
y = array_ops.ones(shape=(2, 3), dtype=dtypes.float32)
|
||||
|
||||
def f(z):
|
||||
a_var = variables.Variable(lambda: y, name="a") / 4
|
||||
return math_ops.matmul(z, a_var / 16)
|
||||
|
||||
# Note that this error is only raised under v2 behavior.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"tf.function-decorated function tried to create variables on non-first"
|
||||
):
|
||||
pfor_control_flow_ops.vectorized_map(f, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user