Add an experimental API to control behavior of outputting all intermediates when using v2 control flow inside Keras models in graph.
PiperOrigin-RevId: 261394829
This commit is contained in:
parent
bd687e87db
commit
6fba1efce3
@ -1180,10 +1180,19 @@ Status GraphConstructor::Convert() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (src_node != nullptr && src_index >= src_node->num_outputs()) {
|
if (src_node != nullptr && src_index >= src_node->num_outputs()) {
|
||||||
return errors::InvalidArgument(
|
std::ostringstream out;
|
||||||
"Node '", node_def.name(), "': Connecting to invalid output ",
|
out << "Node '" << node_def.name() << "': Connecting to invalid output "
|
||||||
tensor_id.index(), " of source node ", tensor_id.node(),
|
<< tensor_id.index() << " of source node " << tensor_id.node()
|
||||||
" which has ", src_node->num_outputs(), " outputs");
|
<< " which has " << src_node->num_outputs() << " outputs.";
|
||||||
|
|
||||||
|
if (src_node->type_string() == "If" ||
|
||||||
|
src_node->type_string() == "StatelessIf" ||
|
||||||
|
src_node->type_string() == "While" ||
|
||||||
|
src_node->type_string() == "StatelessWhile") {
|
||||||
|
out << " Try using "
|
||||||
|
<< "tf.compat.v1.experimental.output_all_intermediates(True).";
|
||||||
|
}
|
||||||
|
return errors::InvalidArgument(out.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
|
inputs.emplace_back(string(tensor_id.node()), src_node, src_index);
|
||||||
|
|||||||
@ -2591,11 +2591,24 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":control_flow_util",
|
":control_flow_util",
|
||||||
|
":control_flow_util_v2",
|
||||||
":framework_ops",
|
":framework_ops",
|
||||||
":util",
|
":util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "control_flow_v2_toggles_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["ops/control_flow_v2_toggles_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":control_flow_v2_toggles",
|
||||||
|
":control_flow_util_v2",
|
||||||
|
":client_testlib",
|
||||||
|
":platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "cond_v2",
|
name = "cond_v2",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|||||||
@ -1007,6 +1007,31 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
grad = gradients_impl.gradients(r, [x])[0]
|
grad = gradients_impl.gradients(r, [x])[0]
|
||||||
self.assertAllEqual(1.0, self.evaluate(grad))
|
self.assertAllEqual(1.0, self.evaluate(grad))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.enable_control_flow_v2
|
||||||
|
def testCondComputeGradAfterSessRunFails(self):
|
||||||
|
with self.cached_session():
|
||||||
|
x = constant_op.constant(10.0, name="x")
|
||||||
|
pred = math_ops.less(1, 2)
|
||||||
|
|
||||||
|
def true_fn():
|
||||||
|
a = x * x
|
||||||
|
return a * a
|
||||||
|
|
||||||
|
def false_fn():
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
r = control_flow_ops.cond(pred, true_fn, false_fn)
|
||||||
|
|
||||||
|
self.assertAllEqual(r, 10000.)
|
||||||
|
grad = gradients_impl.gradients(r, [x])[0]
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors_impl.InvalidArgumentError,
|
||||||
|
r"Connecting to invalid output 1 of source node cond which has 1 "
|
||||||
|
r"outputs. Try using "
|
||||||
|
"tf.compat.v1.experimental.output_all_intermediates\(True\)."):
|
||||||
|
self.evaluate(grad)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.enable_output_all_intermediates
|
@test_util.enable_output_all_intermediates
|
||||||
def testCondComputeGradAfterSessRun(self):
|
def testCondComputeGradAfterSessRun(self):
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import control_flow_util
|
from tensorflow.python.ops import control_flow_util
|
||||||
|
from tensorflow.python.ops import control_flow_util_v2
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -64,3 +65,30 @@ def control_flow_v2_enabled(): # pylint: disable=invalid-name
|
|||||||
Note: v2 control flow is always enabled inside of tf.function.
|
Note: v2 control flow is always enabled inside of tf.function.
|
||||||
"""
|
"""
|
||||||
return control_flow_util.EnableControlFlowV2(ops.get_default_graph())
|
return control_flow_util.EnableControlFlowV2(ops.get_default_graph())
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(v1=["experimental.output_all_intermediates"])
|
||||||
|
def output_all_intermediates(state): # pylint: disable=invalid-name
|
||||||
|
"""Whether to output all intermediates from functional control flow ops.
|
||||||
|
|
||||||
|
The "default" behavior to is to output all intermediates when using v2 control
|
||||||
|
flow inside Keras models in graph mode (possibly inside Estimators). This is
|
||||||
|
needed to support taking gradients of v2 control flow. In graph mode, Keras
|
||||||
|
can sometimes freeze the forward graph before the gradient computation which
|
||||||
|
does not work for v2 control flow since it requires updating the forward ops
|
||||||
|
to output the needed intermediates. We work around this by proactively
|
||||||
|
outputting the needed intermediates when building the forward pass itself.
|
||||||
|
Ideally any such extra tensors should be pruned out at runtime. However, if
|
||||||
|
for any reason this doesn't work for you or if you have an infernce-only model
|
||||||
|
you can turn this behavior off using
|
||||||
|
`tf.compat.v1.experimental.output_all_intermediates(False)`.
|
||||||
|
|
||||||
|
If with the default behavior you are still seeing errors of the form
|
||||||
|
"Connecting to invalid output X of source node Y which has Z outputs" try
|
||||||
|
setting `tf.compat.v1.experimental.output_all_intermediates(True)` and
|
||||||
|
please file an issue at https://github.com/tensorflow/tensorflow/issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: True, False or None. None restores the default behavior.
|
||||||
|
"""
|
||||||
|
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = state # pylint: disable=protected-access
|
||||||
|
|||||||
44
tensorflow/python/ops/control_flow_v2_toggles_test.py
Normal file
44
tensorflow/python/ops/control_flow_v2_toggles_test.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for control_flow_v2_toggles.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.ops import control_flow_util_v2
|
||||||
|
from tensorflow.python.ops import control_flow_v2_toggles
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ControlFlowV2TogglesTest(test.TestCase):
|
||||||
|
|
||||||
|
def testOutputAllIntermediates(self):
|
||||||
|
self.assertIsNone(
|
||||||
|
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
|
||||||
|
control_flow_v2_toggles.output_all_intermediates(True)
|
||||||
|
self.assertTrue(
|
||||||
|
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
|
||||||
|
control_flow_v2_toggles.output_all_intermediates(False)
|
||||||
|
self.assertFalse(
|
||||||
|
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
|
||||||
|
control_flow_v2_toggles.output_all_intermediates(None)
|
||||||
|
self.assertIsNone(
|
||||||
|
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
googletest.main()
|
||||||
@ -4,4 +4,8 @@ tf_module {
|
|||||||
name: "function_executor_type"
|
name: "function_executor_type"
|
||||||
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "output_all_intermediates"
|
||||||
|
argspec: "args=[\'state\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -363,6 +363,8 @@ renames = {
|
|||||||
'tf.compat.v1.estimator.tpu.TPUEstimatorSpec',
|
'tf.compat.v1.estimator.tpu.TPUEstimatorSpec',
|
||||||
'tf.estimator.tpu.experimental.EmbeddingSpec':
|
'tf.estimator.tpu.experimental.EmbeddingSpec':
|
||||||
'tf.compat.v1.estimator.tpu.experimental.EmbeddingSpec',
|
'tf.compat.v1.estimator.tpu.experimental.EmbeddingSpec',
|
||||||
|
'tf.experimental.output_all_intermediates':
|
||||||
|
'tf.compat.v1.experimental.output_all_intermediates',
|
||||||
'tf.expm1':
|
'tf.expm1':
|
||||||
'tf.math.expm1',
|
'tf.math.expm1',
|
||||||
'tf.fake_quant_with_min_max_args':
|
'tf.fake_quant_with_min_max_args':
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user