Adds tf.switch_case, an n-way switch/case statement, to TF python. (C++ support was added previously, in the "Case" op+kernel, XLA)

Adds XLA support for Case with TensorArrays.

Propagate compile time constants into the branch functions of Case in XLA.
This enables using ops like StridedSlice which require compile time constants in the branch functions.

PiperOrigin-RevId: 248466828
This commit is contained in:
Brian Patton 2019-05-15 22:10:50 -07:00 committed by TensorFlower Gardener
parent 7bb7116c77
commit e224546ee5
14 changed files with 1112 additions and 256 deletions

View File

@ -19,12 +19,16 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -33,6 +37,7 @@ from tensorflow.python.platform import test
class CondTest(xla_test.XLATestCase): class CondTest(xla_test.XLATestCase):
def testCondAndTensorArrayInDefun(self): def testCondAndTensorArrayInDefun(self):
# TODO(b/132430685): Make test more useful. Also b/129396295, b/127846988
with self.session(), self.test_scope(): with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()
@ -47,7 +52,7 @@ class CondTest(xla_test.XLATestCase):
return output.stack() return output.stack()
output_t = f() output_t = f()
self.assertAllEqual(self.evaluate(output_t), [5.]) self.assertAllEqual([5.], self.evaluate(output_t))
xla_context.Exit() xla_context.Exit()
@ -71,11 +76,178 @@ class CondTest(xla_test.XLATestCase):
output = control_flow_ops.cond( output = control_flow_ops.cond(
constant_op.constant(True), if_true, if_false) constant_op.constant(True), if_true, if_false)
self.assertAllEqual( self.assertAllEqual(1.,
sess.run(output, feed_dict={ sess.run(output, feed_dict={
x: [0., 1., 2.], x: [0., 1., 2.],
p: 1 p: 1
}), 1.) }))
xla_context.Exit()
def testCondConstPropagation_xlaCompile(self):
self.skipTest("b/132430685")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder_with_default([0., 1., 2.], shape=[3])
p = constant_op.constant(1)
def f():
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def if_true():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[p]
def if_false():
return 5.
return control_flow_ops.cond(
constant_op.constant(True), if_true, if_false)
output = xla.compile(f)
self.assertAllEqual(1., self.evaluate(output))
xla_context.Exit()
def testCondConstPropagation_errorMsg(self):
self.skipTest("b/132430685")
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder(dtypes.float32)
p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32)
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def if_true():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[:p]
def if_false():
return array_ops.fill([p], 5.)
output = control_flow_ops.cond(
constant_op.constant(True), if_true, if_false)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must be a compile-time constant"):
sess.run(
output, feed_dict={
x: [0., 1., 2.],
})
xla_context.Exit()
def testCondConstPropagation_errorMsg_xlaCompile(self):
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder(dtypes.float32)
p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32)
condition = math_ops.cast(
random_ops.random_uniform([], minval=0, maxval=2, dtype=dtypes.int32),
dtypes.bool)
def f():
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def if_true():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[:p]
def if_false():
return array_ops.fill([p], 5.)
return control_flow_ops.cond(condition, if_true, if_false)
output = xla.compile(f)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must be a compile-time constant"):
sess.run(
output, feed_dict={
x: [0., 1., 2.],
})
xla_context.Exit()
def testSwitchCaseAndTensorArrayInDefun(self):
self.skipTest("b/127846988")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@function.defun
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.switch_case(
constant_op.constant(1), {
0: lambda: ta.write(0, 5.),
1: lambda: ta.write(0, 10.),
2: lambda: ta.write(0, 15.),
})
return output.stack()
output_t = f()
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testSwitchCaseConstPropagation(self):
self.skipTest("b/127846988")
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
x = array_ops.placeholder(dtypes.float32)
p = array_ops.placeholder(dtypes.int32)
def branch0():
return 5.
def branch1():
return 15.
# TODO(b/129021699): Wrapping this in a tf.function does not work.
def branch2():
# This emits a StridedSlice op which expects the index to be a
# compile-time const.
return x[p]
output = control_flow_ops.switch_case(
constant_op.constant(2), {
0: branch0,
1: branch1,
2: branch2,
})
self.assertAllEqual(7.,
sess.run(output, feed_dict={
x: [0., 1., 7.],
p: 2,
}))
xla_context.Exit()
def testCondNoInputs(self):
"""Verifies against `Failed precondition: Expected one input shape`."""
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
for pred in True, False:
cond_out = control_flow_ops.cond(
array_ops.placeholder_with_default(pred, []),
lambda: constant_op.constant(2.),
lambda: constant_op.constant(1.))
self.assertEqual(int(pred) + 1., self.evaluate(cond_out))
xla_context.Exit() xla_context.Exit()

View File

