cond_v2: use optional tensors instead of FakeParams.

The purpose of this change is to not waste memory allocating large
FakeParams, which is especially important on GPU.

This also adds a few other fixes needed to get optional variants
working with cond_v2, including on GPU.

PiperOrigin-RevId: 223260005
This commit is contained in:
Skye Wanderman-Milne 2018-11-28 16:49:14 -08:00 committed by TensorFlower Gardener
parent 22ff3ec66e
commit f6ee54c9b1
9 changed files with 180 additions and 37 deletions

View File

@ -1002,7 +1002,7 @@ Status Placer::Run() {
int assigned_device = -1;
// Heuristic A application.
if (IsGeneratorNode(node)) {
if (IsGeneratorNode(node) && !node->out_edges().empty()) {
const Node* output = (*node->out_edges().begin())->dst();
int output_device_name = output->assigned_device_name_index();

View File

@ -78,6 +78,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
REGISTER_GPU_SWITCH(uint64);
TF_CALL_variant(REGISTER_GPU_SWITCH);
#undef REGISTER_CPU_SWITCH
#undef REGISTER_CPU_REF_SWITCH

View File

@ -1972,6 +1972,15 @@ py_library(
],
)
py_library(
name = "optional_grad",
srcs = ["ops/optional_grad.py"],
srcs_version = "PY2AND3",
deps = [
":framework_ops",
],
)
py_library(
name = "sets",
srcs = [
@ -2151,6 +2160,7 @@ py_library(
":graph_to_function_def",
":pywrap_tensorflow",
":util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:function",
],
)
@ -2295,6 +2305,7 @@ py_library(
":manip_ops",
":math_grad",
":math_ops",
":optional_grad",
":platform",
":random_grad",
":resource_variable_ops",

View File

@ -126,7 +126,7 @@ class CondV2Test(test.TestCase):
self.assertEqual(sess.run(out, {pred: False}), (2.0,))
def _createCond(self, name):
"""Helper function for testDefaultName."""
"""Creates a cond_v2 call and returns the output tensor and the cond op."""
pred = constant_op.constant(True, name="pred")
x = constant_op.constant(1.0, name="x")
@ -139,11 +139,11 @@ class CondV2Test(test.TestCase):
output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
cond_op = output.op.inputs[0].op
self.assertEqual(cond_op.type, "If")
return cond_op
return output, cond_op
def testDefaultName(self):
with ops.Graph().as_default():
cond_op = self._createCond(None)
_, cond_op = self._createCond(None)
self.assertEqual(cond_op.name, "cond")
self.assertRegexpMatches(
cond_op.get_attr("then_branch").name, r"cond_true_\d*")
@ -152,14 +152,14 @@ class CondV2Test(test.TestCase):
with ops.Graph().as_default():
with ops.name_scope("foo"):
cond1_op = self._createCond("")
_, cond1_op = self._createCond("")
self.assertEqual(cond1_op.name, "foo/cond")
self.assertRegexpMatches(
cond1_op.get_attr("then_branch").name, r"foo_cond_true_\d*")
self.assertRegexpMatches(
cond1_op.get_attr("else_branch").name, r"foo_cond_false_\d*")
cond2_op = self._createCond(None)
_, cond2_op = self._createCond(None)
self.assertEqual(cond2_op.name, "foo/cond_1")
self.assertRegexpMatches(
cond2_op.get_attr("then_branch").name, r"foo_cond_1_true_\d*")
@ -612,11 +612,11 @@ class CondV2Test(test.TestCase):
def testLowering(self):
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
out_cond = self._createCond("cond")
cond_output, _ = self._createCond("cond")
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
sess.run(out_cond, options=run_options, run_metadata=run_metadata)
sess.run(cond_output, options=run_options, run_metadata=run_metadata)
# If lowering was enabled, there should be a `Switch` node
switch_found = any(
@ -641,12 +641,12 @@ class CondV2Test(test.TestCase):
# Build the cond_v2 in an XLA context
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
out_cond = self._createCond("cond")
cond_output, _ = self._createCond("cond")
xla_context.Exit()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
sess.run(out_cond, options=run_options, run_metadata=run_metadata)
sess.run(cond_output, options=run_options, run_metadata=run_metadata)
# Lowering disabled in XLA, there should be no `Switch` node
switch_found = any(

View File

@ -446,8 +446,7 @@ class ControlFlowTest(test.TestCase):
g = gradients_impl.gradients(y, x)[0]
self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
# TODO(b/119791601): Enable this.
# self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
@test_util.disable_control_flow_v2("b/113293074")
def testCondIndexedSlicesDifferentTypes(self):
@ -2168,11 +2167,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(512.0, self.evaluate(r))
def testNestedWhileCondWhileGrad(self):
if control_flow_ops.ENABLE_WHILE_V2 and test_util.is_gpu_available():
self.skipTest("b/118459209")
self._testNestedWhileCondWhileGrad(use_gpu=False)
@test_util.disable_control_flow_v2("b/118459209")
def testNestedWhileCondWhileGradGpu(self):
self._testNestedWhileCondWhileGrad(use_gpu=True)

View File

@ -30,7 +30,9 @@ from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2 as util
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gradients_impl
@ -149,7 +151,9 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
true_inputs, false_inputs)
# Add all intermediate tensors as function outputs so they're available for
# the gradient computation.
# the gradient computation. Since the outputs of the two functions must match,
# we wrap all the intermediates in optionals. Each intermediate output will
# have a value iff its corresponding branch is taken.
true_intermediates = _get_intermediates(true_graph)
false_intermediates = _get_intermediates(false_graph)
@ -157,12 +161,28 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
# Save the original number of outputs to return to the caller.
num_cond_outputs = len(true_graph.outputs)
# Make the number/type of new intermediate outputs match.
extra_true_outputs, extra_false_outputs = _pad_params(
true_graph, false_graph, true_intermediates, false_intermediates)
if control_flow_util.InXlaContext(ops.get_default_graph()):
# XLA does not yet support optionals, so output intermediates directly and
# make them match via FakeParams, which can be converted to zeros in XLA.
# TODO(skyewm,jpienaar): can XLA support optionals?
extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
true_graph, false_graph, true_intermediates, false_intermediates)
else:
# Wrap intermediates in optionals.
wrapped_true_intermediates = _wrap_intermediates(true_graph,
true_intermediates)
wrapped_false_intermediates = _wrap_intermediates(false_graph,
false_intermediates)
# Make outputs match by adding none optionals.
extra_true_outputs, extra_false_outputs = _make_intermediates_match(
true_graph, false_graph,
wrapped_true_intermediates, wrapped_false_intermediates)
true_graph.outputs.extend(extra_true_outputs)
false_graph.outputs.extend(extra_false_outputs)
# TODO(skyewm): somehow indicate it's a bug if this fails.
_check_same_outputs(true_graph, false_graph)
# Create the If op.
tensors = gen_functional_ops._if( # pylint: disable=protected-access
@ -175,7 +195,8 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
name=name)
# TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
util.maybe_set_lowering_attr(tensors[0].op)
if_op = tensors[0].op
util.maybe_set_lowering_attr(if_op)
# Return identities for each output of the If op, rather than the output of
# the If op directly. This makes pruning work if the output of cond() is
@ -187,6 +208,9 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
# correct output structure
tensors = [array_ops.identity(t) for t in tensors]
# Prevent fetching since the variant outputs can't be fetched directly.
if_op.graph.prevent_fetching(if_op)
return tensors[:num_cond_outputs]
@ -278,7 +302,7 @@ def _create_grad_func(func_graph, grads, name):
return func_graph_module.func_graph_from_py_func(
name,
lambda: _grad_fn(func_graph, grads), [], {},
func_graph=util.CondBranchFuncGraph(name, read_only_collections=False))
func_graph=_CondGradFuncGraph(name, func_graph))
def _resolve_grad_inputs(cond_graph, grad_graph):
@ -360,28 +384,39 @@ def _separate_unique_inputs(true_inputs, false_inputs):
return list(shared_inputs), list(true_only_inputs), list(false_only_inputs)
def _pad_params(true_graph, false_graph, true_params, false_params):
"""Returns new param lists that have matching signatures.
def _make_intermediates_match(true_graph, false_graph,
true_optionals, false_optionals):
"""Returns new optionals lists that have matching signatures.
This is done by mirroring each param list in the other using dummy params.
There is no merging of params.
This is done by mirroring each list in the other using none optionals.
There is no merging of like optionals.
Args:
true_graph: FuncGraph
false_graph: FuncGraph
true_params: a list of Tensors from true_graph
false_params: a list of Tensors from false_graph
true_optionals: a list of optional Tensors from true_graph
false_optionals: a list of optional Tensors from false_graph
Returns:
A new list of Tensors in true_graph and a new list of Tensors in
false_graph. The two lists have the same number of Tensors, with matching
types and shapes across the lists.
false_graph. The two lists have the same number of Tensors, all of which
will be optionals of the same shape/type.
"""
new_true_params = (true_params +
_create_dummy_params(true_graph, false_params))
new_false_inputs = (_create_dummy_params(false_graph, true_params)
+ false_params)
return new_true_params, new_false_inputs
new_true_optionals = (true_optionals +
_create_none_optionals(true_graph, false_optionals))
new_false_optionals = (_create_none_optionals(false_graph, true_optionals)
+ false_optionals)
return new_true_optionals, new_false_optionals
def _make_intermediates_match_xla(true_graph, false_graph, true_intermediates,
false_intermediates):
"""Like _make_intermediates_match but for the XLA case."""
new_true_intermediates = (true_intermediates +
_create_fakeparams(true_graph, false_intermediates))
new_false_intermediates = (_create_fakeparams(false_graph, true_intermediates)
+ false_intermediates)
return new_true_intermediates, new_false_intermediates
def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
@ -416,11 +451,11 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
true_graph.inputs = (
[true_input_to_param[t] for t in shared_inputs] +
[true_input_to_param[t] for t in true_only_inputs] +
_create_dummy_params(true_graph, false_only_inputs))
_create_dummy_inputs(true_graph, false_only_inputs))
false_graph.inputs = (
[false_input_to_param[t] for t in shared_inputs] +
_create_dummy_params(false_graph, true_only_inputs) +
_create_dummy_inputs(false_graph, true_only_inputs) +
[false_input_to_param[t] for t in false_only_inputs])
# Rewrite the FuncGraphs' state to reflect the new inputs.
@ -432,7 +467,12 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
return new_inputs
def _create_dummy_params(func_graph, template_tensors):
def _wrap_intermediates(func_graph, intermediates):
with func_graph.as_default():
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
def _create_dummy_inputs(func_graph, template_tensors):
"""Creates tensors in func_graph to represent template_tensors.
Args:
@ -442,6 +482,27 @@ def _create_dummy_params(func_graph, template_tensors):
Returns:
A list of tensors in func_graph.
"""
with func_graph.as_default():
return [array_ops.placeholder(t.dtype, shape=t.shape)
for t in template_tensors]
def _create_none_optionals(func_graph, template_tensors):
"""Creates none optionals in func_graph to represent template_tensors.
Args:
func_graph: FuncGraph.
template_tensors: a list of tensors in func_graph.
Returns:
A list of tensors in func_graph.
"""
with func_graph.as_default():
return [gen_dataset_ops.optional_none() for _ in template_tensors]
def _create_fakeparams(func_graph, template_tensors):
"""Create FakeParams for the XLA case."""
with func_graph.as_default():
return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape)
for t in template_tensors]
@ -474,3 +535,38 @@ def _get_output_shapes(true_graph_outputs, false_graph_outputs):
for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
]
return output_shapes
class _CondGradFuncGraph(util.CondBranchFuncGraph):
"""FuncGraph for the gradient function of the branch of an If op.
Handles unwrapping optional intermediate values that are captured by the
gradient computation.
"""
def __init__(self, name, forward_graph):
super(_CondGradFuncGraph, self).__init__(name, read_only_collections=False)
self._forward_graph = forward_graph
def _capture_helper(self, tensor, name):
if (tensor.graph is not self._forward_graph or
tensor in self._forward_graph.inputs or
tensor in self._forward_graph.outputs):
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
# 'tensor' is an intermediate in the forward graph. We find the corresonding
# optional tensor, which is output from the If op, and capture it as
# normal. We then unwrap the captured optional value to get the raw
# intermediate value.
for consumer in tensor.consumers():
if (consumer.type == "OptionalFromValue"
and consumer.outputs[0] in self._forward_graph.outputs):
optional = consumer.outputs[0]
captured_optional = super(_CondGradFuncGraph, self)._capture_helper(
optional, name)
return gen_dataset_ops.optional_get_value(
captured_optional, [tensor.dtype], [tensor.shape])[0]
raise ValueError(
"Couldn't find OptionalFromValue consumer for tensor '%s'.\n"
"This is an internal bug, please report at "
"https://github.com/tensorflow/tensorflow/issues." % tensor.name)

View File

@ -38,6 +38,11 @@ def IsInXLAContext(op):
return GetContainingXLAContext(ctxt) is not None
def InXlaContext(graph):
ctxt = graph._get_control_flow_context() # pylint: disable=protected-access
return GetContainingXLAContext(ctxt) is not None
def IsInWhileLoop(op):
ctxt = op._get_control_flow_context() # pylint: disable=protected-access
return GetContainingWhileContext(ctxt) is not None

View File

@ -49,6 +49,7 @@ from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import optional_grad # pylint: disable=unused-import
from tensorflow.python.ops import random_grad # pylint: disable=unused-import
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops

View File

@ -0,0 +1,33 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient functions for optional ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
@ops.RegisterGradient("OptionalFromValue")
def _OptionalFromValueGrad(op, grad):
return gen_dataset_ops.optional_get_value(
grad, [t.dtype for t in op.inputs], [t.shape for t in op.inputs])
@ops.RegisterGradient("OptionalGetValue")
def _OptionalGetValueGrad(unused_op, *grads):
return gen_dataset_ops.optional_from_value(grads)