Add an experimental API to control behavior of outputting all intermediates when using v2 control flow inside Keras models in graph.
PiperOrigin-RevId: 261394829
This commit is contained in:
parent
bd687e87db
commit
6fba1efce3
@ -1180,10 +1180,19 @@ Status GraphConstructor::Convert() {
|
||||
}
|
||||
|
||||
if (src_node != nullptr && src_index >= src_node->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"Node '", node_def.name(), "': Connecting to invalid output ",
|
||||
tensor_id.index(), " of source node ", tensor_id.node(),
|
||||
" which has ", src_node->num_outputs(), " outputs");
|
||||
std::ostringstream out;
|
||||
out << "Node '" << node_def.name() << "': Connecting to invalid output "
|
||||
<< tensor_id.index() << " of source node " << tensor_id.node()
|
||||
<< " which has " << src_node->num_outputs() << " outputs.";
|
||||
|
||||
if (src_node->type_string() == "If" ||
|
||||
src_node->type_string() == "StatelessIf" ||
|
||||
src_node->type_string() == "While" ||
|
||||
src_node->type_string() == "StatelessWhile") {
|
||||
out << " Try using "
|
||||
<< "tf.compat.v1.experimental.output_all_intermediates(True).";
|
||||
}
|
||||
return errors::InvalidArgument(out.str());
|
||||
}
|
||||
|
||||
inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
|
||||
|
@ -2591,11 +2591,24 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":control_flow_util",
|
||||
":control_flow_util_v2",
|
||||
":framework_ops",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "control_flow_v2_toggles_test",
|
||||
size = "small",
|
||||
srcs = ["ops/control_flow_v2_toggles_test.py"],
|
||||
additional_deps = [
|
||||
":control_flow_v2_toggles",
|
||||
":control_flow_util_v2",
|
||||
":client_testlib",
|
||||
":platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cond_v2",
|
||||
srcs = [
|
||||
|
@ -1007,6 +1007,31 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
||||
grad = gradients_impl.gradients(r, [x])[0]
|
||||
self.assertAllEqual(1.0, self.evaluate(grad))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.enable_control_flow_v2
|
||||
def testCondComputeGradAfterSessRunFails(self):
|
||||
with self.cached_session():
|
||||
x = constant_op.constant(10.0, name="x")
|
||||
pred = math_ops.less(1, 2)
|
||||
|
||||
def true_fn():
|
||||
a = x * x
|
||||
return a * a
|
||||
|
||||
def false_fn():
|
||||
return x * x
|
||||
|
||||
r = control_flow_ops.cond(pred, true_fn, false_fn)
|
||||
|
||||
self.assertAllEqual(r, 10000.)
|
||||
grad = gradients_impl.gradients(r, [x])[0]
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.InvalidArgumentError,
|
||||
r"Connecting to invalid output 1 of source node cond which has 1 "
|
||||
r"outputs. Try using "
|
||||
"tf.compat.v1.experimental.output_all_intermediates\(True\)."):
|
||||
self.evaluate(grad)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.enable_output_all_intermediates
|
||||
def testCondComputeGradAfterSessRun(self):
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -64,3 +65,30 @@ def control_flow_v2_enabled(): # pylint: disable=invalid-name
|
||||
Note: v2 control flow is always enabled inside of tf.function.
|
||||
"""
|
||||
return control_flow_util.EnableControlFlowV2(ops.get_default_graph())
|
||||
|
||||
|
||||
@tf_export(v1=["experimental.output_all_intermediates"])
|
||||
def output_all_intermediates(state): # pylint: disable=invalid-name
|
||||
"""Whether to output all intermediates from functional control flow ops.
|
||||
|
||||
The "default" behavior to is to output all intermediates when using v2 control
|
||||
flow inside Keras models in graph mode (possibly inside Estimators). This is
|
||||
needed to support taking gradients of v2 control flow. In graph mode, Keras
|
||||
can sometimes freeze the forward graph before the gradient computation which
|
||||
does not work for v2 control flow since it requires updating the forward ops
|
||||
to output the needed intermediates. We work around this by proactively
|
||||
outputting the needed intermediates when building the forward pass itself.
|
||||
Ideally any such extra tensors should be pruned out at runtime. However, if
|
||||
for any reason this doesn't work for you or if you have an infernce-only model
|
||||
you can turn this behavior off using
|
||||
`tf.compat.v1.experimental.output_all_intermediates(False)`.
|
||||
|
||||
If with the default behavior you are still seeing errors of the form
|
||||
"Connecting to invalid output X of source node Y which has Z outputs" try
|
||||
setting `tf.compat.v1.experimental.output_all_intermediates(True)` and
|
||||
please file an issue at https://github.com/tensorflow/tensorflow/issues.
|
||||
|
||||
Args:
|
||||
state: True, False or None. None restores the default behavior.
|
||||
"""
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = state # pylint: disable=protected-access
|
||||
|
44
tensorflow/python/ops/control_flow_v2_toggles_test.py
Normal file
44
tensorflow/python/ops/control_flow_v2_toggles_test.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for control_flow_v2_toggles.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ControlFlowV2TogglesTest(test.TestCase):
|
||||
|
||||
def testOutputAllIntermediates(self):
|
||||
self.assertIsNone(
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
|
||||
control_flow_v2_toggles.output_all_intermediates(True)
|
||||
self.assertTrue(
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
|
||||
control_flow_v2_toggles.output_all_intermediates(False)
|
||||
self.assertFalse(
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
|
||||
control_flow_v2_toggles.output_all_intermediates(None)
|
||||
self.assertIsNone(
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
@ -4,4 +4,8 @@ tf_module {
|
||||
name: "function_executor_type"
|
||||
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "output_all_intermediates"
|
||||
argspec: "args=[\'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -363,6 +363,8 @@ renames = {
|
||||
'tf.compat.v1.estimator.tpu.TPUEstimatorSpec',
|
||||
'tf.estimator.tpu.experimental.EmbeddingSpec':
|
||||
'tf.compat.v1.estimator.tpu.experimental.EmbeddingSpec',
|
||||
'tf.experimental.output_all_intermediates':
|
||||
'tf.compat.v1.experimental.output_all_intermediates',
|
||||
'tf.expm1':
|
||||
'tf.math.expm1',
|
||||
'tf.fake_quant_with_min_max_args':
|
||||
|
Loading…
Reference in New Issue
Block a user