When applicable, also move TensorListElementShape and TensorListLength to the forward graph as an optimization to Control Flow v2. PiperOrigin-RevId: 347699857 Change-Id: I98e4bd2df4d79cb7e3d4bc3c2c2f8c86e76aef9a
2049 lines
62 KiB
Python
2049 lines
62 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for while_v2."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl.testing import parameterized
|
|
|
|
from google.protobuf import text_format
|
|
from tensorflow.core.framework import graph_pb2
|
|
from tensorflow.core.protobuf import config_pb2
|
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
|
from tensorflow.python.eager import backprop
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import function
|
|
from tensorflow.python.framework import importer
|
|
from tensorflow.python.framework import meta_graph
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.grappler import tf_optimizer
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import control_flow_util
|
|
from tensorflow.python.ops import control_flow_util_v2
|
|
from tensorflow.python.ops import control_flow_v2_toggles
|
|
from tensorflow.python.ops import custom_gradient
|
|
from tensorflow.python.ops import gen_array_ops
|
|
from tensorflow.python.ops import gen_list_ops
|
|
from tensorflow.python.ops import gradient_checker_v2
|
|
from tensorflow.python.ops import gradients_impl
|
|
from tensorflow.python.ops import list_ops
|
|
from tensorflow.python.ops import map_fn
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import random_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.ops import while_v2
|
|
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
def random_gamma(shape): # pylint: disable=invalid-name
|
|
return random_ops.random_gamma(shape, 1.0)
|
|
|
|
|
|
def random_gamma_with_alpha_beta(shape): # pylint: disable=invalid-name
|
|
return random_ops.random_gamma(
|
|
shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]])
|
|
|
|
|
|
def random_poisson_v2(shape): # pylint: disable=invalid-name
|
|
return random_ops.random_poisson_v2(shape, 1.0)
|
|
|
|
|
|
def random_poisson_v2_with_lam(shape): # pylint: disable=invalid-name
|
|
return random_ops.random_poisson_v2(shape, [12.2, 3.3])
|
|
|
|
|
|
def fill(shape): # pylint: disable=invalid-name
|
|
return array_ops.fill(shape, 1.0)
|
|
|
|
|
|
class WhileV2Test(test.TestCase, parameterized.TestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSingleLoopVar(self):
|
|
x = constant_op.constant(2.)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
|
|
grad = gradients_impl.gradients(ret, [x])
|
|
with self.cached_session():
|
|
self.assertEqual(self.evaluate(ret), 16.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSingleLoopVarBackPropFalse(self):
|
|
x = constant_op.constant(2.)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8.,
|
|
lambda v: v * v, [x],
|
|
return_same_structure=False,
|
|
back_prop=False)
|
|
grad = gradients_impl.gradients(ret, [x])
|
|
self.assertEqual(grad, [None])
|
|
with self.cached_session():
|
|
self.assertEqual(self.evaluate(ret), 16.)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCustomGradient(self):
|
|
x = constant_op.constant(2.)
|
|
n = constant_op.constant(1., name="const-n")
|
|
m = variables.Variable(1.0)
|
|
self.evaluate(variables.global_variables_initializer())
|
|
|
|
def body_fn(v): # pylint: disable=invalid-name
|
|
|
|
@custom_gradient.custom_gradient
|
|
def inner_fn(v): # pylint: disable=invalid-name
|
|
|
|
def grad_fn(dy, variables=None): # pylint: disable=invalid-name, unused-argument, redefined-outer-name
|
|
return dy * 2 * v * n * m, [v * v]
|
|
|
|
return v * v * m, grad_fn
|
|
|
|
return inner_fn(v)
|
|
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8., body_fn, [x], return_same_structure=False)
|
|
grad = gradients_impl.gradients(ret, [x])
|
|
with self.cached_session():
|
|
self.assertEqual(self.evaluate(ret), 16.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testReturnSameStructureTrue(self):
|
|
x = constant_op.constant(2.)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True)
|
|
grad = gradients_impl.gradients(ret, [x])
|
|
with self.cached_session() as sess:
|
|
eval_result = sess.run(ret)
|
|
self.assertIsInstance(eval_result, list)
|
|
self.assertLen(eval_result, 1)
|
|
self.assertEqual(16., eval_result[0])
|
|
self.assertSequenceEqual(sess.run(grad), [32.])
|
|
|
|
def testVerifyInputOutputTypesMatch(self):
|
|
|
|
@def_function.function
|
|
def BuildWhile():
|
|
x = constant_op.constant(1., dtypes.float32)
|
|
|
|
def Body(x):
|
|
return math_ops.cast(x, dtypes.float16) + 1
|
|
|
|
while_loop_v2(lambda x: x < 10, Body, [x])
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
r"Loop var Const:0 enters the loop with type <dtype: 'float32'> "
|
|
r"but has type <dtype: 'float16'> after 1 iteration."):
|
|
BuildWhile()
|
|
|
|
@parameterized.parameters(dtypes.float32, dtypes.float64)
|
|
def testGradientTapeResourceVariable(self, dtype):
|
|
with context.eager_mode():
|
|
v = variables.Variable(1., dtype=dtype)
|
|
|
|
@def_function.function
|
|
def fnWithLoop(): # pylint: disable=invalid-name
|
|
with backprop.GradientTape() as tape:
|
|
_, x = while_loop_v2(
|
|
lambda i, _: i < 2,
|
|
lambda i, x: (i + 1, x * v),
|
|
[0, constant_op.constant(2., dtype=dtype)])
|
|
return tape.gradient(x, v)
|
|
|
|
self.assertAllEqual(fnWithLoop(), 4.0)
|
|
|
|
def checkIteratedGradients(self, func):
|
|
with context.eager_mode():
|
|
|
|
def _Grad(f):
|
|
def _GradFunction(primal):
|
|
with backprop.GradientTape() as tape:
|
|
tape.watch(primal)
|
|
primal_out = f(primal)
|
|
return tape.gradient(primal_out, primal)
|
|
return _GradFunction
|
|
|
|
f = func
|
|
one = constant_op.constant(1.)
|
|
|
|
for _ in range(3):
|
|
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
|
def_function.function(f), [one])
|
|
self.assertAllClose(theoretical, numerical, rtol=1e-3)
|
|
f = _Grad(f)
|
|
self.assertAllClose(array_ops.reshape(numerical, []),
|
|
def_function.function(f)(one),
|
|
rtol=1e-3)
|
|
|
|
def testIteratedGradients(self):
|
|
|
|
def _Func(x):
|
|
_, z = while_loop_v2(
|
|
lambda i, _: i < 2,
|
|
lambda i, y: (i + 1, math_ops.cos(y)),
|
|
[0, x])
|
|
return z
|
|
|
|
self.checkIteratedGradients(_Func)
|
|
|
|
def testIteratedGradientsWithList(self):
|
|
|
|
def _Func(x):
|
|
results = list_ops.empty_tensor_list(
|
|
element_shape=[], element_dtype=dtypes.float32)
|
|
|
|
def _LoopBody(i, y, handle):
|
|
return (i + 1, math_ops.cos(y),
|
|
list_ops.tensor_list_push_back(handle, y))
|
|
|
|
_, z, results = while_loop_v2(
|
|
lambda i, _, h: i < 2, _LoopBody, [0, x, results])
|
|
return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
|
|
results, dtypes.float32))
|
|
|
|
self.checkIteratedGradients(_Func)
|
|
|
|
def testGradWhileGradWhileWithVariable(self):
|
|
with context.eager_mode():
|
|
v = variables.Variable(1.)
|
|
|
|
@def_function.function
|
|
def _Func(x):
|
|
|
|
def _Inner(a):
|
|
with backprop.GradientTape() as tape:
|
|
tape.watch(a)
|
|
_, b = while_loop_v2(
|
|
lambda i, _: i < 2,
|
|
lambda i, y: (i + 1, math_ops.cos(v + y)),
|
|
[0, a])
|
|
return tape.gradient(b, a)
|
|
|
|
_, z = while_loop_v2(
|
|
lambda i, _: i < 2,
|
|
lambda i, y: (i + 1, _Inner(y)),
|
|
[0, x])
|
|
return z
|
|
|
|
with backprop.GradientTape(persistent=True) as tape:
|
|
x = constant_op.constant(1.)
|
|
tape.watch(x)
|
|
y = _Func(x)
|
|
dx, _ = tape.gradient(y, [x, v])
|
|
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
|
_Func, [x])
|
|
self.assertAllClose(numerical, theoretical, rtol=1e-3)
|
|
self.assertAllClose(array_ops.reshape(numerical, []),
|
|
dx, rtol=1e-3)
|
|
|
|
def testThreeNestWithLists(self):
|
|
with context.eager_mode():
|
|
def _WrapInWhile(f):
|
|
def _Wrapped(x):
|
|
results = list_ops.empty_tensor_list(
|
|
element_shape=[], element_dtype=dtypes.float32)
|
|
|
|
def _LoopBody(i, y, handle):
|
|
return (i + 1, f(math_ops.cos(y)),
|
|
list_ops.tensor_list_push_back(handle, y))
|
|
|
|
_, z, results = control_flow_ops.while_loop(
|
|
lambda i, _, h: i < 2, _LoopBody, [0, x, results])
|
|
return z + math_ops.reduce_sum(list_ops.tensor_list_stack(
|
|
results, dtypes.float32))
|
|
return _Wrapped
|
|
|
|
f = math_ops.sin
|
|
|
|
target_function = _WrapInWhile(_WrapInWhile(_WrapInWhile(f)))
|
|
|
|
@def_function.function
|
|
def _TapeFromGraphMode(x):
|
|
with backprop.GradientTape(persistent=True) as tape:
|
|
tape.watch(x)
|
|
y = target_function(x)
|
|
return tape.gradient(y, x)
|
|
|
|
x = constant_op.constant(1.)
|
|
dx = _TapeFromGraphMode(x)
|
|
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
|
target_function, [x])
|
|
self.assertAllClose(numerical, theoretical, rtol=3e-3)
|
|
self.assertAllClose(array_ops.reshape(numerical, []), dx, rtol=3e-3)
|
|
|
|
def testDeviceLabelsInherited(self):
|
|
def _LoopBody(i, y):
|
|
result = math_ops.cos(y)
|
|
self.assertIn("CPU:10", result.device)
|
|
with ops.device("CPU:11"):
|
|
result = array_ops.identity(result)
|
|
self.assertIn("CPU:11", result.device)
|
|
return i + 1, result
|
|
|
|
@def_function.function
|
|
def _FunctionWithWhileLoop():
|
|
x = constant_op.constant(1.)
|
|
with ops.device("CPU:10"):
|
|
_, z = while_loop_v2(
|
|
lambda i, _: i < 2,
|
|
_LoopBody,
|
|
[0, x])
|
|
return z
|
|
# The test assertion runs at trace time.
|
|
_FunctionWithWhileLoop.get_concrete_function()
|
|
|
|
def testExternalControlDependencies(self):
|
|
with ops.Graph().as_default(), self.test_session():
|
|
v = variables.Variable(1.)
|
|
self.evaluate(v.initializer)
|
|
op = v.assign_add(1.)
|
|
|
|
def body_fn(i): # pylint: disable=invalid-name
|
|
with ops.control_dependencies([op]):
|
|
return i + 1
|
|
|
|
loop = while_loop_v2(lambda i: i < 1, body_fn, [0])
|
|
loop[0].op.run()
|
|
self.assertAllEqual(self.evaluate(v), 2.0)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMultipleLoopVarsBasic(self):
|
|
x = constant_op.constant(5.)
|
|
y = constant_op.constant(3.)
|
|
|
|
# x = 5.
|
|
# y = 3.
|
|
# while x < 45.:
|
|
# x = x * y
|
|
ret = while_loop_v2(
|
|
lambda v, _: v < 45.,
|
|
lambda v, w: (v * w, w), [x, y],
|
|
return_same_structure=False)
|
|
# ret = [x*y^2, y]
|
|
|
|
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
|
|
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
|
|
with self.cached_session():
|
|
self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
|
|
self.assertSequenceEqual(self.evaluate(grad), [9.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMultipleLoopNonscalarCond(self):
|
|
x = constant_op.constant([[5.]])
|
|
y = constant_op.constant(3.)
|
|
|
|
# x = 5.
|
|
# y = 3.
|
|
# while x < 45.:
|
|
# x = x * y
|
|
ret = while_loop_v2(
|
|
lambda v, _: v < 45.,
|
|
lambda v, w: (v * w, w), [x, y],
|
|
return_same_structure=False)
|
|
# ret == [x*y^2, y]
|
|
|
|
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
|
|
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
|
|
with self.cached_session():
|
|
self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
|
|
self.assertSequenceEqual(self.evaluate(grad), [9.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMultipleLoopVars(self):
|
|
x = constant_op.constant(5.)
|
|
y = constant_op.constant(3.)
|
|
|
|
# x = 5.
|
|
# y = 3.
|
|
# while x < 45.:
|
|
# x = x * y
|
|
# y = x + y
|
|
ret = while_loop_v2(
|
|
lambda v, _: v < 45.,
|
|
lambda v, w: (v * w, v + w), [x, y],
|
|
return_same_structure=False)
|
|
# ret = [y*x**2 + x*y**2, x*y + x + y]
|
|
|
|
gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2]
|
|
gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1]
|
|
gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1]
|
|
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
|
|
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
|
|
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
|
|
with self.cached_session():
|
|
self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
|
|
self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
|
|
self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
|
|
self.assertSequenceEqual(self.evaluate(gradx_2), [43.])
|
|
self.assertSequenceEqual(self.evaluate(grady_0), [55.])
|
|
self.assertSequenceEqual(self.evaluate(grady_1), [6.])
|
|
self.assertSequenceEqual(self.evaluate(grady_2), [61.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testGradientTape(self):
|
|
with backprop.GradientTape() as t:
|
|
x = constant_op.constant(2.)
|
|
t.watch(x)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 4., lambda v: v * v, [x],
|
|
return_same_structure=False) # x**2
|
|
grad = t.gradient(ret, x)
|
|
with self.cached_session() as sess:
|
|
self.assertAllEqual(sess.run(grad), 4.0)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMultipleWhileLoops(self):
|
|
x = constant_op.constant(2.)
|
|
ret1 = while_loop_v2(
|
|
lambda v: v < 4., lambda v: v * v, [x],
|
|
return_same_structure=False) # x**2
|
|
ret2 = while_loop_v2(
|
|
lambda v: v < 16., lambda v: v * v, [ret1],
|
|
return_same_structure=False) # x**4
|
|
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
|
|
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
|
with self.cached_session():
|
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
|
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
|
|
|
def testMultipleWhileLoopsWithFunc(self):
|
|
x = constant_op.constant(2.)
|
|
|
|
@def_function.function
|
|
def Fn():
|
|
ret1 = while_loop_v2(
|
|
lambda v: v < 4.,
|
|
lambda v: v * v, [x],
|
|
return_same_structure=False,
|
|
name="while_1") # x**2
|
|
ret2 = while_loop_v2(
|
|
lambda v: v < 16.,
|
|
lambda v: v * v, [x],
|
|
return_same_structure=False,
|
|
name="while_2") # x**4
|
|
return ret1, ret2
|
|
|
|
concrete_fn = Fn.get_concrete_function()
|
|
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
|
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
|
self.assertEqual(while_1.type, "StatelessWhile")
|
|
self.assertEqual(while_2.type, "StatelessWhile")
|
|
self.assertEmpty(while_1.control_inputs)
|
|
self.assertEmpty(while_2.control_inputs)
|
|
|
|
def testMultipleWhileLoopsGradStateless(self):
|
|
|
|
@def_function.function
|
|
def Fn():
|
|
x = constant_op.constant(2.)
|
|
with backprop.GradientTape() as tape:
|
|
tape.watch(x)
|
|
ret1 = while_loop_v2(
|
|
lambda v: v < 4.,
|
|
lambda v: v * v, [x],
|
|
return_same_structure=False,
|
|
name="while_1") # x**2
|
|
ret2 = while_loop_v2(
|
|
lambda v: v < 16.,
|
|
lambda v: v * v, [x],
|
|
return_same_structure=False,
|
|
name="while_2") # x**4
|
|
loss = ret1 + ret2
|
|
return tape.gradient(loss, x)
|
|
|
|
graph = Fn.get_concrete_function().graph
|
|
while_ops = [op for op in graph.get_operations() if "While" in op.type]
|
|
self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4,
|
|
"Must have exactly 4 StatelessWhile ops.")
|
|
for op in while_ops:
|
|
self.assertEmpty(op.control_inputs,
|
|
"{} should not have any control inputs".format(op.name))
|
|
|
|
def testMultipleWhileLoopsWithDeps(self):
|
|
x = variables.Variable(2.)
|
|
c = constant_op.constant(2.)
|
|
|
|
@def_function.function
|
|
def Fn():
|
|
|
|
def Body1(v):
|
|
x.assign(x)
|
|
return v * x
|
|
|
|
ret1 = while_loop_v2(
|
|
lambda v: v < 4.,
|
|
Body1, [c],
|
|
return_same_structure=False,
|
|
name="while_1") # 2x
|
|
|
|
def Body2(v):
|
|
x.assign(x)
|
|
return v * x * x
|
|
|
|
ret2 = while_loop_v2(
|
|
lambda v: v < 16.,
|
|
Body2, [c],
|
|
return_same_structure=False,
|
|
name="while_2") # 4x
|
|
return ret1, ret2
|
|
|
|
concrete_fn = Fn.get_concrete_function()
|
|
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
|
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
|
self.assertEqual(while_1.type, "While")
|
|
self.assertEqual(while_2.type, "While")
|
|
self.assertEmpty(while_1.control_inputs)
|
|
self.assertLen(while_2.control_inputs, 1)
|
|
self.assertIs(while_2.control_inputs[0], while_1)
|
|
|
|
def testMultipleWhileLoopsWithVarsDeps(self):
|
|
x1 = variables.Variable(2.)
|
|
x2 = variables.Variable(3.)
|
|
c = constant_op.constant(2.)
|
|
|
|
@def_function.function
|
|
def Fn():
|
|
|
|
def Body1(v):
|
|
x1.assign(x1)
|
|
return v * x1
|
|
|
|
ret1 = while_loop_v2(
|
|
lambda v: v < 4.,
|
|
Body1, [c],
|
|
return_same_structure=False,
|
|
name="while_1") # 2x
|
|
|
|
def Body2(v):
|
|
x1.assign(x1)
|
|
return v * x1 * x1
|
|
|
|
ret2 = while_loop_v2(
|
|
lambda v: v < 16.,
|
|
Body2, [c],
|
|
return_same_structure=False,
|
|
name="while_2") # 4x
|
|
|
|
def Body3(v):
|
|
x2.assign(x2)
|
|
return v * x2
|
|
|
|
ret3 = while_loop_v2(
|
|
lambda v: v < 4.,
|
|
Body3, [c],
|
|
return_same_structure=False,
|
|
name="while_3") # 3x
|
|
|
|
def Body4(v):
|
|
x2.assign(x2)
|
|
return v * x2 * x2
|
|
|
|
ret4 = while_loop_v2(
|
|
lambda v: v < 16.,
|
|
Body4, [c],
|
|
return_same_structure=False,
|
|
name="while_4") # 9x
|
|
ret5 = while_loop_v2(
|
|
lambda v: v < 16.,
|
|
lambda v: v * v, [c],
|
|
return_same_structure=False,
|
|
name="while_stateless") # x**2
|
|
return ret1, ret2, ret3, ret4, ret5
|
|
|
|
concrete_fn = Fn.get_concrete_function()
|
|
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
|
|
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
|
|
while_3 = concrete_fn.graph.get_operation_by_name("while_3")
|
|
while_4 = concrete_fn.graph.get_operation_by_name("while_4")
|
|
while_stateless = concrete_fn.graph.get_operation_by_name(
|
|
"while_stateless")
|
|
self.assertEqual(while_1.type, "While")
|
|
self.assertEqual(while_2.type, "While")
|
|
self.assertEqual(while_3.type, "While")
|
|
self.assertEqual(while_4.type, "While")
|
|
self.assertEqual(while_stateless.type, "StatelessWhile")
|
|
self.assertEmpty(while_1.control_inputs)
|
|
self.assertLen(while_2.control_inputs, 1)
|
|
self.assertIs(while_2.control_inputs[0], while_1)
|
|
self.assertEmpty(while_3.control_inputs)
|
|
self.assertLen(while_4.control_inputs, 1)
|
|
self.assertIs(while_4.control_inputs[0], while_3)
|
|
self.assertEmpty(while_stateless.control_inputs)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testDoubleDerivative(self):
|
|
x = constant_op.constant(2.)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8., lambda v: v**2, [x],
|
|
return_same_structure=False) # x**4
|
|
grad = gradients_impl.gradients(ret, [x]) # 4x**3
|
|
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
|
|
with self.cached_session():
|
|
self.assertEqual(self.evaluate(ret), 16.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
|
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
|
|
|
|
@test_util.run_v2_only
|
|
def testMultipleWhileLoopsEager(self):
|
|
|
|
@def_function.function
|
|
def Func():
|
|
x = constant_op.constant(2.)
|
|
ret1 = while_loop_v2(
|
|
lambda v: v < 4., lambda v: v * v, [x],
|
|
return_same_structure=False) # x**2
|
|
ret2 = while_loop_v2(
|
|
lambda v: v < 16.,
|
|
lambda v: v * v, [ret1],
|
|
return_same_structure=False) # x**4
|
|
grad = gradients_impl.gradients(ret2, [x])[0] # 4x**3
|
|
grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2
|
|
return grad, grad_grad
|
|
|
|
grad, grad_grad = Func()
|
|
self.assertEqual(grad.numpy(), 32.)
|
|
self.assertEqual(grad_grad.numpy(), 48.)
|
|
|
|
@test_util.run_v2_only
|
|
def testDoubleDerivativeEager(self):
|
|
|
|
@def_function.function
|
|
def Func():
|
|
x = constant_op.constant(2.)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8., lambda v: v**2, [x],
|
|
return_same_structure=False) # x**4
|
|
grad = gradients_impl.gradients(ret, [x])[0] # 4x**3
|
|
grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2
|
|
return ret, grad, grad_grad
|
|
|
|
ret, grad, grad_grad = Func()
|
|
self.assertEqual(ret.numpy(), 16.)
|
|
self.assertEqual(grad.numpy(), 32.)
|
|
self.assertEqual(grad_grad.numpy(), 48.)
|
|
|
|
def _testPruning(self):
|
|
x = constant_op.constant(1)
|
|
|
|
tensor_list = list_ops.empty_tensor_list(
|
|
element_dtype=x.dtype, element_shape=x.shape)
|
|
|
|
def Cond(x, tl):
|
|
del tl # Unused for Cond.
|
|
return x < 5
|
|
|
|
def Body(x, tl):
|
|
return x + 1, list_ops.tensor_list_push_back(tl, x)
|
|
|
|
outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
|
|
|
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
|
train_op.append(outputs[0])
|
|
|
|
g = GetOptimizedGraph()
|
|
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
|
|
# away, causing an extra Enter node.
|
|
enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
|
|
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
|
|
# Test that the TensorList is pruned out.
|
|
self.assertEmpty([
|
|
n for n in g.node if n.op == "Enter" and
|
|
n.attr["T"].type == dtypes.variant.as_datatype_enum
|
|
])
|
|
self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
|
|
|
|
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
|
|
train_op.append(stack)
|
|
g = GetOptimizedGraph()
|
|
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
|
|
# away, causing an extra Enter node.
|
|
enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
|
|
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
|
|
# Test that the TensorList is not pruned out.
|
|
self.assertNotEmpty([
|
|
n for n in g.node if n.op == "Enter" and
|
|
n.attr["T"].type == dtypes.variant.as_datatype_enum
|
|
])
|
|
self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testPruningV1(self):
|
|
self._testPruning()
|
|
|
|
@test_util.enable_control_flow_v2
|
|
@test_util.run_deprecated_v1
|
|
def testPruningV2(self):
|
|
self._testPruning()
|
|
|
|
def _testDoNotAccumulateInvariants(self):
|
|
push_op = ("TensorListPushBack"
|
|
if control_flow_v2_toggles.control_flow_v2_enabled() else
|
|
"StackPushV2")
|
|
|
|
# Tests that loop invariants, i.e., tensors that are "captured" by the
|
|
# while loop and not passed as loop variables are not accumulated in
|
|
# gradient computation.
|
|
v = constant_op.constant(5.0, name="v")
|
|
|
|
r = control_flow_ops.while_loop(
|
|
lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
|
|
|
|
output = gradients_impl.gradients(r, v)[0]
|
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
|
train_op.append(output)
|
|
|
|
g = GetOptimizedGraph()
|
|
# The gradient for v * x requires the value of both v and x. Since v is a
|
|
# loop invariant it is not accumulated so we have just one accumulator for
|
|
# x.
|
|
self.assertLen([n for n in g.node if n.op == push_op], 1)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testDoNotAccumulateInvariantsV1(self):
|
|
self._testDoNotAccumulateInvariants()
|
|
|
|
@test_util.run_deprecated_v1
|
|
@test_util.enable_control_flow_v2
|
|
def testDoNotAccumulateInvariantsV2(self):
|
|
self._testDoNotAccumulateInvariants()
|
|
|
|
@test_util.enable_control_flow_v2
|
|
@test_util.run_deprecated_v1
|
|
@test_util.enable_output_all_intermediates
|
|
def testPruningNested(self):
|
|
assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
|
|
x = constant_op.constant(0)
|
|
|
|
tensor_list = list_ops.empty_tensor_list(
|
|
element_dtype=x.dtype, element_shape=x.shape)
|
|
|
|
def Cond(x, tl):
|
|
del tl # Unused for Cond.
|
|
return x < 25
|
|
|
|
def Body(x, tl):
|
|
|
|
def InnerCond(inner_x, unused_outer_x, unused_tl):
|
|
return inner_x < 5
|
|
|
|
def InnerBody(inner_x, outer_x, tl):
|
|
return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x)
|
|
|
|
inner_x = constant_op.constant(0)
|
|
return control_flow_ops.while_loop(InnerCond, InnerBody,
|
|
[inner_x, x, tl])[1:]
|
|
|
|
outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
|
|
|
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
|
train_op.append(outputs[0])
|
|
|
|
g = GetOptimizedGraph()
|
|
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
|
|
# away, causing an extra Enter node.
|
|
# enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
|
|
# self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
|
|
# Test that the TensorList is pruned out.
|
|
self.assertEmpty([
|
|
n for n in g.node if n.op == "Enter" and
|
|
n.attr["T"].type == dtypes.variant.as_datatype_enum
|
|
])
|
|
self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
|
|
self.assertEmpty([n for n in g.node if n.op == "_While"])
|
|
|
|
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
|
|
train_op.append(stack)
|
|
g = GetOptimizedGraph()
|
|
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
|
|
# away, causing an extra Enter node.
|
|
# enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
|
|
# self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
|
|
# Test that the TensorList is not pruned out.
|
|
self.assertNotEmpty([
|
|
n for n in g.node if n.op == "Enter" and
|
|
n.attr["T"].type == dtypes.variant.as_datatype_enum
|
|
])
|
|
self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
|
|
|
|
@test_util.enable_control_flow_v2
|
|
@test_util.run_deprecated_v1
|
|
@test_util.enable_output_all_intermediates
|
|
def testPruningNested2(self):
|
|
assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
|
|
v = constant_op.constant(5.0, name="v")
|
|
|
|
p = array_ops.placeholder(dtype=dtypes.int32)
|
|
|
|
def MidBodyBuilder(iterations):
|
|
|
|
def MidBody(i, x):
|
|
r = control_flow_ops.while_loop(
|
|
lambda *_: True,
|
|
lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")),
|
|
(0, x),
|
|
maximum_iterations=iterations,
|
|
name="inner")
|
|
return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
|
|
|
|
return MidBody
|
|
|
|
def OuterBody(i, x):
|
|
iterations = array_ops.size(p, name="iterations")
|
|
return (i + 1, x + control_flow_ops.while_loop(
|
|
lambda *_: True,
|
|
MidBodyBuilder(iterations), (0, x),
|
|
maximum_iterations=iterations,
|
|
name="mid")[1])
|
|
|
|
def CreateWhileLoop():
|
|
with ops.device("/cpu:0"):
|
|
r = control_flow_ops.while_loop(
|
|
lambda *_: True,
|
|
OuterBody, (0, 1.0),
|
|
maximum_iterations=5,
|
|
name="outer")
|
|
return array_ops.identity(r[1])
|
|
|
|
output = CreateWhileLoop()
|
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
|
train_op.append(output)
|
|
|
|
g = GetOptimizedGraph()
|
|
self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
|
|
|
|
@test_util.enable_control_flow_v2
|
|
@test_util.run_deprecated_v1
|
|
@test_util.enable_output_all_intermediates
|
|
def testPruningNested3(self):
|
|
assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
|
|
v = constant_op.constant(5.0, name="v")
|
|
|
|
def CreateWhileLoop():
|
|
r = control_flow_ops.while_loop(
|
|
lambda _: True,
|
|
lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0],
|
|
maximum_iterations=5,
|
|
name="outer")
|
|
return array_ops.identity(r)
|
|
|
|
r = CreateWhileLoop()
|
|
output = gradients_impl.gradients(r, v)[0]
|
|
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
|
|
train_op.append(output)
|
|
|
|
g = GetOptimizedGraph()
|
|
self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
|
|
|
|
def _assertNotAccumulated(self, while_op, index):
|
|
"""Asserts that `while_op` input at `index` is not accumulated."""
|
|
body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
|
|
placeholder = body_graph.inputs[index]
|
|
self.assertNotIn("TensorListPushBack",
|
|
[op.type for op in placeholder.consumers()])
|
|
|
|
@test_util.enable_control_flow_v2
|
|
@test_util.run_deprecated_v1
|
|
@test_util.enable_output_all_intermediates
|
|
def testDoNotOutputLoopCounterAsIntermediate(self):
|
|
assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
|
|
v = constant_op.constant(5.0, name="v")
|
|
r = control_flow_ops.while_loop(
|
|
lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
|
|
# Skip over Identity.
|
|
while_op = r.op.inputs[0].op
|
|
self._assertNotAccumulated(while_op, 0)
|
|
|
|
@test_util.enable_control_flow_v2
|
|
@test_util.run_deprecated_v1
|
|
@test_util.enable_output_all_intermediates
|
|
def testDoNotOutputLoopInvariantAsIntermediate(self):
|
|
assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
|
|
|
|
def GetInputIndex(op, tensor):
|
|
for index, inp in enumerate(op.inputs):
|
|
if inp is tensor:
|
|
return index
|
|
|
|
v = constant_op.constant(5.0, name="v")
|
|
r = control_flow_ops.while_loop(
|
|
lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
|
|
# Skip over Identity.
|
|
while_op = r.op.inputs[0].op
|
|
# We can't directly use while_op.inputs.index() because Tensors are not
|
|
# hashable.
|
|
index = GetInputIndex(while_op, v)
|
|
self._assertNotAccumulated(while_op, index)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCaptureExternalTensorInCond(self):
|
|
x = constant_op.constant(2.)
|
|
y = constant_op.constant(1.)
|
|
ret = while_loop_v2(
|
|
lambda v: v + y < 9.,
|
|
lambda v: v * 3., [x],
|
|
return_same_structure=False)
|
|
grad = gradients_impl.gradients(ret, [x])
|
|
with self.cached_session():
|
|
self.assertEqual(self.evaluate(ret), 18.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [9.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCaptureExternalTensorInBody(self):
|
|
x = constant_op.constant(2.)
|
|
y = constant_op.constant(3.)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False)
|
|
grad = gradients_impl.gradients(ret, [x])
|
|
with self.cached_session():
|
|
self.assertEqual(self.evaluate(ret), 18.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [9.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testLoopWithTensorListPushBack(self):
|
|
x = constant_op.constant(2.)
|
|
|
|
tensor_list = list_ops.empty_tensor_list(
|
|
element_dtype=dtypes.float32, element_shape=ScalarShape())
|
|
|
|
def Cond(x, tl):
|
|
del tl # Unused for Cond.
|
|
return x < 5.
|
|
|
|
def Body(x, tl):
|
|
tl = list_ops.tensor_list_push_back(tl, x)
|
|
tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
|
|
return x**2., tl
|
|
|
|
ret = while_loop_v2(
|
|
Cond, Body, [x, tensor_list], return_same_structure=False)
|
|
grad = gradients_impl.gradients(ret[0], x)
|
|
with self.cached_session() as sess:
|
|
self.assertEqual(sess.run(ret[0]), 16.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testDuplicateAccumulator(self):
|
|
x = constant_op.constant(2.)
|
|
|
|
tensor_list = list_ops.empty_tensor_list(
|
|
element_dtype=dtypes.float32, element_shape=ScalarShape())
|
|
|
|
def Cond(x, tl):
|
|
del tl # Unused for Cond.
|
|
return x < 5.
|
|
|
|
def Body(x, tl):
|
|
# There is an accumulator in the loop already so we should not add
|
|
# another.
|
|
tl = list_ops.tensor_list_push_back(tl, x)
|
|
return x**2., tl
|
|
|
|
ret = while_loop_v2(
|
|
Cond, Body, [x, tensor_list], return_same_structure=False)
|
|
|
|
for op in ops.get_default_graph().get_operations():
|
|
if op.type == "While" or op.type == "StatelessWhile":
|
|
while_op = op
|
|
|
|
body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
|
|
x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0]
|
|
x_input_t = body_graph.inputs[x_input_index]
|
|
accumulator_count = len(
|
|
[c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
|
|
self.assertEqual(accumulator_count, 1)
|
|
|
|
grad = gradients_impl.gradients(ret[0], x)
|
|
with self.cached_session() as sess:
|
|
self.assertEqual(sess.run(ret[0]), 16.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
|
|
|
@parameterized.named_parameters(
|
|
("UnknownShape", None),
|
|
("PartiallyDefinedShape", [None, 2]),
|
|
("FullyDefinedShape", [1, 2]),
|
|
)
|
|
@test_util.run_deprecated_v1
|
|
def testAccumulatorElementShape(self, shape):
|
|
|
|
def MatchShape(actual_tensor_shape):
|
|
# Compare the shapes, treating None dimensions as equal. We do not
|
|
# directly check actual_tensor_shape and tf.TensorShape(shape) for
|
|
# equality because tf.Dimension.__eq__ returns None if either dimension is
|
|
# None.
|
|
if shape is None:
|
|
self.assertIsNone(actual_tensor_shape.dims)
|
|
else:
|
|
self.assertListEqual(actual_tensor_shape.as_list(), shape)
|
|
|
|
def GetAccumulatorForInputAtIndex(while_op, idx):
|
|
body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
|
|
y_input_t = body_graph.inputs[idx]
|
|
push_back_node = [c for c in y_input_t.consumers()
|
|
if c.type == "TensorListPushBack"][0]
|
|
output_idx = body_graph.outputs.index(push_back_node.outputs[0])
|
|
return while_op.outputs[output_idx]
|
|
|
|
x = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
|
|
y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
|
|
|
|
# Forward pass.
|
|
ret = while_loop_v2(lambda v, u: v < 8.,
|
|
lambda v, u: (math_ops.pow(v, u), u),
|
|
[x, y],
|
|
return_same_structure=True)
|
|
while_op = ret[0].op.inputs[0].op
|
|
# Gradient pass.
|
|
grad = gradients_impl.gradients(ret[0], x)
|
|
# Note: There is an Identity b/w grad[0] and the While op.
|
|
grad_while_op = grad[0].op.inputs[0].op
|
|
|
|
# Get the TensorList output of While op containing the accumulated values
|
|
# of y.
|
|
x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0]
|
|
output = GetAccumulatorForInputAtIndex(while_op, x_input_index)
|
|
_, val = list_ops.tensor_list_pop_back(output,
|
|
element_dtype=dtypes.float32)
|
|
MatchShape(val.shape)
|
|
|
|
# Take second derivative to generate intermediate grad_while_op outputs
|
|
gradients_impl.gradients(grad, x)
|
|
|
|
# Get the TensorList output of gradient While op containing the accumulated
|
|
# values of grad_x (note that grad_x is needed by the second derivative).
|
|
# grad_while_op.inputs:
|
|
grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0])
|
|
grad_output = GetAccumulatorForInputAtIndex(grad_while_op,
|
|
grad_output_index)
|
|
_, val = list_ops.tensor_list_pop_back(grad_output,
|
|
element_dtype=dtypes.float32)
|
|
MatchShape(val.shape)
|
|
|
|
def _createWhile(self, name):
|
|
"""Helper function testDefaultName."""
|
|
output = while_v2.while_loop(
|
|
lambda i: i < 3,
|
|
lambda i: i + 1, [constant_op.constant(0)],
|
|
return_same_structure=False)
|
|
while_op = output.op.inputs[0].op
|
|
self.assertEqual(while_op.type, "StatelessWhile")
|
|
return while_op
|
|
|
|
def testDefaultName(self):
|
|
with ops.Graph().as_default():
|
|
while_op = self._createWhile(None)
|
|
self.assertEqual(while_op.name, "while")
|
|
self.assertRegex(while_op.get_attr("cond").name, r"while_cond_\d*")
|
|
self.assertRegex(while_op.get_attr("body").name, r"while_body_\d*")
|
|
|
|
with ops.Graph().as_default():
|
|
with ops.name_scope("foo"):
|
|
while1_op = self._createWhile("")
|
|
self.assertEqual(while1_op.name, "foo/while")
|
|
self.assertRegex(while1_op.get_attr("cond").name, r"foo_while_cond_\d*")
|
|
self.assertRegex(while1_op.get_attr("body").name, r"foo_while_body_\d*")
|
|
|
|
while2_op = self._createWhile(None)
|
|
self.assertEqual(while2_op.name, "foo/while_1")
|
|
self.assertRegex(
|
|
while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*")
|
|
self.assertRegex(
|
|
while2_op.get_attr("body").name, r"foo_while_1_body_\d*")
|
|
|
|
@test_util.enable_control_flow_v2
|
|
@test_util.run_deprecated_v1
|
|
def testWhileAndTensorArray(self):
|
|
param = constant_op.constant(2.0)
|
|
y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
|
|
# map_fn uses TensorArray internally.
|
|
r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0)
|
|
grad = gradients_impl.gradients(r, param)[0]
|
|
self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r))
|
|
self.assertAllClose(21.0, self.evaluate(grad))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNestedWhile(self):
|
|
# Compute sum of geometric progression: n^0 + n^1 + ... + n^m
|
|
# We compute the pow using a while loop.
|
|
n = constant_op.constant(3.)
|
|
m = constant_op.constant(5.)
|
|
sum_of_powers = constant_op.constant(0.)
|
|
|
|
def Body(i, previous_sum):
|
|
prod = constant_op.constant(1.)
|
|
return i - 1., previous_sum + while_loop_v2(
|
|
lambda c, _: c > 0,
|
|
lambda c, v: (c - 1., v * n), [i, prod],
|
|
return_same_structure=False)[1]
|
|
|
|
result = while_loop_v2(
|
|
lambda i, _: i >= 0,
|
|
Body, [m, sum_of_powers],
|
|
return_same_structure=False)[1]
|
|
grad = gradients_impl.gradients(result, [n])
|
|
self.assertEqual(self.evaluate(result), 364.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [547.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNestedWhileWithLegacyDefun(self):
|
|
n = constant_op.constant(3.)
|
|
m = constant_op.constant(5.)
|
|
sum_of_powers = constant_op.constant(0.)
|
|
|
|
def Body(i, previous_sum):
|
|
prod = constant_op.constant(1.)
|
|
|
|
def InnerBodyWrapper(c, v):
|
|
|
|
@function.Defun(dtypes.float32, dtypes.float32)
|
|
def InnerBody(c, v):
|
|
return c - 1., v * n
|
|
|
|
results = InnerBody(c, v)
|
|
results[0].set_shape([])
|
|
results[1].set_shape([])
|
|
return results
|
|
|
|
return i - 1., previous_sum + while_loop_v2(
|
|
lambda c, _: c > 0,
|
|
InnerBodyWrapper, [i, prod],
|
|
return_same_structure=False)[1]
|
|
|
|
result = while_loop_v2(
|
|
lambda i, _: i >= 0,
|
|
Body, [m, sum_of_powers],
|
|
return_same_structure=False)[1]
|
|
grad = gradients_impl.gradients(result, [n])
|
|
self.assertEqual(self.evaluate(result), 364.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [547.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testIdentityNodeInBody(self):
|
|
|
|
def Body(v):
|
|
v = array_ops.identity(v)
|
|
v = array_ops.identity(v)
|
|
return v * v
|
|
|
|
x = constant_op.constant(2.)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8., Body, [x], return_same_structure=False)
|
|
grad = gradients_impl.gradients(ret, [x])
|
|
self.assertEqual(self.evaluate(ret), 16.)
|
|
self.assertSequenceEqual(self.evaluate(grad), [32.])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testForwardPassRewrite(self):
|
|
x = constant_op.constant(1.0, name="x")
|
|
output = while_v2.while_loop(lambda x: x < 10.0,
|
|
lambda x: x * 2.0,
|
|
[x])[0]
|
|
while_op = output.op.inputs[0].op
|
|
self.assertEqual(while_op.type, "StatelessWhile")
|
|
# outputs = [loop_counter, max_iters, x]
|
|
self.assertLen(while_op.outputs, 3)
|
|
|
|
gradients_impl.gradients(output, x)
|
|
# while_op should have been rewritten to output intermediates.
|
|
# outputs = [loop_counter, max_iters, x, x_accumulator]
|
|
self.assertLen(while_op.outputs, 4)
|
|
|
|
gradients_impl.gradients(output, x)
|
|
# Computing the gradient again shouldn't rewrite while_op again.
|
|
self.assertLen(while_op.outputs, 4)
|
|
|
|
@parameterized.named_parameters(
|
|
("RandomUniform", random_ops.random_uniform, [5, 3]),
|
|
("RandomNormal", random_ops.random_normal, [5, 3]),
|
|
("ParameterizedTruncatedNormal",
|
|
random_ops.parameterized_truncated_normal, [5, 3]),
|
|
("TruncatedNormal", random_ops.truncated_normal, [5, 3]),
|
|
("RandomGamma", random_gamma, [5, 3]),
|
|
("RandomPoissonV2", random_poisson_v2, [5, 3]),
|
|
("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]),
|
|
("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]),
|
|
)
|
|
@test_util.run_deprecated_v1
|
|
def testRandomOpsShape(self, random_fn, expected_shape):
|
|
shape = constant_op.constant([3])
|
|
|
|
def Body(i, u):
|
|
shape_extended = array_ops.concat([[5], shape], axis=0)
|
|
u = random_fn(shape_extended)
|
|
assert u.shape.as_list() == expected_shape, str(u.shape.as_list())
|
|
return i + 1, u
|
|
|
|
_, _ = while_loop_v2(
|
|
cond=lambda i, _: i < 3,
|
|
body=Body,
|
|
loop_vars=[
|
|
0,
|
|
array_ops.zeros(expected_shape, dtype=dtypes.float32),
|
|
])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testReshapeShape(self):
|
|
shape = constant_op.constant([3, 4])
|
|
|
|
def Body(i, u):
|
|
shape_extended = array_ops.concat([[5], shape], axis=0)
|
|
u = array_ops.reshape(u, [-1])
|
|
assert u.shape.as_list() == [60], str(u.shape.as_list())
|
|
u = array_ops.reshape(u, shape_extended)
|
|
assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
|
|
return i + 1, u
|
|
|
|
_, _ = while_loop_v2(
|
|
cond=lambda i, _: i < 3,
|
|
body=Body,
|
|
loop_vars=[
|
|
0,
|
|
array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
|
|
])
|
|
|
|
@parameterized.named_parameters(
|
|
("Zeros", array_ops.zeros),
|
|
("Ones", array_ops.ones),
|
|
("Fill", fill),
|
|
)
|
|
@test_util.run_deprecated_v1
|
|
def testFillOpsShape(self, fill_fn):
|
|
shape = constant_op.constant([3, 4])
|
|
|
|
def Body(i, u):
|
|
shape_extended = array_ops.concat([[5], shape], axis=0)
|
|
u = fill_fn(shape_extended)
|
|
assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
|
|
return i + 1, u
|
|
|
|
_, _ = while_loop_v2(
|
|
cond=lambda i, _: i < 3,
|
|
body=Body,
|
|
loop_vars=[
|
|
0,
|
|
array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
|
|
])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testExternalColocationGrad(self):
|
|
external_t = constant_op.constant(2.)
|
|
v0 = constant_op.constant(2.)
|
|
|
|
def Body(v):
|
|
with ops.colocate_with(external_t):
|
|
return v * v
|
|
|
|
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
|
|
grad = gradients_impl.gradients(ret, [v0])[0]
|
|
self.assertAllEqual(ret, 16.)
|
|
self.assertAllEqual(grad, 32.)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testDoNotAccumulateConstNodes(self):
|
|
|
|
def Body(v):
|
|
return v * 2.0
|
|
|
|
v0 = constant_op.constant(2.)
|
|
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
|
|
# Gradients computation has the side-effect of updating the forward op
|
|
# which is what we want to test.
|
|
unused_grad = gradients_impl.gradients(ret, [v0])[0]
|
|
# ret is separated from the `While` op by an `Identity` so we skip over
|
|
# that.
|
|
forward_while_op = ret.op.inputs[0].op
|
|
body_graph = while_v2._get_graph(forward_while_op, "body", "_body_graph")
|
|
push_back_nodes = [
|
|
o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
|
|
]
|
|
# Gradient of `Mul` requires accumulating both its inputs. But since one
|
|
# of those is a Const (2.0), we should have just one accumulator.
|
|
self.assertLen(push_back_nodes, 1)
|
|
|
|
def testDoNotAccumulateForwardTensorsForReductionOps(self):
|
|
|
|
@def_function.function
|
|
def Fn():
|
|
with backprop.GradientTape() as tape:
|
|
x = constant_op.constant(2.)
|
|
tape.watch(x)
|
|
|
|
def Body(i, x):
|
|
forward_graph = ops.get_default_graph()
|
|
|
|
@custom_gradient.custom_gradient
|
|
def SquaredWithZeroGrad(x):
|
|
|
|
def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name
|
|
del variables
|
|
gradient_graph = ops.get_default_graph()
|
|
shape = gen_array_ops.shape(x)
|
|
assert shape.graph is forward_graph
|
|
rank = gen_array_ops.rank(x)
|
|
assert rank.graph is forward_graph
|
|
size = gen_array_ops.size(x)
|
|
assert size.graph is forward_graph
|
|
zeros = array_ops.zeros(shape)
|
|
assert zeros.graph is gradient_graph
|
|
return zeros
|
|
|
|
return x * 2, Grad
|
|
|
|
return i + 1, SquaredWithZeroGrad(x)
|
|
|
|
_, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
|
|
grad = tape.gradient(result, x)
|
|
return grad
|
|
|
|
Fn()
|
|
|
|
def testDoNotAccumulateForwardTensorsForTensorListReductionOps(self):
|
|
|
|
@def_function.function
|
|
def Fn():
|
|
with backprop.GradientTape() as tape:
|
|
e = constant_op.constant(2.)
|
|
x = list_ops.empty_tensor_list(
|
|
element_dtype=dtypes.float32, element_shape=e.shape)
|
|
x = list_ops.tensor_list_push_back(x, e)
|
|
tape.watch(x)
|
|
|
|
def Body(i, x):
|
|
forward_graph = ops.get_default_graph()
|
|
|
|
@custom_gradient.custom_gradient
|
|
def IdentityWithZeroGrad(x):
|
|
|
|
def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name
|
|
del variables
|
|
gradient_graph = ops.get_default_graph()
|
|
shape = gen_list_ops.tensor_list_element_shape(
|
|
x, shape_type=dtypes.int32)
|
|
assert shape.graph is forward_graph
|
|
size = gen_list_ops.tensor_list_length(x)
|
|
assert size.graph is forward_graph
|
|
zeros = gen_list_ops.tensor_list_reserve(shape, size,
|
|
dtypes.float32)
|
|
assert zeros.graph is gradient_graph
|
|
return zeros
|
|
|
|
return x, Grad
|
|
|
|
return i + 1, IdentityWithZeroGrad(x)
|
|
|
|
_, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x])
|
|
ones_like = list_ops.tensor_list_from_tensor(
|
|
array_ops.ones_like(
|
|
list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)),
|
|
element_shape=tensor_shape.TensorShape([]))
|
|
grad = tape.gradient(result, x, output_gradients=[ones_like])
|
|
return grad
|
|
|
|
Fn()
|
|
|
|
@test_util.run_v2_only
|
|
def testInheritParentNameScope(self):
|
|
|
|
@def_function.function
|
|
def F():
|
|
with ops.name_scope("foo"):
|
|
|
|
def Cond(unused_i):
|
|
with ops.name_scope("cond"):
|
|
actual_name_scope = ops.get_name_scope()
|
|
expected_name_scope = "foo/while/cond"
|
|
assert actual_name_scope == expected_name_scope, (
|
|
"%s does not match %s" %
|
|
(actual_name_scope, expected_name_scope))
|
|
return False
|
|
|
|
def Body(i):
|
|
with ops.name_scope("body"):
|
|
actual_name_scope = ops.get_name_scope()
|
|
expected_name_scope = "foo/while/body"
|
|
assert actual_name_scope == expected_name_scope, (
|
|
"%s does not match %s" %
|
|
(actual_name_scope, expected_name_scope))
|
|
return i
|
|
|
|
return while_v2.while_loop(Cond, Body, [0.])
|
|
|
|
F()
|
|
|
|
@test_util.run_deprecated_v1 # Need to pass RunMetadata.
|
|
def testDisableLowering(self):
|
|
old = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE
|
|
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = True
|
|
with self.session() as sess:
|
|
x = constant_op.constant(2.)
|
|
ret = while_loop_v2(
|
|
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
|
|
|
|
opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
|
|
run_metadata = config_pb2.RunMetadata()
|
|
self.assertEqual(sess.run(ret, options=opts, run_metadata=run_metadata),
|
|
16)
|
|
for dev_stat in run_metadata.step_stats.dev_stats:
|
|
for ns in dev_stat.node_stats:
|
|
self.assertNotIn("switch", ns.node_name)
|
|
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old
|
|
|
|
def _runBasicWithConfig(self, config):
|
|
with ops.device("/cpu:0"):
|
|
x = constant_op.constant(0)
|
|
ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x])
|
|
with self.cached_session(config=config):
|
|
self.assertEqual(1000, self.evaluate(ret))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testRunKernelsInline(self):
|
|
config = config_pb2.ConfigProto()
|
|
config.inter_op_parallelism_threads = -1
|
|
self._runBasicWithConfig(config)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSingleThreadedExecution(self):
|
|
config = config_pb2.ConfigProto()
|
|
config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
|
|
self._runBasicWithConfig(config)
|
|
|
|
def testIsControlFlowGraph(self):
|
|
x = constant_op.constant(0)
|
|
|
|
@def_function.function
|
|
def F(c):
|
|
|
|
def Cond(i):
|
|
self.assertTrue(i.graph.is_control_flow_graph)
|
|
return i < 2
|
|
|
|
def Body(i):
|
|
i = i + 1
|
|
self.assertTrue(i.graph.is_control_flow_graph)
|
|
return i
|
|
|
|
return while_loop_v2(Cond, Body, [c])
|
|
|
|
ret, = F(x)
|
|
self.assertEqual(2, self.evaluate(ret))
|
|
|
|
def testImportFromSerializedWithFunctionInBody(self):
|
|
serialized = """node {
|
|
name: "Const"
|
|
op: "Const"
|
|
attr {
|
|
key: "dtype"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
attr {
|
|
key: "value"
|
|
value {
|
|
tensor {
|
|
dtype: DT_FLOAT
|
|
tensor_shape {
|
|
}
|
|
float_val: 1.0
|
|
}
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "while/maximum_iterations"
|
|
op: "Const"
|
|
attr {
|
|
key: "dtype"
|
|
value {
|
|
type: DT_INT32
|
|
}
|
|
}
|
|
attr {
|
|
key: "value"
|
|
value {
|
|
tensor {
|
|
dtype: DT_INT32
|
|
tensor_shape {
|
|
}
|
|
int_val: -1
|
|
}
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "while/loop_counter"
|
|
op: "Const"
|
|
attr {
|
|
key: "dtype"
|
|
value {
|
|
type: DT_INT32
|
|
}
|
|
}
|
|
attr {
|
|
key: "value"
|
|
value {
|
|
tensor {
|
|
dtype: DT_INT32
|
|
tensor_shape {
|
|
}
|
|
int_val: 0
|
|
}
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "while"
|
|
op: "StatelessWhile"
|
|
input: "while/loop_counter"
|
|
input: "while/maximum_iterations"
|
|
input: "Const"
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
list {
|
|
type: DT_INT32
|
|
type: DT_INT32
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "_lower_using_switch_merge"
|
|
value {
|
|
b: true
|
|
}
|
|
}
|
|
attr {
|
|
key: "_num_original_outputs"
|
|
value {
|
|
i: 3
|
|
}
|
|
}
|
|
attr {
|
|
key: "_read_only_resource_inputs"
|
|
value {
|
|
list {
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "body"
|
|
value {
|
|
func {
|
|
name: "while_body_822"
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "cond"
|
|
value {
|
|
func {
|
|
name: "while_cond_821"
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "output_shapes"
|
|
value {
|
|
list {
|
|
shape {
|
|
}
|
|
shape {
|
|
}
|
|
shape {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "parallel_iterations"
|
|
value {
|
|
i: 10
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "while/Identity"
|
|
op: "Identity"
|
|
input: "while"
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_INT32
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "while/Identity_1"
|
|
op: "Identity"
|
|
input: "while:1"
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_INT32
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "while/Identity_2"
|
|
op: "Identity"
|
|
input: "while:2"
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
}
|
|
library {
|
|
function {
|
|
signature {
|
|
name: "while_body_822"
|
|
input_arg {
|
|
name: "while_loop_counter"
|
|
type: DT_INT32
|
|
}
|
|
input_arg {
|
|
name: "while_maximum_iterations_0"
|
|
type: DT_INT32
|
|
}
|
|
input_arg {
|
|
name: "placeholder"
|
|
type: DT_FLOAT
|
|
}
|
|
output_arg {
|
|
name: "add"
|
|
type: DT_INT32
|
|
}
|
|
output_arg {
|
|
name: "while_maximum_iterations"
|
|
type: DT_INT32
|
|
}
|
|
output_arg {
|
|
name: "partitionedcall"
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
node_def {
|
|
name: "PartitionedCall"
|
|
op: "PartitionedCall"
|
|
input: "placeholder"
|
|
attr {
|
|
key: "Tin"
|
|
value {
|
|
list {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "Tout"
|
|
value {
|
|
list {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "_collective_manager_ids"
|
|
value {
|
|
list {
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "_read_only_resource_inputs"
|
|
value {
|
|
list {
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "config"
|
|
value {
|
|
s: ""
|
|
}
|
|
}
|
|
attr {
|
|
key: "config_proto"
|
|
value {
|
|
s: ""
|
|
}
|
|
}
|
|
attr {
|
|
key: "executor_type"
|
|
value {
|
|
s: ""
|
|
}
|
|
}
|
|
attr {
|
|
key: "f"
|
|
value {
|
|
func {
|
|
name: "__inference_f_841"
|
|
}
|
|
}
|
|
}
|
|
experimental_debug_info {
|
|
original_node_names: "PartitionedCall"
|
|
}
|
|
}
|
|
node_def {
|
|
name: "add/y"
|
|
op: "Const"
|
|
attr {
|
|
key: "dtype"
|
|
value {
|
|
type: DT_INT32
|
|
}
|
|
}
|
|
attr {
|
|
key: "value"
|
|
value {
|
|
tensor {
|
|
dtype: DT_INT32
|
|
tensor_shape {
|
|
}
|
|
int_val: 1
|
|
}
|
|
}
|
|
}
|
|
experimental_debug_info {
|
|
original_node_names: "add/y"
|
|
}
|
|
}
|
|
node_def {
|
|
name: "add_0"
|
|
op: "AddV2"
|
|
input: "while_loop_counter"
|
|
input: "add/y:output:0"
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_INT32
|
|
}
|
|
}
|
|
experimental_debug_info {
|
|
original_node_names: "add"
|
|
}
|
|
}
|
|
ret {
|
|
key: "add"
|
|
value: "add_0:z:0"
|
|
}
|
|
ret {
|
|
key: "partitionedcall"
|
|
value: "PartitionedCall:output:0"
|
|
}
|
|
ret {
|
|
key: "while_maximum_iterations"
|
|
value: "while_maximum_iterations_0"
|
|
}
|
|
arg_attr {
|
|
key: 0
|
|
value {
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
arg_attr {
|
|
key: 1
|
|
value {
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
arg_attr {
|
|
key: 2
|
|
value {
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
function {
|
|
signature {
|
|
name: "while_cond_821"
|
|
input_arg {
|
|
name: "while_loop_counter"
|
|
type: DT_INT32
|
|
}
|
|
input_arg {
|
|
name: "while_maximum_iterations"
|
|
type: DT_INT32
|
|
}
|
|
input_arg {
|
|
name: "placeholder"
|
|
type: DT_FLOAT
|
|
}
|
|
output_arg {
|
|
name: "less"
|
|
type: DT_BOOL
|
|
}
|
|
}
|
|
node_def {
|
|
name: "Less/y"
|
|
op: "Const"
|
|
attr {
|
|
key: "dtype"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
attr {
|
|
key: "value"
|
|
value {
|
|
tensor {
|
|
dtype: DT_FLOAT
|
|
tensor_shape {
|
|
}
|
|
float_val: 5.0
|
|
}
|
|
}
|
|
}
|
|
experimental_debug_info {
|
|
original_node_names: "Less/y"
|
|
}
|
|
}
|
|
node_def {
|
|
name: "Less"
|
|
op: "Less"
|
|
input: "placeholder"
|
|
input: "Less/y:output:0"
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
experimental_debug_info {
|
|
original_node_names: "Less"
|
|
}
|
|
}
|
|
ret {
|
|
key: "less"
|
|
value: "Less:z:0"
|
|
}
|
|
arg_attr {
|
|
key: 0
|
|
value {
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
arg_attr {
|
|
key: 1
|
|
value {
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
arg_attr {
|
|
key: 2
|
|
value {
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
function {
|
|
signature {
|
|
name: "__inference_f_841"
|
|
input_arg {
|
|
name: "mul_placeholder"
|
|
type: DT_FLOAT
|
|
}
|
|
output_arg {
|
|
name: "identity"
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
node_def {
|
|
name: "mul/y"
|
|
op: "Const"
|
|
attr {
|
|
key: "dtype"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
attr {
|
|
key: "value"
|
|
value {
|
|
tensor {
|
|
dtype: DT_FLOAT
|
|
tensor_shape {
|
|
}
|
|
float_val: 2.0
|
|
}
|
|
}
|
|
}
|
|
experimental_debug_info {
|
|
original_node_names: "mul/y"
|
|
}
|
|
}
|
|
node_def {
|
|
name: "mul"
|
|
op: "Mul"
|
|
input: "mul_placeholder"
|
|
input: "mul/y:output:0"
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
experimental_debug_info {
|
|
original_node_names: "mul"
|
|
}
|
|
}
|
|
node_def {
|
|
name: "Identity"
|
|
op: "Identity"
|
|
input: "mul:z:0"
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
experimental_debug_info {
|
|
original_node_names: "Identity"
|
|
}
|
|
}
|
|
ret {
|
|
key: "identity"
|
|
value: "Identity:output:0"
|
|
}
|
|
arg_attr {
|
|
key: 0
|
|
value {
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
versions {
|
|
producer: 399
|
|
min_consumer: 12
|
|
}
|
|
"""
|
|
# Code for generating above graph:
|
|
#
|
|
# def Body(i):
|
|
# @tf.function
|
|
# def f():
|
|
# return i * 2
|
|
# return f()
|
|
# tf.while_loop(lambda i: i < 5., Body, [tf.constant(1.)])
|
|
graph_def = graph_pb2.GraphDef()
|
|
text_format.Parse(serialized, graph_def)
|
|
@def_function.function
|
|
def F():
|
|
x, y = importer.import_graph_def(
|
|
graph_def, return_elements=["Const:0", "while:2"])
|
|
grad_out, = gradients_impl.gradients(y, x)
|
|
return grad_out
|
|
self.assertAllEqual(F(), 8.0)
|
|
|
|
def testIndexedSlicesInIncomingGrads(self):
|
|
@def_function.function
|
|
def F():
|
|
x = constant_op.constant([2.])
|
|
# Computes x^4
|
|
ret = while_loop_v2(
|
|
lambda _: True, lambda v: v * v, [x], return_same_structure=False,
|
|
maximum_iterations=2)
|
|
v = array_ops.gather(ret, [0])
|
|
return gradients_impl.gradients(v, [x])[0] # 4*x^3
|
|
self.assertAllEqual(self.evaluate(F()), [32.])
|
|
|
|
|
|
def ScalarShape():
|
|
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
|
|
|
|
|
def GetOptimizedGraph():
|
|
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
|
|
config = config_pb2.ConfigProto()
|
|
config.graph_options.rewrite_options.CopyFrom(
|
|
rewriter_config_pb2.RewriterConfig(
|
|
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
|
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
|
|
return tf_optimizer.OptimizeGraph(config, mg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|