@ -125,6 +125,8 @@ Status BackwardsConstAnalysis(const Graph& g,
return status; return status;
} }
namespace {
Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node, Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node,
StringPiece func_attr_name, const FunctionBody** fbody) { StringPiece func_attr_name, const FunctionBody** fbody) {
NameAttrList name_attr_list; NameAttrList name_attr_list;
@ -136,6 +138,50 @@ Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node,
return Status::OK(); return Status::OK();
} }
Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, const Node* node,
StringPiece func_list_attr_name,
std::vector<const FunctionBody*>* fbodies) {
std::vector<NameAttrList> name_attr_lists;
TF_RETURN_IF_ERROR(
GetNodeAttr(node->def(), func_list_attr_name, &name_attr_lists));
for (const NameAttrList& name_attr_list : name_attr_lists) {
FunctionLibraryRuntime::Handle func_handle;
TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
name_attr_list.name(), AttrSlice(&name_attr_list.attr()),
&func_handle));
fbodies->push_back(flib_runtime->GetFunctionBody(func_handle));
}
return Status::OK();
}
Status CondConstInputIndices(
absl::Span<const FunctionBody* const> branch_bodies,
std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime) {
TF_RET_CHECK(!branch_bodies.empty());
TF_RET_CHECK(branch_bodies[0] != nullptr);
int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size();
// Stores indices of the "branch function" inputs that are expected to be
// compile time constants.
std::vector<bool> compile_time_const_arg_indices(num_inputs);
for (auto fbody : branch_bodies) {
TF_RET_CHECK(fbody != nullptr);
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*(fbody->graph), &compile_time_const_arg_indices,
/*compile_time_const_nodes=*/nullptr, flib_runtime));
}
for (int i = 0; i < compile_time_const_arg_indices.size(); i++) {
if (compile_time_const_arg_indices[i]) {
// The 0th input is the pred or branch index, which is not passed to the
// branches. So the i'th input of a branch function corresponds to the
// i + 1'th input of the If/Case op.
const_input_idxs->push_back(i + 1);
}
}
return Status::OK();
}
} // namespace
Status GetCompileTimeConstInputs(const Node* node, Status GetCompileTimeConstInputs(const Node* node,
std::vector<int>* const_input_idxs, std::vector<int>* const_input_idxs,
FunctionLibraryRuntime* flib_runtime) { FunctionLibraryRuntime* flib_runtime) {
@ -179,6 +225,7 @@ Status GetCompileTimeConstInputs(const Node* node,
} }
} }
} }
return Status::OK();
} else if (node->type_string() == "If") { } else if (node->type_string() == "If") {
const FunctionBody* fthen = nullptr; const FunctionBody* fthen = nullptr;
const FunctionBody* felse = nullptr; const FunctionBody* felse = nullptr;
@ -186,31 +233,17 @@ Status GetCompileTimeConstInputs(const Node* node,
GetFunctionBody(flib_runtime, node, "then_branch", &fthen)); GetFunctionBody(flib_runtime, node, "then_branch", &fthen));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GetFunctionBody(flib_runtime, node, "else_branch", &felse)); GetFunctionBody(flib_runtime, node, "else_branch", &felse));
TF_RET_CHECK(fthen); return CondConstInputIndices({fthen, felse}, const_input_idxs,
TF_RET_CHECK(felse); flib_runtime);
int num_inputs = fthen->fdef.signature().input_arg_size(); } else if (node->type_string() == "Case") {
// Stores indices of the "branch function" inputs that are expected to be std::vector<const FunctionBody*> branch_bodies;
// compile time constants. TF_RETURN_IF_ERROR(
std::vector<bool> compile_time_const_arg_indices(num_inputs); GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
TF_RETURN_IF_ERROR(BackwardsConstAnalysis( return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
*(fthen->graph), &compile_time_const_arg_indices,
/*compile_time_const_nodes=*/nullptr, flib_runtime));
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*(felse->graph), &compile_time_const_arg_indices,
/*compile_time_const_nodes=*/nullptr, flib_runtime));
for (int i = 0; i < compile_time_const_arg_indices.size(); i++) {
if (compile_time_const_arg_indices[i]) {
// The 0th input is the loop condition which is not passed to the
// branches. So the i'th input of a branch function corresponds to
// i + 1'th input of the If op.
const_input_idxs->push_back(i + 1);
}
}
} else { } else {
return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(), return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(),
const_input_idxs); const_input_idxs);
} }
return Status::OK();
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -311,6 +311,7 @@ tf_kernel_library(
srcs = ["case_op.cc"], srcs = ["case_op.cc"],
hdrs = ["case_op.h"], hdrs = ["case_op.h"],
deps = [ deps = [
":if_while_utils",
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/case_op.h" #include "tensorflow/compiler/tf2xla/kernels/case_op.h"
#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
@ -34,10 +35,41 @@ XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
} else { } else {
has_token_input_output_ = !token_input_nodes_.empty(); has_token_input_output_ = !token_input_nodes_.empty();
} }
if (ctx->HasAttr(kPropagateCompileTimeConsts)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts,
&propagate_compile_time_consts_));
}
} }
namespace {
Status ConvertCompileTimeConstArgumentsToConst(
XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args) {
for (int i = 0; i < args->size(); i++) {
XlaCompiler::Argument& arg = (*args)[i];
const XlaExpression& expression = ctx->InputExpression(i + 1);
// If the input tensor is a compile time constant build a kConstant type
// argument.
if (arg.kind == XlaCompiler::Argument::kParameter) {
// NOTE: We can not simply check that this is Kind::kConstant because
// this could be the output of a MetadataOnly op e.g. Size.
xla::StatusOr<absl::optional<Tensor>> maybe_constant =
expression.ResolveConstant(ctx->compiler()->client());
if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) {
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = expression.dtype();
arg.constant_value = std::move(maybe_constant.ValueOrDie().value());
arg.shape = expression.GetShape().ValueOrDie();
}
}
}
return Status::OK();
}
} // namespace
// TODO(b/35949885): There is duplication here with the handling of the // TODO(b/35949885): There is duplication here with the handling of the
// while_op. Refactor the common code out/rework. // while_op/if_op. Refactor the common code out/rework.
void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaBuilder* b = ctx->builder(); xla::XlaBuilder* b = ctx->builder();
int num_branches = branches_.size(); int num_branches = branches_.size();
@ -84,12 +116,30 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
} else { } else {
arg.kind = XlaCompiler::Argument::kParameter; arg.kind = XlaCompiler::Argument::kParameter;
arg.type = input_types_[i]; arg.type = input_types_[i];
arg.shape = ctx->InputShape(i + 1); // Use the xla::Shape for the input instead of ctx->InputShape. This is
// necessary for forwarding shapes of DT_VARIANTs, e.g. TensorLists.
auto shape_or = ctx->builder()->GetShape(ctx->Input(i + 1));
OP_REQUIRES_OK(ctx, shape_or.status());
arg.shape = shape_or.ValueOrDie();
VLOG(2) << "Arg type: " << DataTypeString(arg.type) VLOG(2) << "Arg type: " << DataTypeString(arg.type)
<< " shape: " << arg.HumanString(); << " shape: " << arg.HumanString();
} }
} }
if (propagate_compile_time_consts_) {
// Replaces `kParameter` type args in `arguments` with `kConstant` if
// the op input corresponding to that arg is a compile-time const. This
// is necessary to propagate compile time consts to ops in the branch
// functions.
// Note: Propagating "all" compile-time constants may not be necessary. We
// should ideally only propagate consts which are required to be compile
// time constants in the branch functions. But that would require calling
// BackwardsConstAnalysis here which would be expensive. However, if we
// start hitting memory issues we should revisit this.
OP_REQUIRES_OK(ctx,
ConvertCompileTimeConstArgumentsToConst(ctx, &arguments));
}
// Compile each branch of the conditional. // Compile each branch of the conditional.
XlaCompiler::CompileOptions options; XlaCompiler::CompileOptions options;
options.use_tuple_arg = true; options.use_tuple_arg = true;
@ -158,8 +208,6 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
} }
OP_REQUIRES(ctx, branch_input_shape.IsTuple(), OP_REQUIRES(ctx, branch_input_shape.IsTuple(),
errors::FailedPrecondition("Expected tuple shape")); errors::FailedPrecondition("Expected tuple shape"));
OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
OP_REQUIRES( OP_REQUIRES(
ctx, ctx,
xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape), xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape),
@ -227,7 +275,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
} else { } else {
inputs[i] = ctx->Input(i + 1); inputs[i] = ctx->Input(input_num);
} }
} }
auto input_tuple = xla::Tuple(b, inputs); auto input_tuple = xla::Tuple(b, inputs);
@ -292,6 +340,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Done building Case"; VLOG(1) << "Done building Case";
} }
REGISTER_XLA_OP(Name("Case").AllowResourceTypes(), XlaCaseOp); REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(),
XlaCaseOp);
} // namespace tensorflow } // namespace tensorflow

View File

@ -55,6 +55,10 @@ class XlaCaseOp : public XlaOpKernel {
DataTypeVector output_types_; DataTypeVector output_types_;
bool has_token_input_output_; bool has_token_input_output_;
std::vector<string> token_input_nodes_; std::vector<string> token_input_nodes_;
// Whether to propagate compile time consts into the cond branches.
// This is not supported by default now since it may cause HBM memory
// overheads.
bool propagate_compile_time_consts_ = false;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -776,7 +776,7 @@ Status XlaCompiler::BuildArguments(
} }
} }
if (input_to_args->empty()) { if (input_to_args->empty() && !use_tuple_arg) {
return Status::OK(); return Status::OK();
} }
@ -829,8 +829,9 @@ Status XlaCompiler::BuildArguments(
xla::ShapeUtil::GetLeafCount(arg_shapes[i]), xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
args[input_to_args->at(i)].is_same_data_across_replicas); args[input_to_args->at(i)].is_same_data_across_replicas);
} }
xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, xla::XlaScopedShardingAssignment assign_tuple_sharding(
tuple_sharding); builder, input_to_args->empty() ? absl::optional<xla::OpSharding>()
: tuple_sharding);
tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple", tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
is_same_across_replicas); is_same_across_replicas);
} else { } else {

View File

@ -3747,6 +3747,7 @@ cuda_py_test(
":while_v2", ":while_v2",
"//tensorflow/python/eager:def_function", "//tensorflow/python/eager:def_function",
], ],
shard_count = 2,
xla_enable_strict_auto_jit = True, xla_enable_strict_auto_jit = True,
) )

View File

@ -2593,8 +2593,15 @@ class Operation(object):
func = attr_value_pb2.NameAttrList(name=func_name) func = attr_value_pb2.NameAttrList(name=func_name)
self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func)) self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
def _set_func_list_attr(self, attr_name, func_names):
"""Private method used to set a list(function) attribute in the node_def."""
funcs = [attr_value_pb2.NameAttrList(name=func_name)
for func_name in func_names]
funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs)
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list))
def _set_type_list_attr(self, attr_name, types): def _set_type_list_attr(self, attr_name, types):
"""Private method used to set a function attribute in the node_def.""" """Private method used to set a list(type) attribute in the node_def."""
if not types: if not types:
return return
if isinstance(types[0], dtypes.DType): if isinstance(types[0], dtypes.DType):
@ -2603,7 +2610,7 @@ class Operation(object):
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list)) self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
def _set_shape_list_attr(self, attr_name, shapes): def _set_shape_list_attr(self, attr_name, shapes):
"""Private method used to set a function attribute in the node_def.""" """Private method used to set a list(shape) attribute in the node_def."""
shapes = [s.as_proto() for s in shapes] shapes = [s.as_proto() for s in shapes]
shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))

