cond_v2: use optional tensors instead of FakeParams.
The purpose of this change is to not waste memory allocating large FakeParams, which is especially important on GPU. This also adds a few other fixes needed to get optional variants working with cond_v2, including on GPU. PiperOrigin-RevId: 223260005
This commit is contained in:
parent
22ff3ec66e
commit
f6ee54c9b1
@ -1002,7 +1002,7 @@ Status Placer::Run() {
|
||||
int assigned_device = -1;
|
||||
|
||||
// Heuristic A application.
|
||||
if (IsGeneratorNode(node)) {
|
||||
if (IsGeneratorNode(node) && !node->out_edges().empty()) {
|
||||
const Node* output = (*node->out_edges().begin())->dst();
|
||||
int output_device_name = output->assigned_device_name_index();
|
||||
|
||||
|
@ -78,6 +78,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
|
||||
REGISTER_GPU_SWITCH(uint64);
|
||||
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
||||
|
||||
#undef REGISTER_CPU_SWITCH
|
||||
#undef REGISTER_CPU_REF_SWITCH
|
||||
|
@ -1972,6 +1972,15 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "optional_grad",
|
||||
srcs = ["ops/optional_grad.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sets",
|
||||
srcs = [
|
||||
@ -2151,6 +2160,7 @@ py_library(
|
||||
":graph_to_function_def",
|
||||
":pywrap_tensorflow",
|
||||
":util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:function",
|
||||
],
|
||||
)
|
||||
@ -2295,6 +2305,7 @@ py_library(
|
||||
":manip_ops",
|
||||
":math_grad",
|
||||
":math_ops",
|
||||
":optional_grad",
|
||||
":platform",
|
||||
":random_grad",
|
||||
":resource_variable_ops",
|
||||
|
@ -126,7 +126,7 @@ class CondV2Test(test.TestCase):
|
||||
self.assertEqual(sess.run(out, {pred: False}), (2.0,))
|
||||
|
||||
def _createCond(self, name):
|
||||
"""Helper function for testDefaultName."""
|
||||
"""Creates a cond_v2 call and returns the output tensor and the cond op."""
|
||||
pred = constant_op.constant(True, name="pred")
|
||||
x = constant_op.constant(1.0, name="x")
|
||||
|
||||
@ -139,11 +139,11 @@ class CondV2Test(test.TestCase):
|
||||
output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
|
||||
cond_op = output.op.inputs[0].op
|
||||
self.assertEqual(cond_op.type, "If")
|
||||
return cond_op
|
||||
return output, cond_op
|
||||
|
||||
def testDefaultName(self):
|
||||
with ops.Graph().as_default():
|
||||
cond_op = self._createCond(None)
|
||||
_, cond_op = self._createCond(None)
|
||||
self.assertEqual(cond_op.name, "cond")
|
||||
self.assertRegexpMatches(
|
||||
cond_op.get_attr("then_branch").name, r"cond_true_\d*")
|
||||
@ -152,14 +152,14 @@ class CondV2Test(test.TestCase):
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with ops.name_scope("foo"):
|
||||
cond1_op = self._createCond("")
|
||||
_, cond1_op = self._createCond("")
|
||||
self.assertEqual(cond1_op.name, "foo/cond")
|
||||
self.assertRegexpMatches(
|
||||
cond1_op.get_attr("then_branch").name, r"foo_cond_true_\d*")
|
||||
self.assertRegexpMatches(
|
||||
cond1_op.get_attr("else_branch").name, r"foo_cond_false_\d*")
|
||||
|
||||
cond2_op = self._createCond(None)
|
||||
_, cond2_op = self._createCond(None)
|
||||
self.assertEqual(cond2_op.name, "foo/cond_1")
|
||||
self.assertRegexpMatches(
|
||||
cond2_op.get_attr("then_branch").name, r"foo_cond_1_true_\d*")
|
||||
@ -612,11 +612,11 @@ class CondV2Test(test.TestCase):
|
||||
def testLowering(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
out_cond = self._createCond("cond")
|
||||
cond_output, _ = self._createCond("cond")
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(out_cond, options=run_options, run_metadata=run_metadata)
|
||||
sess.run(cond_output, options=run_options, run_metadata=run_metadata)
|
||||
|
||||
# If lowering was enabled, there should be a `Switch` node
|
||||
switch_found = any(
|
||||
@ -641,12 +641,12 @@ class CondV2Test(test.TestCase):
|
||||
# Build the cond_v2 in an XLA context
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
out_cond = self._createCond("cond")
|
||||
cond_output, _ = self._createCond("cond")
|
||||
xla_context.Exit()
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(out_cond, options=run_options, run_metadata=run_metadata)
|
||||
sess.run(cond_output, options=run_options, run_metadata=run_metadata)
|
||||
|
||||
# Lowering disabled in XLA, there should be no `Switch` node
|
||||
switch_found = any(
|
||||
|
@ -446,8 +446,7 @@ class ControlFlowTest(test.TestCase):
|
||||
g = gradients_impl.gradients(y, x)[0]
|
||||
|
||||
self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
|
||||
# TODO(b/119791601): Enable this.
|
||||
# self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
|
||||
self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
|
||||
|
||||
@test_util.disable_control_flow_v2("b/113293074")
|
||||
def testCondIndexedSlicesDifferentTypes(self):
|
||||
@ -2168,11 +2167,8 @@ class ControlFlowTest(test.TestCase):
|
||||
self.assertAllClose(512.0, self.evaluate(r))
|
||||
|
||||
def testNestedWhileCondWhileGrad(self):
|
||||
if control_flow_ops.ENABLE_WHILE_V2 and test_util.is_gpu_available():
|
||||
self.skipTest("b/118459209")
|
||||
self._testNestedWhileCondWhileGrad(use_gpu=False)
|
||||
|
||||
@test_util.disable_control_flow_v2("b/118459209")
|
||||
def testNestedWhileCondWhileGradGpu(self):
|
||||
self._testNestedWhileCondWhileGrad(use_gpu=True)
|
||||
|
||||
|
@ -30,7 +30,9 @@ from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import function_def_to_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_util_v2 as util
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_functional_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -149,7 +151,9 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
|
||||
true_inputs, false_inputs)
|
||||
|
||||
# Add all intermediate tensors as function outputs so they're available for
|
||||
# the gradient computation.
|
||||
# the gradient computation. Since the outputs of the two functions must match,
|
||||
# we wrap all the intermediates in optionals. Each intermediate output will
|
||||
# have a value iff its corresponding branch is taken.
|
||||
|
||||
true_intermediates = _get_intermediates(true_graph)
|
||||
false_intermediates = _get_intermediates(false_graph)
|
||||
@ -157,12 +161,28 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
|
||||
# Save the original number of outputs to return to the caller.
|
||||
num_cond_outputs = len(true_graph.outputs)
|
||||
|
||||
# Make the number/type of new intermediate outputs match.
|
||||
extra_true_outputs, extra_false_outputs = _pad_params(
|
||||
true_graph, false_graph, true_intermediates, false_intermediates)
|
||||
if control_flow_util.InXlaContext(ops.get_default_graph()):
|
||||
# XLA does not yet support optionals, so output intermediates directly and
|
||||
# make them match via FakeParams, which can be converted to zeros in XLA.
|
||||
# TODO(skyewm,jpienaar): can XLA support optionals?
|
||||
extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
|
||||
true_graph, false_graph, true_intermediates, false_intermediates)
|
||||
else:
|
||||
# Wrap intermediates in optionals.
|
||||
wrapped_true_intermediates = _wrap_intermediates(true_graph,
|
||||
true_intermediates)
|
||||
wrapped_false_intermediates = _wrap_intermediates(false_graph,
|
||||
false_intermediates)
|
||||
|
||||
# Make outputs match by adding none optionals.
|
||||
extra_true_outputs, extra_false_outputs = _make_intermediates_match(
|
||||
true_graph, false_graph,
|
||||
wrapped_true_intermediates, wrapped_false_intermediates)
|
||||
|
||||
true_graph.outputs.extend(extra_true_outputs)
|
||||
false_graph.outputs.extend(extra_false_outputs)
|
||||
# TODO(skyewm): somehow indicate it's a bug if this fails.
|
||||
_check_same_outputs(true_graph, false_graph)
|
||||
|
||||
# Create the If op.
|
||||
tensors = gen_functional_ops._if( # pylint: disable=protected-access
|
||||
@ -175,7 +195,8 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
|
||||
name=name)
|
||||
|
||||
# TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
|
||||
util.maybe_set_lowering_attr(tensors[0].op)
|
||||
if_op = tensors[0].op
|
||||
util.maybe_set_lowering_attr(if_op)
|
||||
|
||||
# Return identities for each output of the If op, rather than the output of
|
||||
# the If op directly. This makes pruning work if the output of cond() is
|
||||
@ -187,6 +208,9 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
|
||||
# correct output structure
|
||||
tensors = [array_ops.identity(t) for t in tensors]
|
||||
|
||||
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||
if_op.graph.prevent_fetching(if_op)
|
||||
|
||||
return tensors[:num_cond_outputs]
|
||||
|
||||
|
||||
@ -278,7 +302,7 @@ def _create_grad_func(func_graph, grads, name):
|
||||
return func_graph_module.func_graph_from_py_func(
|
||||
name,
|
||||
lambda: _grad_fn(func_graph, grads), [], {},
|
||||
func_graph=util.CondBranchFuncGraph(name, read_only_collections=False))
|
||||
func_graph=_CondGradFuncGraph(name, func_graph))
|
||||
|
||||
|
||||
def _resolve_grad_inputs(cond_graph, grad_graph):
|
||||
@ -360,28 +384,39 @@ def _separate_unique_inputs(true_inputs, false_inputs):
|
||||
return list(shared_inputs), list(true_only_inputs), list(false_only_inputs)
|
||||
|
||||
|
||||
def _pad_params(true_graph, false_graph, true_params, false_params):
|
||||
"""Returns new param lists that have matching signatures.
|
||||
def _make_intermediates_match(true_graph, false_graph,
|
||||
true_optionals, false_optionals):
|
||||
"""Returns new optionals lists that have matching signatures.
|
||||
|
||||
This is done by mirroring each param list in the other using dummy params.
|
||||
There is no merging of params.
|
||||
This is done by mirroring each list in the other using none optionals.
|
||||
There is no merging of like optionals.
|
||||
|
||||
Args:
|
||||
true_graph: FuncGraph
|
||||
false_graph: FuncGraph
|
||||
true_params: a list of Tensors from true_graph
|
||||
false_params: a list of Tensors from false_graph
|
||||
true_optionals: a list of optional Tensors from true_graph
|
||||
false_optionals: a list of optional Tensors from false_graph
|
||||
|
||||
Returns:
|
||||
A new list of Tensors in true_graph and a new list of Tensors in
|
||||
false_graph. The two lists have the same number of Tensors, with matching
|
||||
types and shapes across the lists.
|
||||
false_graph. The two lists have the same number of Tensors, all of which
|
||||
will be optionals of the same shape/type.
|
||||
"""
|
||||
new_true_params = (true_params +
|
||||
_create_dummy_params(true_graph, false_params))
|
||||
new_false_inputs = (_create_dummy_params(false_graph, true_params)
|
||||
+ false_params)
|
||||
return new_true_params, new_false_inputs
|
||||
new_true_optionals = (true_optionals +
|
||||
_create_none_optionals(true_graph, false_optionals))
|
||||
new_false_optionals = (_create_none_optionals(false_graph, true_optionals)
|
||||
+ false_optionals)
|
||||
return new_true_optionals, new_false_optionals
|
||||
|
||||
|
||||
def _make_intermediates_match_xla(true_graph, false_graph, true_intermediates,
|
||||
false_intermediates):
|
||||
"""Like _make_intermediates_match but for the XLA case."""
|
||||
new_true_intermediates = (true_intermediates +
|
||||
_create_fakeparams(true_graph, false_intermediates))
|
||||
new_false_intermediates = (_create_fakeparams(false_graph, true_intermediates)
|
||||
+ false_intermediates)
|
||||
return new_true_intermediates, new_false_intermediates
|
||||
|
||||
|
||||
def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
|
||||
@ -416,11 +451,11 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
|
||||
true_graph.inputs = (
|
||||
[true_input_to_param[t] for t in shared_inputs] +
|
||||
[true_input_to_param[t] for t in true_only_inputs] +
|
||||
_create_dummy_params(true_graph, false_only_inputs))
|
||||
_create_dummy_inputs(true_graph, false_only_inputs))
|
||||
|
||||
false_graph.inputs = (
|
||||
[false_input_to_param[t] for t in shared_inputs] +
|
||||
_create_dummy_params(false_graph, true_only_inputs) +
|
||||
_create_dummy_inputs(false_graph, true_only_inputs) +
|
||||
[false_input_to_param[t] for t in false_only_inputs])
|
||||
|
||||
# Rewrite the FuncGraphs' state to reflect the new inputs.
|
||||
@ -432,7 +467,12 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
|
||||
return new_inputs
|
||||
|
||||
|
||||
def _create_dummy_params(func_graph, template_tensors):
|
||||
def _wrap_intermediates(func_graph, intermediates):
|
||||
with func_graph.as_default():
|
||||
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
|
||||
|
||||
|
||||
def _create_dummy_inputs(func_graph, template_tensors):
|
||||
"""Creates tensors in func_graph to represent template_tensors.
|
||||
|
||||
Args:
|
||||
@ -442,6 +482,27 @@ def _create_dummy_params(func_graph, template_tensors):
|
||||
Returns:
|
||||
A list of tensors in func_graph.
|
||||
"""
|
||||
with func_graph.as_default():
|
||||
return [array_ops.placeholder(t.dtype, shape=t.shape)
|
||||
for t in template_tensors]
|
||||
|
||||
|
||||
def _create_none_optionals(func_graph, template_tensors):
|
||||
"""Creates none optionals in func_graph to represent template_tensors.
|
||||
|
||||
Args:
|
||||
func_graph: FuncGraph.
|
||||
template_tensors: a list of tensors in func_graph.
|
||||
|
||||
Returns:
|
||||
A list of tensors in func_graph.
|
||||
"""
|
||||
with func_graph.as_default():
|
||||
return [gen_dataset_ops.optional_none() for _ in template_tensors]
|
||||
|
||||
|
||||
def _create_fakeparams(func_graph, template_tensors):
|
||||
"""Create FakeParams for the XLA case."""
|
||||
with func_graph.as_default():
|
||||
return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape)
|
||||
for t in template_tensors]
|
||||
@ -474,3 +535,38 @@ def _get_output_shapes(true_graph_outputs, false_graph_outputs):
|
||||
for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
|
||||
]
|
||||
return output_shapes
|
||||
|
||||
|
||||
class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
||||
"""FuncGraph for the gradient function of the branch of an If op.
|
||||
|
||||
Handles unwrapping optional intermediate values that are captured by the
|
||||
gradient computation.
|
||||
"""
|
||||
|
||||
def __init__(self, name, forward_graph):
|
||||
super(_CondGradFuncGraph, self).__init__(name, read_only_collections=False)
|
||||
self._forward_graph = forward_graph
|
||||
|
||||
def _capture_helper(self, tensor, name):
|
||||
if (tensor.graph is not self._forward_graph or
|
||||
tensor in self._forward_graph.inputs or
|
||||
tensor in self._forward_graph.outputs):
|
||||
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
|
||||
|
||||
# 'tensor' is an intermediate in the forward graph. We find the corresonding
|
||||
# optional tensor, which is output from the If op, and capture it as
|
||||
# normal. We then unwrap the captured optional value to get the raw
|
||||
# intermediate value.
|
||||
for consumer in tensor.consumers():
|
||||
if (consumer.type == "OptionalFromValue"
|
||||
and consumer.outputs[0] in self._forward_graph.outputs):
|
||||
optional = consumer.outputs[0]
|
||||
captured_optional = super(_CondGradFuncGraph, self)._capture_helper(
|
||||
optional, name)
|
||||
return gen_dataset_ops.optional_get_value(
|
||||
captured_optional, [tensor.dtype], [tensor.shape])[0]
|
||||
raise ValueError(
|
||||
"Couldn't find OptionalFromValue consumer for tensor '%s'.\n"
|
||||
"This is an internal bug, please report at "
|
||||
"https://github.com/tensorflow/tensorflow/issues." % tensor.name)
|
||||
|
@ -38,6 +38,11 @@ def IsInXLAContext(op):
|
||||
return GetContainingXLAContext(ctxt) is not None
|
||||
|
||||
|
||||
def InXlaContext(graph):
|
||||
ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
|
||||
return GetContainingXLAContext(ctxt) is not None
|
||||
|
||||
|
||||
def IsInWhileLoop(op):
|
||||
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
|
||||
return GetContainingWhileContext(ctxt) is not None
|
||||
|
@ -49,6 +49,7 @@ from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import optional_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import random_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
|
33
tensorflow/python/ops/optional_grad.py
Normal file
33
tensorflow/python/ops/optional_grad.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Gradient functions for optional ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
|
||||
|
||||
@ops.RegisterGradient("OptionalFromValue")
|
||||
def _OptionalFromValueGrad(op, grad):
|
||||
return gen_dataset_ops.optional_get_value(
|
||||
grad, [t.dtype for t in op.inputs], [t.shape for t in op.inputs])
|
||||
|
||||
|
||||
@ops.RegisterGradient("OptionalGetValue")
|
||||
def _OptionalGetValueGrad(unused_op, *grads):
|
||||
return gen_dataset_ops.optional_from_value(grads)
|
Loading…
x
Reference in New Issue
Block a user