View File

@ -776,6 +776,33 @@ class ControlFlowTest(test.TestCase):
"Tensor true_branch:0 in true_fn is accessed from false_fn."): "Tensor true_branch:0 in true_fn is accessed from false_fn."):
f() f()
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
@def_function.function
def f():
c = constant_op.constant(1.)
inputs = {"c": c}
def br1_fn(inputs):
inputs["c"] = array_ops.identity(inputs["c"], name="br1_identity")
return inputs["c"]
def br4_fn(inputs):
return array_ops.identity(inputs["c"])
def other_fn():
return array_ops.identity(c)
return control_flow_ops.switch_case(
constant_op.constant(2),
[other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
lambda: br4_fn(inputs)])
with self.assertRaisesRegexp(
ValueError,
"Tensor br1_identity:0 in branch 1 is accessed from branch 4."):
f()
def testCondListOutput(self): def testCondListOutput(self):
with self.cached_session() as sess: with self.cached_session() as sess:
x = constant_op.constant(10) x = constant_op.constant(10)

View File

@ -47,6 +47,9 @@ from tensorflow.python.util import nest
# readability. # readability.
# pylint: disable=protected-access # pylint: disable=protected-access
_COND = 1
_CASE = 2
def cond_v2(pred, true_fn, false_fn, name="cond"): def cond_v2(pred, true_fn, false_fn, name="cond"):
"""Like tf.cond, except emits a single If op.""" """Like tf.cond, except emits a single If op."""
@ -80,7 +83,7 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
add_control_dependencies=add_control_dependencies, add_control_dependencies=add_control_dependencies,
op_return_value=pred) op_return_value=pred)
verify_captures(true_graph, false_graph) verify_captures(_COND, [true_graph, false_graph])
return _build_cond(pred, true_graph, false_graph, return _build_cond(pred, true_graph, false_graph,
true_graph.external_captures, true_graph.external_captures,
false_graph.external_captures, false_graph.external_captures,
@ -106,8 +109,7 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
false_grad_graph = _create_grad_func( false_grad_graph = _create_grad_func(
false_graph, grads, util.unique_grad_fn_name(false_graph.name)) false_graph, grads, util.unique_grad_fn_name(false_graph.name))
if (true_grad_graph.if_op_needs_rewrite or if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
false_grad_graph.if_op_needs_rewrite):
# Modify 'op' to output the intermediates needed by the grad functions. Note # Modify 'op' to output the intermediates needed by the grad functions. Note
# that all needed intermediates are wrapped in optionals. Each optional # that all needed intermediates are wrapped in optionals. Each optional
# intermediate output will have a value iff its corresponding branch is # intermediate output will have a value iff its corresponding branch is
@ -122,18 +124,18 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
true_intermediates = true_grad_graph.xla_intermediates true_intermediates = true_grad_graph.xla_intermediates
false_intermediates = false_grad_graph.xla_intermediates false_intermediates = false_grad_graph.xla_intermediates
extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
true_graph, false_graph, true_intermediates, false_intermediates) [true_graph, false_graph], [true_intermediates, false_intermediates])
else: else:
true_intermediates = true_grad_graph.wrapped_intermediates true_intermediates = true_grad_graph.wrapped_intermediates
false_intermediates = false_grad_graph.wrapped_intermediates false_intermediates = false_grad_graph.wrapped_intermediates
# Make outputs match by adding none optionals. # Make outputs match by adding none optionals.
extra_true_outputs, extra_false_outputs = _make_intermediates_match( extra_true_outputs, extra_false_outputs = _make_intermediates_match(
true_graph, false_graph, true_intermediates, false_intermediates) [true_graph, false_graph], [true_intermediates, false_intermediates])
true_graph.outputs.extend(extra_true_outputs) true_graph.outputs.extend(extra_true_outputs)
false_graph.outputs.extend(extra_false_outputs) false_graph.outputs.extend(extra_false_outputs)
# TODO(skyewm): indicate it's an internal bug if this fails. # TODO(skyewm): indicate it's an internal bug if this fails.
_check_same_outputs(true_graph, false_graph) _check_same_outputs(_COND, [true_graph, false_graph])
true_graph.name += "_rewritten" true_graph.name += "_rewritten"
false_graph.name += "_rewritten" false_graph.name += "_rewritten"
@ -153,7 +155,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
# This modifies true_grad_graph and false_grad_graph. # This modifies true_grad_graph and false_grad_graph.
_make_output_composite_tensors_match(true_grad_graph, false_grad_graph) _make_output_composite_tensors_match(_COND,
[true_grad_graph, false_grad_graph])
outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph, outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph,
true_grad_inputs, false_grad_inputs) true_grad_inputs, false_grad_inputs)
@ -185,13 +188,13 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
A list of Tensors which are the outputs of the If op. Does not include added A list of Tensors which are the outputs of the If op. Does not include added
intermediate outputs. intermediate outputs.
""" """
_make_indexed_slices_indices_types_match(true_graph, false_graph) _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
_check_same_outputs(true_graph, false_graph) _check_same_outputs(_COND, [true_graph, false_graph])
# Add inputs to true_graph and false_graph to make them match. Note that # Add inputs to true_graph and false_graph to make them match. Note that
# this modifies true_graph and false_graph. # this modifies true_graph and false_graph.
cond_inputs = _make_inputs_match(true_graph, false_graph, cond_inputs = _make_inputs_match([true_graph, false_graph],
true_inputs, false_inputs) [true_inputs, false_inputs])
# Create the If op. # Create the If op.
with ops.control_dependencies( with ops.control_dependencies(
@ -226,38 +229,45 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
tensors) tensors)
def _get_func_graphs(if_op): def _get_func_graphs(op):
"""Returns `FuncGraph`s for the input op branches. """Returns `FuncGraph`s for the input op branches.
Args: Args:
if_op: The _If Operation. op: The If or Case Operation.
Returns: Returns:
A 2-tuple of the `FuncGraph`s of the then_branch and else_branch. A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches
for Case).
""" """
def _get_func_graph_for_branch(branch_name):
def _get_func_graph_for_branch(name_attr_list):
"""Generates and returns a FuncGraph for the given branch.""" """Generates and returns a FuncGraph for the given branch."""
inputs = if_op.inputs[1:] # First input is pred. inputs = op.inputs[1:] # First input is pred.
input_shapes = [t.shape for t in inputs] input_shapes = [t.shape for t in inputs]
func_name = if_op.get_attr(branch_name).name fdef = op.graph._get_function(name_attr_list.name).definition
fdef = if_op.graph._get_function(func_name).definition # `op.graph` may not be the same as `ops.get_default_graph()` e.g.
# `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
# in the case of nested if ops or when the gradient is being computed # in the case of nested if ops or when the gradient is being computed
# from inside a Defun. We build the `func_graph` with `if_op.graph` as its # from inside a Defun. We build the `func_graph` with `op.graph` as its
# `outer_graph`. This resembles how the `FuncGraph` was built in the # `outer_graph`. This resembles how the `FuncGraph` was built in the
# forward pass. We need this so that we can resolve references to tensors # forward pass. We need this so that we can resolve references to tensors
# in `func_graph` from its gradient graph in `_resolve_grad_inputs`. # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
with if_op.graph.as_default(): with op.graph.as_default():
func_graph = function_def_to_graph.function_def_to_graph( func_graph = function_def_to_graph.function_def_to_graph(
fdef, input_shapes) fdef, input_shapes)
func_graph.captures = collections.OrderedDict(zip(inputs, func_graph.captures = collections.OrderedDict(zip(inputs,
func_graph.inputs)) func_graph.inputs))
# Set the if op so that the gradient code can use it. # Link the op so that the gradient code can use it.
func_graph._if = if_op func_graph._forward_cond = op
return func_graph return func_graph
return (_get_func_graph_for_branch("then_branch"), if op.type == "If":
_get_func_graph_for_branch("else_branch")) return (_get_func_graph_for_branch(op.get_attr("then_branch")),
_get_func_graph_for_branch(op.get_attr("else_branch")))
elif op.type == "Case":
return [_get_func_graph_for_branch(branch_fn)
for branch_fn in op.get_attr("branches")]
else:
raise ValueError("Unsupported op type: {}".format(op.type))
def _grad_fn(func_graph, grads): def _grad_fn(func_graph, grads):
@ -348,7 +358,7 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
# to If op outputs. So we get the outer tensor corresponding to those # to If op outputs. So we get the outer tensor corresponding to those
# from the list of `external_captures`. # from the list of `external_captures`.
try: try:
t = t.graph._if.outputs[t.graph.outputs.index(t)] t = t.graph._forward_cond.outputs[t.graph.outputs.index(t)]
except ValueError: except ValueError:
index = t.graph.internal_captures.index(t) index = t.graph.internal_captures.index(t)
t = t.graph.external_captures[index] t = t.graph.external_captures[index]
@ -373,211 +383,191 @@ def _get_intermediates(func_graph):
return intermediates return intermediates
def _separate_unique_inputs(true_inputs, false_inputs): def _make_intermediates_match(branch_graphs, branch_optionals):
"""Separates tensors appearing only in true_inputs or false_inputs, or both.
Args:
true_inputs: list of Tensors
false_inputs: list of Tensors
Returns:
Three lists of Tensors:
1. The tensors that appear in both true_inputs and false_inputs
2. The tensors that only appear in true_inputs
3. The tensors that only appear in false_inputs
"""
true_inputs = set(true_inputs)
false_inputs = set(false_inputs)
shared_inputs = true_inputs.intersection(false_inputs)
true_only_inputs = true_inputs - false_inputs
false_only_inputs = false_inputs - true_inputs
return list(shared_inputs), list(true_only_inputs), list(false_only_inputs)
def _make_intermediates_match(true_graph, false_graph,
true_optionals, false_optionals):
"""Returns new optionals lists that have matching signatures. """Returns new optionals lists that have matching signatures.
This is done by mirroring each list in the other using none optionals. This is done by mirroring each list in the other using none optionals.
There is no merging of like optionals. There is no merging of like optionals.
Args: Args:
true_graph: FuncGraph branch_graphs: `list` of `FuncGraph`.
false_graph: FuncGraph branch_optionals: `list` of `list`s of optional `Tensor`s from other
true_optionals: a list of optional Tensors from true_graph branch_graphs
false_optionals: a list of optional Tensors from false_graph
Returns: Returns:
A new list of Tensors in true_graph and a new list of Tensors in A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the
false_graph. The two lists have the same number of Tensors, all of which same number of `Tensor`s, all of which will be optionals of the same
will be optionals of the same shape/type. shape/type.
""" """
new_true_optionals = (true_optionals + new_branch_optionals = []
_create_none_optionals(true_graph, false_optionals)) # Since the intermediates are optionals with dtype variant, we only need
new_false_optionals = (_create_none_optionals(false_graph, true_optionals) # enough room for the longest list of intermediates.
+ false_optionals) intermediates_size = max(len(o) for o in branch_optionals)
return new_true_optionals, new_false_optionals for i, branch_graph in enumerate(branch_graphs):
other_optionals = _create_none_optionals(
branch_graph, intermediates_size - len(branch_optionals[i]))
new_branch_optionals.append(branch_optionals[i] + other_optionals)
return new_branch_optionals
def _make_intermediates_match_xla(true_graph, false_graph, true_intermediates, def _make_intermediates_match_xla(branch_graphs, branch_intermediates):
false_intermediates):
"""Like _make_intermediates_match but for the XLA case.""" """Like _make_intermediates_match but for the XLA case."""
new_true_intermediates = (true_intermediates + new_branch_intermediates = []
_create_fakeparams(true_graph, false_intermediates)) for i, branch_graph in enumerate(branch_graphs):
new_false_intermediates = (_create_fakeparams(false_graph, true_intermediates) other_fakeparams = _create_fakeparams(
+ false_intermediates) branch_graph,
return new_true_intermediates, new_false_intermediates sum((bi for bi in branch_intermediates
if bi is not branch_intermediates[i]), []))
num_preceding = sum(len(bi) for bi in branch_intermediates[:i])
new_branch_intermediates.append(other_fakeparams[:num_preceding] +
branch_intermediates[i] +
other_fakeparams[num_preceding:])
return new_branch_intermediates
def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): def _make_inputs_match(branch_graphs, branch_inputs):
"""Modifies true_graph and false_graph so they have the same input signature. """Modifies branch_graphs so they have the same input signature.
This method reorders and/or adds parameters to true_graph and false_graph so This method reorders and/or adds parameters to each graph in branch_graphs so
they have the same input signature, and updates the 'inputs' and 'captured' they have the same input signature, and updates the 'inputs' and 'captured'
fields of both graphs accordingly. It uses the input tensors from the outer fields of each graph accordingly. It uses the input tensors from the outer
graph to avoid duplicating shared arguments. graph to avoid duplicating shared arguments.
Args: Args:
true_graph: FuncGraph branch_graphs: a `list` of `FuncGraph`
false_graph: FuncGraph branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
true_inputs: a list of Tensors in the outer graph. The inputs for inputs for the corresponding graph in `branch_graphs`.
true_graph.
false_inputs: a list of Tensors in the outer graph. The inputs for
false_graph.
Returns: Returns:
A new list of Tensors from the outer graph that are the new inputs for both A new list of Tensors from the outer graph that are the new inputs for each
true_graph and false_graph. This is a deduped version of true_inputs + branch_graph. This is a deduped version of `sum(branch_inputs)`.
false_inputs.
""" """
shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( assert len(branch_graphs) == len(branch_inputs)
true_inputs, false_inputs) new_inputs = set()
for branch_in in branch_inputs:
new_inputs |= set(branch_in)
new_inputs = list(new_inputs)
new_inputs = shared_inputs + true_only_inputs + false_only_inputs for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
branch_input_to_param = dict(zip(branch_in, branch_graph.inputs))
input_list = []
for in_t in new_inputs:
param = branch_input_to_param.get(in_t, None)
if param is None:
param = _create_dummy_input(branch_graph, in_t)
input_list.append(param)
true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) branch_graph.inputs = input_list
false_input_to_param = dict(zip(false_inputs, false_graph.inputs))
true_graph.inputs = ( # Rewrite the FuncGraphs' state to reflect the new inputs.
[true_input_to_param[t] for t in shared_inputs] + branch_graph.captures = collections.OrderedDict(
[true_input_to_param[t] for t in true_only_inputs] + zip(new_inputs, branch_graph.inputs))
_create_dummy_inputs(true_graph, false_only_inputs))
false_graph.inputs = (
[false_input_to_param[t] for t in shared_inputs] +
_create_dummy_inputs(false_graph, true_only_inputs) +
[false_input_to_param[t] for t in false_only_inputs])
# Rewrite the FuncGraphs' state to reflect the new inputs.
true_graph.captures = collections.OrderedDict(zip(new_inputs,
true_graph.inputs))
false_graph.captures = collections.OrderedDict(zip(new_inputs,
false_graph.inputs))
return new_inputs return new_inputs
def _make_output_composite_tensors_match(true_graph, false_graph): def _make_output_composite_tensors_match(op_type, branch_graphs):
"""Modifies true_graph and false_graph so they have the same output signature. """Modifies each branch_graph's outputs to have the same output signature.
Currently the only transformation implemented is turning a Tensor into an Currently the only transformation implemented is turning a Tensor into an
equivalent IndexedSlices if the other branch returns an IndexedSlices. equivalent IndexedSlices if the other branch returns an IndexedSlices.
Updates {true,false}_graph.{outputs,structured_outputs}. Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
branch_graphs.
Args: Args:
true_graph: FuncGraph op_type: _COND or _CASE
false_graph: FuncGraph branch_graphs: `list` of `FuncGraph`
Raises: Raises:
TypeError: if a pair of outputs cannot be rewritten. TypeError: if a set of outputs cannot be rewritten.
""" """
# Note: since this is only used for gradient graphs, we do not expect the # Note: since this is only used for gradient graphs, we do not expect the
# outputs to be structured (e.g. nested lists), and thus do not need to use # outputs to be structured (e.g. nested lists), and thus do not need to use
# nest.flatten, etc. # nest.flatten, etc.
true_outputs = list(true_graph.structured_outputs) assert branch_graphs
false_outputs = list(false_graph.structured_outputs) branch_outputs = [g.structured_outputs for g in branch_graphs]
assert len(true_outputs) == len(false_outputs) outputs_per_branch = list(len(outs) for outs in branch_outputs)
assert len(set(outputs_per_branch)) == 1, outputs_per_branch
for idx, (true_out, false_out) in enumerate(zip(true_outputs, false_outputs)): for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
if type(true_out) == type(false_out): # pylint: disable=unidiomatic-typecheck if len(set(type(out) for out in branch_outs)) == 1:
continue continue
if (isinstance(true_out, ops.IndexedSlices) and if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs):
isinstance(false_out, ops.Tensor)): continue
with false_graph.as_default(): for branch_idx, branch_out in enumerate(branch_outs):
false_outputs[idx] = math_ops._as_indexed_slices(false_out) if isinstance(branch_out, ops.IndexedSlices):
elif (isinstance(true_out, ops.Tensor) and continue
isinstance(false_out, ops.IndexedSlices)): elif isinstance(branch_out, ops.Tensor):
with true_graph.as_default(): with branch_graphs[branch_idx].as_default():
true_outputs[idx] = math_ops._as_indexed_slices(true_out) branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices(
else: branch_out)
raise TypeError( else:
"Cannot reconcile tf.cond %i-th outputs:\n" raise TypeError(
" true_fn returned: %s\n" "Cannot reconcile {op_name} {output_idx}-th outputs:\n"
" false_fn returned: %s" % (idx, true_out, false_out)) " outputs from all branches: {outputs}".format(
op_name="tf.cond" if op_type == _COND else "tf.case",
output_idx=output_idx,
outputs=branch_outs))
true_graph.structured_outputs = true_outputs for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
true_graph.outputs = func_graph_module.flatten(true_outputs) branch_graph.structured_outputs = branch_outs
false_graph.structured_outputs = false_outputs branch_graph.outputs = func_graph_module.flatten(branch_outs)
false_graph.outputs = func_graph_module.flatten(false_outputs)
def _make_indexed_slices_indices_types_match(true_graph, false_graph): def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
"""Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs.""" """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
assert branch_graphs
indexed_slice_indices = [] indexed_slice_indices = []
current_index = 0 current_index = 0
true_outputs_flat_with_composites = nest.flatten( branch_outputs_flat_with_composites = [
true_graph.structured_outputs, expand_composites=False) nest.flatten(branch_graph.structured_outputs, expand_composites=False)
false_outputs_flat_with_composites = nest.flatten( for branch_graph in branch_graphs
false_graph.structured_outputs, expand_composites=False) ]
outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
assert len(set(outs_per_branch)) == 1, outs_per_branch
# Store indices of IndexedSlices.indices in `indexed_slice_indices`. # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
for idx, (true_out, false_out) in enumerate( for output_idx, branch_outs in enumerate(
zip(true_outputs_flat_with_composites, zip(*branch_outputs_flat_with_composites)):
false_outputs_flat_with_composites)): if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
if isinstance(true_out, ops.IndexedSlices) != isinstance( raise TypeError("Cannot reconcile {op_name} {output_idx}-th outputs:\n"
false_out, ops.IndexedSlices): " branches returned: {outputs}".format(
raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n" op_name="tf.cond" if op_type == _COND else "tf.case",
" true_fn returned: %s\n" output_idx=output_idx,
" false_fn returned: %s" % (idx, true_out, false_out)) outputs=branch_outs))
if isinstance(true_out, ops.IndexedSlices): if isinstance(branch_outs[0], ops.IndexedSlices):
# indices is the second component of the composite tensor. # indices is the second component of the composite tensor.
indexed_slice_indices.append(current_index + 1) indexed_slice_indices.append(current_index + 1)
if nest.is_sequence_or_composite(true_out): if nest.is_sequence_or_composite(branch_outs[0]):
current_index += len(nest.flatten(true_out, expand_composites=True)) current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
else: else:
current_index += 1 current_index += 1
if not indexed_slice_indices: if not indexed_slice_indices:
return return
if current_index != len(true_graph.outputs): if current_index != len(branch_graphs[0].outputs):
raise ValueError("Insufficient elements in true_graph.outputs.\n" raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
"Expected: %i\n" "Expected: %i\n"
"Actual: %i" % (current_index, len(true_graph.outputs))) "Actual: %i" %
(current_index, len(branch_graphs[0].outputs)))
# Cast indices with mismatching types to int64. # Cast indices with mismatching types to int64.
for index in indexed_slice_indices: for index in indexed_slice_indices:
if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
for bg in branch_graphs):
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
"Found: %s" % str(true_graph.outputs[index].dtype)) "Found: %s" %
if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): str([bg.outputs[index].dtype for bg in branch_graphs]))
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
"Found: %s" % str(false_graph.outputs[index].dtype)) for branch_graph in branch_graphs:
if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype: if branch_graph.outputs[index].dtype == dtypes.int32:
if false_graph.outputs[index].dtype == dtypes.int32: with branch_graph.as_default():
with false_graph.as_default(): branch_graph.outputs[index] = math_ops.cast(
false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index], branch_graph.outputs[index], dtypes.int64)
dtypes.int64)
else:
with true_graph.as_default():
true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index],
dtypes.int64)
true_graph.structured_outputs = func_graph_module.pack_sequence_as( for branch_graph in branch_graphs:
true_graph.structured_outputs, true_graph.outputs) branch_graph.structured_outputs = func_graph_module.pack_sequence_as(
false_graph.structured_outputs = func_graph_module.pack_sequence_as( branch_graph.structured_outputs, branch_graph.outputs)
false_graph.structured_outputs, false_graph.outputs)
def _wrap_intermediates(func_graph, intermediates): def _wrap_intermediates(func_graph, intermediates):
@ -585,33 +575,33 @@ def _wrap_intermediates(func_graph, intermediates):
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates] return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
def _create_dummy_inputs(func_graph, template_tensors): def _create_dummy_input(func_graph, template_tensor):
"""Creates tensors in func_graph to represent template_tensors. """Creates tensors in func_graph to represent template_tensors.
Args: Args:
func_graph: FuncGraph. func_graph: FuncGraph.
template_tensors: a list of tensors in the outer graph. template_tensor: a tensor in the outer graph.
Returns: Returns:
A list of tensors in func_graph. A tensor in func_graph.
""" """
with func_graph.as_default(): with func_graph.as_default():
return [array_ops.placeholder(t.dtype, shape=t.shape) return array_ops.placeholder(
for t in template_tensors] template_tensor.dtype, shape=template_tensor.shape)
def _create_none_optionals(func_graph, template_tensors): def _create_none_optionals(func_graph, n):
"""Creates none optionals in func_graph to represent template_tensors. """Creates `n` `None` optionals in func_graph.
Args: Args:
func_graph: FuncGraph. func_graph: FuncGraph.
template_tensors: a list of tensors in func_graph. n: `int` the number of `None` optionals to make.
Returns: Returns:
A list of tensors in func_graph. A list of tensors in func_graph.
""" """
with func_graph.as_default(): with func_graph.as_default():
return [gen_dataset_ops.optional_none() for _ in template_tensors] return [gen_dataset_ops.optional_none() for _ in range(n)]
def _create_fakeparams(func_graph, template_tensors): def _create_fakeparams(func_graph, template_tensors):
@ -621,55 +611,69 @@ def _create_fakeparams(func_graph, template_tensors):
for t in template_tensors] for t in template_tensors]
def _check_same_outputs(true_graph, false_graph): def _check_same_outputs(op_type, graphs):
"""Raises an error if true_graph and false_graph have different outputs.""" """Raises an error if `graphs` have different outputs."""
def error(error_detail): def error(branch_idx, error_detail):
raise TypeError( raise TypeError(
"true_fn and false_fn arguments to tf.cond must have the same number, " "{b0_name} and {bn_name} arguments to {op_name} must have the same "
"type, and overall structure of return values.\n" "number, type, and overall structure of return values.\n"
"\n" "\n"
"true_fn output: %s\n" "{b0_name} output: {b0_out}\n"
"false_fn output: %s\n" "{bn_name} output: {bn_out}\n"
"\n" "\n"
"Error details:\n" "Error details:\n"
"%s" % (true_graph.structured_outputs, false_graph.structured_outputs, "{detail}".format(
error_detail)) b0_name="true_fn" if op_type == _COND else "branches[0]",
bn_name=("false_fn" if op_type == _COND else
"branches[{}]".format(branch_idx)),
op_name="tf.cond" if op_type == _COND else "tf.case",
b0_out=graphs[0].structured_outputs,
bn_out=graphs[branch_idx].structured_outputs,
detail=error_detail))
try: for b in range(1, len(graphs)):
nest.assert_same_structure(true_graph.structured_outputs, try:
false_graph.structured_outputs, nest.assert_same_structure(
expand_composites=True) graphs[0].structured_outputs,
except (ValueError, TypeError) as e: graphs[b].structured_outputs,
error(str(e)) expand_composites=True)
except (ValueError, TypeError) as e:
error(b, str(e))
assert len(true_graph.outputs) == len(false_graph.outputs) assert len(graphs[0].outputs) == len(graphs[b].outputs)
for true_out, false_out in zip(true_graph.outputs, false_graph.outputs): for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs):
if true_out.dtype != false_out.dtype: if b0_out.dtype != bn_out.dtype:
error("%s and %s have different types" % (true_out, false_out)) error(b, "%s and %s have different types" % (b0_out, bn_out))
def _get_output_shapes(true_graph_outputs, false_graph_outputs): def _get_output_shapes(*branch_graph_outputs):
output_shapes = [ output_shapes = []
t_out.shape.most_specific_compatible_shape(f_out.shape) for out_by_branch in zip(*branch_graph_outputs):
for t_out, f_out in zip(true_graph_outputs, false_graph_outputs) shape = out_by_branch[0].shape
] for other_out in out_by_branch[1:]:
shape = shape.most_specific_compatible_shape(other_out.shape)
output_shapes.append(shape)
return output_shapes return output_shapes
def verify_captures(true_graph, false_graph): def verify_captures(op_type, branch_graphs):
"""Verify that a true_fn tensor is not accessed in false_fn and vice-versa.""" """Verify that a branch's tensor is not accessed in another branch fn."""
for t in false_graph.external_captures: # Note: It is technically not possible for lower-branch_index branches to
if not isinstance(t, ops.EagerTensor) and t.graph is true_graph: # capture tensors from higher-branch_index branches, because of the order of
raise ValueError("Tensor {} in true_fn is accessed from false_fn.".format( # branch graph construction, but we check all for completeness and to
t.name))
# Note: This is technically not possible right now because `false_graph`
# is built "after" `true_graph` but we add this check for completeness and to
# guard against potential future changes. # guard against potential future changes.
for t in true_graph.external_captures: other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)}
if not isinstance(t, ops.EagerTensor) and t.graph is false_graph: for i, branch_graph in enumerate(branch_graphs):
raise ValueError("Tensor {} in false_fn is accessed from true_fn.".format( for t in branch_graph.external_captures:
t.name)) if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs:
branch_names = ["true_fn", "false_fn"] if op_type == _COND else [
"branch {}".format(bi) for bi in range(len(branch_graphs))]
raise ValueError(
"Tensor {tname} in {b0name} is accessed from {b1name}.".format(
tname=t.name,
b0name=branch_names[other_branch_graphs[t.graph]],
b1name=branch_names[i]))
class _CondGradFuncGraph(util.CondBranchFuncGraph): class _CondGradFuncGraph(util.CondBranchFuncGraph):
@ -679,14 +683,14 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
gradient computation in optionals. gradient computation in optionals.
Attributes: Attributes:
if_op_needs_rewrite: True if any intermediates were captured, meaning the op_needs_rewrite: True if any intermediates were captured, meaning the
forward If op needs to be written to output the wrapped intermediates. forward If op needs to be written to output the wrapped intermediates.
""" """
def __init__(self, name, forward_graph): def __init__(self, name, forward_graph):
super(_CondGradFuncGraph, self).__init__( super(_CondGradFuncGraph, self).__init__(
name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access
self.if_op_needs_rewrite = False self.op_needs_rewrite = False
self._forward_graph = forward_graph self._forward_graph = forward_graph
# Maps from forward intermediate tensor -> the unwrapped captured # Maps from forward intermediate tensor -> the unwrapped captured
# intermediate. # intermediate.
@ -719,7 +723,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
# TODO(skyewm,jpienaar): can XLA support optionals? # TODO(skyewm,jpienaar): can XLA support optionals?
if tensor not in self.captures: if tensor not in self.captures:
self.xla_intermediates.append(tensor) self.xla_intermediates.append(tensor)
self.if_op_needs_rewrite = True self.op_needs_rewrite = True
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
captured_tensor = self._indirect_captures.get(tensor) captured_tensor = self._indirect_captures.get(tensor)
@ -756,7 +760,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
# 'tensor' hasn't been wrapped, do it now. # 'tensor' hasn't been wrapped, do it now.
with self._forward_graph.as_default(): with self._forward_graph.as_default():
optional = gen_dataset_ops.optional_from_value([tensor]) optional = gen_dataset_ops.optional_from_value([tensor])
self.if_op_needs_rewrite = True self.op_needs_rewrite = True
self._wrapped_intermediates[tensor] = optional self._wrapped_intermediates[tensor] = optional
optional = self._wrapped_intermediates[tensor] optional = self._wrapped_intermediates[tensor]
@ -767,3 +771,182 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
self._indirect_captures[tensor] = captured_tensor self._indirect_captures[tensor] = captured_tensor
return captured_tensor return captured_tensor
def indexed_case(branch_index, branch_fns, name="indexed_case"):
"""Like conv_v2, except emits a Case op instead of an If."""
if isinstance(branch_index, int):
raise TypeError("branch_index must not be a Python int", branch_index)
with ops.name_scope(name) as scope:
branch_names = [
util.unique_fn_name(scope, "branch{}".format(b))
for b in range(len(branch_fns))
]
# Automatic control dependencies are added in defuns, but not in v1
# graphs. Propagate that behavior here.
add_control_dependencies = ops.get_default_graph()._add_control_dependencies
branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
branch_graphs = []
for branch_name, branch_fn in zip(branch_names, branch_fns):
branch_graphs.append(
func_graph_module.func_graph_from_py_func(
branch_name,
branch_fn,
[],
{},
func_graph=util.CondBranchFuncGraph(
branch_name,
collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
add_control_dependencies=add_control_dependencies,
op_return_value=branch_index))
verify_captures(_CASE, branch_graphs)
return _build_case(
branch_index,
branch_graphs, [g.external_captures for g in branch_graphs],
name=scope)
@ops.RegisterGradient("Case")
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
"""The gradient of a Case op produced (w/ branch_index) by tf.case."""
# Get the if operator (this logic handles the case where op is a MockOp)
case_op = op.outputs[0].op
branch_graphs = _get_func_graphs(case_op)
assert branch_graphs
# Note: op.graph != ops.get_default_graph() when we are computing the gradient
# of a nested cond.
for branch_graph in branch_graphs:
assert branch_graph.outer_graph == case_op.graph
# Create grad functions that compute the gradient of the branch forward
# graphs. These functions will capture tensors from the forward pass
# functions.
branch_grad_graphs = []
for branch_graph in branch_graphs:
branch_grad_graphs.append(
_create_grad_func(branch_graph, grads,
util.unique_grad_fn_name(branch_graph.name)))
if any(g.op_needs_rewrite for g in branch_grad_graphs):
# Modify 'op' to output the intermediates needed by the grad functions. Note
# that all needed intermediates are wrapped in optionals. Each optional
# intermediate output will have a value iff its corresponding branch is
# taken.
# NOTE(bjp): if there are any active sessions, this modification to `op`
# may make them unrunnable!
if control_flow_util.InXlaContext(ops.get_default_graph()):
# XLA does not yet support optionals, so output intermediates directly and
# make them match via FakeParams, which can be converted to zeros in XLA.
# TODO(bjp,jpienaar): can XLA support optionals?
branches_intermediates = [
branch_grad_graph.xla_intermediates
for branch_grad_graph in branch_grad_graphs
]
extra_branch_outputs = _make_intermediates_match_xla(
branch_graphs, branches_intermediates)
else:
branch_intermediates = [
g.wrapped_intermediates for g in branch_grad_graphs
]
# Make outputs match by adding none optionals.
extra_branch_outputs = _make_intermediates_match(branch_graphs,
branch_intermediates)
for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs):
branch_graph.outputs.extend(extra_outputs)
_make_indexed_slices_indices_types_match(_CASE, branch_graphs)
# TODO(bjp): indicate it's an internal bug if this fails.
_check_same_outputs(_CASE, branch_graphs)
for branch_graph in branch_graphs:
branch_graph.name += "_rewritten"
case_op._set_func_list_attr("branches", [
util.create_new_tf_function(branch_graph)
for branch_graph in branch_graphs
])
case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
case_op._set_shape_list_attr("output_shapes",
branch_graphs[0].output_shapes)
case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
[t.shape for t in extra_branch_outputs[0]])
# Resolve references to forward graph tensors in grad graphs and ensure
# they are in-scope, i.e., belong to one of outer graphs of the grad graph.
branches_grad_inputs = [
_resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
]
# This modifies the graphs in branch_grad_graphs.
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
branches_grad_inputs)
# The predicate has no gradient.
return [None] + outputs
def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
"""Creates an `Case` op from `branch_index`, branch graphs and inputs.
Note that this modifies `branch_graphs` to make the inputs match, and to
output all intermediates values so they're available for the gradient
computation.
`branch_graphs` need not have the same input types, but they must
have the same outpute types.
Args:
branch_index: integer Tensor
branch_graphs: List of FuncGraph
branch_inputs: List of lists of Tensors to be passed to corresponding
branch_graph as input.
name: the name for the Case op.
Returns:
A list of Tensors which are the outputs of the Case op. Does not include
added intermediate outputs.
"""
_check_same_outputs(_CASE, branch_graphs)
# Add inputs to branch_graphs to make them match. Note that this modifies the
# graphs in `branch_graphs`.
case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
# Create the Case op.
with ops.control_dependencies(
sum((list(bg.control_captures) for bg in branch_graphs), [])):
tensors = gen_functional_ops.case(
branch_index,
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
[util.create_new_tf_function(g) for g in branch_graphs],
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
name=name)
# TODO(b/110167197) this requires Case to have at least 1 output
case_op = tensors[0].op
# TODO(b/131304144): Enable lowering Case to SwitchN/Merge for graph mode.
# util.maybe_set_lowering_attr(case_op)
util.maybe_propagate_compile_time_consts_in_xla(case_op)
# Return identities for each output of the Case op, rather than the output of
# the Case op directly. This makes pruning work if the output of select_case()
# is fetched: the lowering pass converts the Case outputs into IdentityN
# outputs, which if fetched will cause all ops in the taken branch to be run
# (since it takes all merge ops as input). After lowering, each output
# identity op will end up with only the appropriate merge op as input.
# TODO(b/79984175): this doesn't have to be a tuple once we covert to the
# correct output structure
tensors = [array_ops.identity(t) for t in tensors]
# Prevent fetching since the variant outputs can't be fetched directly.
case_op.graph.prevent_fetching(case_op)
return func_graph_module.pack_sequence_as(branch_graphs[0].structured_outputs,
tensors)

View File

@ -3931,6 +3931,101 @@ def _case_helper(cond_fn,
return fn() return fn()
def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
branch_index):
"""Verifies input arguments for the case function.
Args:
branch_fns: Dict or list of pairs of an `int` and a callable which
returns a list of tensors.
default: Optional callable that returns a list of tensors.
branch_index: Optional int `Tensor`, which selects for the corresponding
pred_fn_pair.
Raises:
TypeError: If `branch_fns` is not a list/dictionary.
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
callables.
TypeError: If `fns[i]` is not callable for any i, or `default` is not
callable.
Returns:
branch_fns: validated list of callables for each branch (default last).
"""
if not isinstance(branch_index, ops.Tensor):
raise TypeError("branch_index must a Tensor, got {}".format(
type(branch_index)))
if not branch_index.dtype.is_integer:
raise TypeError("branch_index must an integer Tensor, got {}".format(
branch_index.dtype))
if not branch_fns:
raise ValueError("Must provide at least one item in branch_fns")
if not isinstance(branch_fns, (list, _basetuple, dict)):
raise TypeError("branch_fns must be a list, tuple, or dict")
if isinstance(branch_fns, dict):
branch_fns = branch_fns.items()
if all(callable(fn) for fn in branch_fns):
branch_fns = list(enumerate(branch_fns))
for key_fn_pair in branch_fns:
if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2:
raise TypeError("Each entry in branch_fns must be a 2-tuple")
key, branch_fn = key_fn_pair
if not isinstance(key, int):
raise TypeError("key must be a Python `int`, got {}".format(type(key)))
if not callable(branch_fn):
raise TypeError("fn for key {} must be callable.".format(key))
keys = [p[0] for p in branch_fns]
if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys):
raise ValueError(
"branch indices (keys) must form contiguous range of [0 to {}) but "
"found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys)))))
actions = [p[1] for p in sorted(branch_fns)]
if default is not None:
actions.append(default)
return actions
def _indexed_case_helper(branch_fns, default, branch_index, name):
"""Implementation of case that emits the n-way indexed Case op.
Args:
branch_fns: Dict or list of pairs of a boolean scalar tensor, and a
callable which returns a list of tensors.
default: Optional callable that returns a list of tensors.
branch_index: Optional int `Tensor`, which selects for the corresponding
pred_fn_pair.
name: A name for this operation (optional).
Returns:
The tensors returned by the pair whose key matched branch_index, or
those returned by `default` if none does.
Raises:
TypeError: If `branch_fns` is not a list/dictionary.
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
callables.
TypeError: If `fns[i]` is not callable for any i, or `default` is not
callable.
"""
branch_fns = _indexed_case_verify_and_canonicalize_args(
branch_fns, default, branch_index)
with ops.name_scope(name, "case", [branch_index]):
if context.executing_eagerly():
branch_index = array_ops.where(
math_ops.less(branch_index, 0)
| math_ops.greater_equal(branch_index, len(branch_fns)),
len(branch_fns) - 1, branch_index)
return branch_fns[int(branch_index)]()
return cond_v2.indexed_case(branch_index, branch_fns)
@tf_export("case") @tf_export("case")
def case(pred_fn_pairs, def case(pred_fn_pairs,
default=None, default=None,
@ -3939,6 +4034,8 @@ def case(pred_fn_pairs,
name="case"): name="case"):
"""Create a case operation. """Create a case operation.
See also `tf.switch_case`.
The `pred_fn_pairs` parameter is a dict or list of pairs of size N. The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
Each pair contains a boolean scalar tensor and a python callable that Each pair contains a boolean scalar tensor and a python callable that
creates the tensors to be returned if the boolean evaluates to True. creates the tensors to be returned if the boolean evaluates to True.
@ -4037,6 +4134,82 @@ def case(pred_fn_pairs,
strict=strict) strict=strict)
@tf_export("switch_case")
def switch_case(branch_index,
branch_fns,
default=None,
name="switch_case"):
"""Create a switch/case operation, i.e. an integer-indexed conditional.
See also `tf.case`.
This op can be substantially more efficient than `tf.case` when exactly one
branch will be selected. `tf.switch_case` is more like a C++ switch/case
statement than `tf.case`, which is more like an if/elif/elif/else chain.
The `branch_fns` parameter is either a dict from `int` to callables, or list
of (`int, callable) pairs, or simply a list of callables (in which case the
index is implicitly the key). The `branch_index` `Tensor` is used to select an
element in `branch_fns` with matching `int` key, falling back to `default`
if none match, or `max(keys)` if no `default` is provided. The keys must form
a contiguous set from `0` to `len(branch_fns) - 1`.
`tf.switch_case` supports nested structures as implemented in `tf.nest`. All
callables must return the same (possibly nested) value structure of lists,
tuples, and/or named tuples.
**Example:**
Pseudocode:
```c++
switch (branch_index) { // c-style switch
case 0: return 17;
case 1: return 31;
default: return -1;
}
```
or
```python
branches = {0: lambda: 17, 1: lambda: 31}
branches.get(branch_index, lambda: -1)()
```
Expressions:
```python
def f1(): return tf.constant(17)
def f2(): return tf.constant(31)
def f3(): return tf.constant(-1)
r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
# Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
```
Args:
branch_index: An int Tensor specifying which of `branch_fns` should be
executed.
branch_fns: A `dict` mapping `int`s to callables, or a `list` of
(`int, callable) pairs, or simply a list of callables (in which case the
index serves as the key). Each callable must return a matching structure
of tensors.
default: Optional callable that returns a structure of tensors.
name: A name for this operation (optional).
Returns:
The tensors returned by the callable identified by `branch_index`, or those
returned by `default` if no key matches and `default` was provided, or those
returned by the max-keyed `branch_fn` if no `default` is provided.
Raises:
TypeError: If `branch_fns` is not a list/dictionary.
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
callables.
TypeError: If `fns[i]` is not callable for any i, or `default` is not
callable.
"""
return _indexed_case_helper(branch_fns, default, branch_index, name)
class XLAControlFlowContext(ControlFlowContext): class XLAControlFlowContext(ControlFlowContext):
"""Base class for XLA and TPU control flow contexts.""" """Base class for XLA and TPU control flow contexts."""

View File

@ -24,6 +24,8 @@ import numpy as np
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import node_def_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -417,6 +419,28 @@ class CondTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
@test_util.enable_control_flow_v2
@test_util.run_in_graph_and_eager_modes
def testCond_gradient(self):
true_in, false_in = array_ops.constant(1.), array_ops.constant(5.)
with backprop.GradientTape(persistent=True) as tape:
tape.watch(true_in)
tape.watch(false_in)
cond_true = control_flow_ops.cond(
array_ops.constant(True), lambda: true_in**2., lambda: false_in**2.)
cond_false = control_flow_ops.cond(
array_ops.constant(False), lambda: true_in**2., lambda: false_in**2.)
grads_true = tape.gradient(
cond_true, [true_in, false_in], output_gradients=3.)
grads_false = tape.gradient(
cond_false, [true_in, false_in], output_gradients=3.)
self.assertEqual(3. * 2. * 1., self.evaluate(grads_true[0]))
self.assertEqual(None if context.executing_eagerly() else 0.,
self.evaluate(grads_true[1]))
self.assertEqual(3. * 2. * 5., self.evaluate(grads_false[1]))
self.assertEqual(None if context.executing_eagerly() else 0.,
self.evaluate(grads_false[0]))
class ContextTest(test_util.TensorFlowTestCase): class ContextTest(test_util.TensorFlowTestCase):
@ -908,6 +932,179 @@ class DataTypesTest(test_util.TensorFlowTestCase):
self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
@test_util.run_all_in_graph_and_eager_modes
class IndexedCaseTest(test_util.TensorFlowTestCase):
def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10, name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(nbranches)]
for bi in range(nbranches):
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(branch_index, branches)
self.assertEqual(bi * 10, self.evaluate(case_out))
def testCase(self):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(nbranches)]
for bi in 0, 2, 3:
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(branch_index, branches)
self.assertEqual(bi * 10., self.evaluate(case_out))
def testCase_withDefault(self):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(nbranches)]
for bi in -1, 2, nbranches:
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, default=make_func(6))
if bi < 0 or bi >= nbranches:
expected = 60.
else:
expected = bi * 10.
self.assertEqual(expected, self.evaluate(case_out))
def testCase_dictWithDefault(self):
nbranches = 5
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = {i: make_func(i) for i in range(nbranches)}
for bi in -1, 0, 3, nbranches:
branch_index = array_ops.placeholder_with_default(bi, [])
case_out = control_flow_ops.switch_case(
branch_index, branches, default=make_func(6))
if bi < 0 or bi >= nbranches:
expected = 60.
else:
expected = bi * 10.
self.assertEqual(expected, self.evaluate(case_out))
def testCase_gradient(self):
nbranches = 5
inputs = [
array_ops.constant(float(bi), name="br{}_in".format(bi))
for bi in range(nbranches)
]
def make_func(bi):
return lambda: inputs[bi]**2.
branches = {bi: make_func(bi) for bi in range(nbranches)}
for bi in -1, 1, 4, nbranches:
branch_index = array_ops.placeholder_with_default(bi, [])
with backprop.GradientTape() as tape:
for x in inputs:
tape.watch(x)
case_out = control_flow_ops.switch_case(branch_index, branches)
out_grad = 3.
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
self.assertEqual(expected, self.evaluate(actual))
def testCase_gradient_diffShapedIntermediates(self):
nbranches = 5
inputs = [
array_ops.constant(
float(bi), shape=[bi + 1], name="br{}_in".format(bi))
for bi in range(nbranches)
]
def make_func(bi):
def f():
x = inputs[bi]**2 * inputs[bi][:bi + 1, None]
return math_ops.reduce_sum(x)
return f
branches = {bi: make_func(bi) for bi in range(nbranches)}
for bi in -1, 2, nbranches:
branch_index = array_ops.placeholder_with_default(bi, [])
with backprop.GradientTape() as tape:
for x in inputs:
tape.watch(x)
case_out = control_flow_ops.switch_case(branch_index, branches)
out_grad = 3.
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
expected_grads = []
for input_idx in range(nbranches):
if used_bi == input_idx:
with backprop.GradientTape() as tape:
tape.watch(inputs[used_bi])
y = make_func(used_bi)()
expected_grads.append(
self.evaluate(
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
else:
expected_grads.append(None if context.executing_eagerly() else [0.] *
(input_idx + 1))
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
if expected is None:
self.assertIsNone(actual)
else:
self.assertAllEqual(expected, self.evaluate(actual))
def testCase_validateIndicesContiguous(self):
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = {i: make_func(i) for i in range(0, 6, 2)}
with self.assertRaisesRegexp(ValueError, "must form contiguous"):
control_flow_ops.switch_case(array_ops.constant(0), branches)
def testCase_validateIndicesDup(self):
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = [(i, make_func(i)) for i in range(0, 6, 2)]
branches.append((0, make_func(7)))
with self.assertRaisesRegexp(ValueError, "must form contiguous"):
control_flow_ops.switch_case(array_ops.constant(0), branches)
def testCase_validateBranchIndex(self):
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = {i: make_func(i) for i in range(5)}
with self.assertRaisesRegexp(TypeError, "branch_index.*Tensor"):
control_flow_ops.switch_case(1, branches)
def testCase_validateNonIntKeys(self):
def make_func(bi):
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
branches = {array_ops.constant(i): make_func(i) for i in range(5)}
with self.assertRaisesRegexp(TypeError, "must be a Python `int`"):
control_flow_ops.switch_case(array_ops.constant(1), branches)
class CaseTest(test_util.TensorFlowTestCase): class CaseTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1

View File

@ -2256,6 +2256,10 @@ tf_module {
name: "svd" name: "svd"
argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], " argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
} }
member_method {
name: "switch_case"
argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'switch_case\'], "
}
member_method { member_method {
name: "tables_initializer" name: "tables_initializer"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], " argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "

View File

@ -988,6 +988,10 @@ tf_module {
name: "subtract" name: "subtract"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "switch_case"
argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'switch_case\'], "
}
member_method { member_method {
name: "tan" name: "tan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "