Adds tf.switch_case, an n-way switch/case statement, to TF python. (C++ support was added previously, in the "Case" op+kernel, XLA)
Adds XLA support for Case with TensorArrays. Propagate compile time constants into the branch functions of Case in XLA. This enables using ops like StridedSlice which require compile time constants in the branch functions. PiperOrigin-RevId: 248466828
This commit is contained in:
parent
7bb7116c77
commit
e224546ee5
@ -19,12 +19,16 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.compiler.xla import xla
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -33,6 +37,7 @@ from tensorflow.python.platform import test
|
|||||||
class CondTest(xla_test.XLATestCase):
|
class CondTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
def testCondAndTensorArrayInDefun(self):
|
def testCondAndTensorArrayInDefun(self):
|
||||||
|
# TODO(b/132430685): Make test more useful. Also b/129396295, b/127846988
|
||||||
with self.session(), self.test_scope():
|
with self.session(), self.test_scope():
|
||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
xla_context.Enter()
|
xla_context.Enter()
|
||||||
@ -47,7 +52,7 @@ class CondTest(xla_test.XLATestCase):
|
|||||||
return output.stack()
|
return output.stack()
|
||||||
|
|
||||||
output_t = f()
|
output_t = f()
|
||||||
self.assertAllEqual(self.evaluate(output_t), [5.])
|
self.assertAllEqual([5.], self.evaluate(output_t))
|
||||||
|
|
||||||
xla_context.Exit()
|
xla_context.Exit()
|
||||||
|
|
||||||
@ -71,11 +76,178 @@ class CondTest(xla_test.XLATestCase):
|
|||||||
output = control_flow_ops.cond(
|
output = control_flow_ops.cond(
|
||||||
constant_op.constant(True), if_true, if_false)
|
constant_op.constant(True), if_true, if_false)
|
||||||
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(1.,
|
||||||
sess.run(output, feed_dict={
|
sess.run(output, feed_dict={
|
||||||
x: [0., 1., 2.],
|
x: [0., 1., 2.],
|
||||||
p: 1
|
p: 1
|
||||||
}), 1.)
|
}))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testCondConstPropagation_xlaCompile(self):
|
||||||
|
self.skipTest("b/132430685")
|
||||||
|
with self.session(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
x = array_ops.placeholder_with_default([0., 1., 2.], shape=[3])
|
||||||
|
p = constant_op.constant(1)
|
||||||
|
|
||||||
|
def f():
|
||||||
|
# TODO(b/129021699): Wrapping this in a tf.function does not work.
|
||||||
|
def if_true():
|
||||||
|
# This emits a StridedSlice op which expects the index to be a
|
||||||
|
# compile-time const.
|
||||||
|
return x[p]
|
||||||
|
|
||||||
|
def if_false():
|
||||||
|
return 5.
|
||||||
|
|
||||||
|
return control_flow_ops.cond(
|
||||||
|
constant_op.constant(True), if_true, if_false)
|
||||||
|
|
||||||
|
output = xla.compile(f)
|
||||||
|
|
||||||
|
self.assertAllEqual(1., self.evaluate(output))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testCondConstPropagation_errorMsg(self):
|
||||||
|
self.skipTest("b/132430685")
|
||||||
|
with self.session() as sess, self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
x = array_ops.placeholder(dtypes.float32)
|
||||||
|
p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32)
|
||||||
|
|
||||||
|
# TODO(b/129021699): Wrapping this in a tf.function does not work.
|
||||||
|
def if_true():
|
||||||
|
# This emits a StridedSlice op which expects the index to be a
|
||||||
|
# compile-time const.
|
||||||
|
return x[:p]
|
||||||
|
|
||||||
|
def if_false():
|
||||||
|
return array_ops.fill([p], 5.)
|
||||||
|
|
||||||
|
output = control_flow_ops.cond(
|
||||||
|
constant_op.constant(True), if_true, if_false)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
"must be a compile-time constant"):
|
||||||
|
sess.run(
|
||||||
|
output, feed_dict={
|
||||||
|
x: [0., 1., 2.],
|
||||||
|
})
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testCondConstPropagation_errorMsg_xlaCompile(self):
|
||||||
|
with self.session() as sess, self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
x = array_ops.placeholder(dtypes.float32)
|
||||||
|
p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32)
|
||||||
|
condition = math_ops.cast(
|
||||||
|
random_ops.random_uniform([], minval=0, maxval=2, dtype=dtypes.int32),
|
||||||
|
dtypes.bool)
|
||||||
|
|
||||||
|
def f():
|
||||||
|
# TODO(b/129021699): Wrapping this in a tf.function does not work.
|
||||||
|
def if_true():
|
||||||
|
# This emits a StridedSlice op which expects the index to be a
|
||||||
|
# compile-time const.
|
||||||
|
return x[:p]
|
||||||
|
|
||||||
|
def if_false():
|
||||||
|
return array_ops.fill([p], 5.)
|
||||||
|
|
||||||
|
return control_flow_ops.cond(condition, if_true, if_false)
|
||||||
|
|
||||||
|
output = xla.compile(f)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
"must be a compile-time constant"):
|
||||||
|
sess.run(
|
||||||
|
output, feed_dict={
|
||||||
|
x: [0., 1., 2.],
|
||||||
|
})
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testSwitchCaseAndTensorArrayInDefun(self):
|
||||||
|
self.skipTest("b/127846988")
|
||||||
|
with self.session(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
@function.defun
|
||||||
|
def f():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
|
output = control_flow_ops.switch_case(
|
||||||
|
constant_op.constant(1), {
|
||||||
|
0: lambda: ta.write(0, 5.),
|
||||||
|
1: lambda: ta.write(0, 10.),
|
||||||
|
2: lambda: ta.write(0, 15.),
|
||||||
|
})
|
||||||
|
|
||||||
|
return output.stack()
|
||||||
|
|
||||||
|
output_t = f()
|
||||||
|
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testSwitchCaseConstPropagation(self):
|
||||||
|
self.skipTest("b/127846988")
|
||||||
|
with self.session() as sess, self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
x = array_ops.placeholder(dtypes.float32)
|
||||||
|
p = array_ops.placeholder(dtypes.int32)
|
||||||
|
|
||||||
|
def branch0():
|
||||||
|
return 5.
|
||||||
|
|
||||||
|
def branch1():
|
||||||
|
return 15.
|
||||||
|
|
||||||
|
# TODO(b/129021699): Wrapping this in a tf.function does not work.
|
||||||
|
def branch2():
|
||||||
|
# This emits a StridedSlice op which expects the index to be a
|
||||||
|
# compile-time const.
|
||||||
|
return x[p]
|
||||||
|
|
||||||
|
output = control_flow_ops.switch_case(
|
||||||
|
constant_op.constant(2), {
|
||||||
|
0: branch0,
|
||||||
|
1: branch1,
|
||||||
|
2: branch2,
|
||||||
|
})
|
||||||
|
|
||||||
|
self.assertAllEqual(7.,
|
||||||
|
sess.run(output, feed_dict={
|
||||||
|
x: [0., 1., 7.],
|
||||||
|
p: 2,
|
||||||
|
}))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testCondNoInputs(self):
|
||||||
|
"""Verifies against `Failed precondition: Expected one input shape`."""
|
||||||
|
|
||||||
|
with self.session(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
for pred in True, False:
|
||||||
|
cond_out = control_flow_ops.cond(
|
||||||
|
array_ops.placeholder_with_default(pred, []),
|
||||||
|
lambda: constant_op.constant(2.),
|
||||||
|
lambda: constant_op.constant(1.))
|
||||||
|
self.assertEqual(int(pred) + 1., self.evaluate(cond_out))
|
||||||
|
|
||||||
xla_context.Exit()
|
xla_context.Exit()
|
||||||
|
|
||||||
|
@ -125,6 +125,8 @@ Status BackwardsConstAnalysis(const Graph& g,
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node,
|
Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node,
|
||||||
StringPiece func_attr_name, const FunctionBody** fbody) {
|
StringPiece func_attr_name, const FunctionBody** fbody) {
|
||||||
NameAttrList name_attr_list;
|
NameAttrList name_attr_list;
|
||||||
@ -136,6 +138,50 @@ Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, const Node* node,
|
||||||
|
StringPiece func_list_attr_name,
|
||||||
|
std::vector<const FunctionBody*>* fbodies) {
|
||||||
|
std::vector<NameAttrList> name_attr_lists;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetNodeAttr(node->def(), func_list_attr_name, &name_attr_lists));
|
||||||
|
for (const NameAttrList& name_attr_list : name_attr_lists) {
|
||||||
|
FunctionLibraryRuntime::Handle func_handle;
|
||||||
|
TF_RETURN_IF_ERROR(flib_runtime->Instantiate(
|
||||||
|
name_attr_list.name(), AttrSlice(&name_attr_list.attr()),
|
||||||
|
&func_handle));
|
||||||
|
fbodies->push_back(flib_runtime->GetFunctionBody(func_handle));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CondConstInputIndices(
|
||||||
|
absl::Span<const FunctionBody* const> branch_bodies,
|
||||||
|
std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime) {
|
||||||
|
TF_RET_CHECK(!branch_bodies.empty());
|
||||||
|
TF_RET_CHECK(branch_bodies[0] != nullptr);
|
||||||
|
int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size();
|
||||||
|
// Stores indices of the "branch function" inputs that are expected to be
|
||||||
|
// compile time constants.
|
||||||
|
std::vector<bool> compile_time_const_arg_indices(num_inputs);
|
||||||
|
for (auto fbody : branch_bodies) {
|
||||||
|
TF_RET_CHECK(fbody != nullptr);
|
||||||
|
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
||||||
|
*(fbody->graph), &compile_time_const_arg_indices,
|
||||||
|
/*compile_time_const_nodes=*/nullptr, flib_runtime));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < compile_time_const_arg_indices.size(); i++) {
|
||||||
|
if (compile_time_const_arg_indices[i]) {
|
||||||
|
// The 0th input is the pred or branch index, which is not passed to the
|
||||||
|
// branches. So the i'th input of a branch function corresponds to the
|
||||||
|
// i + 1'th input of the If/Case op.
|
||||||
|
const_input_idxs->push_back(i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Status GetCompileTimeConstInputs(const Node* node,
|
Status GetCompileTimeConstInputs(const Node* node,
|
||||||
std::vector<int>* const_input_idxs,
|
std::vector<int>* const_input_idxs,
|
||||||
FunctionLibraryRuntime* flib_runtime) {
|
FunctionLibraryRuntime* flib_runtime) {
|
||||||
@ -179,6 +225,7 @@ Status GetCompileTimeConstInputs(const Node* node,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
} else if (node->type_string() == "If") {
|
} else if (node->type_string() == "If") {
|
||||||
const FunctionBody* fthen = nullptr;
|
const FunctionBody* fthen = nullptr;
|
||||||
const FunctionBody* felse = nullptr;
|
const FunctionBody* felse = nullptr;
|
||||||
@ -186,31 +233,17 @@ Status GetCompileTimeConstInputs(const Node* node,
|
|||||||
GetFunctionBody(flib_runtime, node, "then_branch", &fthen));
|
GetFunctionBody(flib_runtime, node, "then_branch", &fthen));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GetFunctionBody(flib_runtime, node, "else_branch", &felse));
|
GetFunctionBody(flib_runtime, node, "else_branch", &felse));
|
||||||
TF_RET_CHECK(fthen);
|
return CondConstInputIndices({fthen, felse}, const_input_idxs,
|
||||||
TF_RET_CHECK(felse);
|
flib_runtime);
|
||||||
int num_inputs = fthen->fdef.signature().input_arg_size();
|
} else if (node->type_string() == "Case") {
|
||||||
// Stores indices of the "branch function" inputs that are expected to be
|
std::vector<const FunctionBody*> branch_bodies;
|
||||||
// compile time constants.
|
TF_RETURN_IF_ERROR(
|
||||||
std::vector<bool> compile_time_const_arg_indices(num_inputs);
|
GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
|
||||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
|
||||||
*(fthen->graph), &compile_time_const_arg_indices,
|
|
||||||
/*compile_time_const_nodes=*/nullptr, flib_runtime));
|
|
||||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
|
||||||
*(felse->graph), &compile_time_const_arg_indices,
|
|
||||||
/*compile_time_const_nodes=*/nullptr, flib_runtime));
|
|
||||||
for (int i = 0; i < compile_time_const_arg_indices.size(); i++) {
|
|
||||||
if (compile_time_const_arg_indices[i]) {
|
|
||||||
// The 0th input is the loop condition which is not passed to the
|
|
||||||
// branches. So the i'th input of a branch function corresponds to
|
|
||||||
// i + 1'th input of the If op.
|
|
||||||
const_input_idxs->push_back(i + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(),
|
return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(),
|
||||||
const_input_idxs);
|
const_input_idxs);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -311,6 +311,7 @@ tf_kernel_library(
|
|||||||
srcs = ["case_op.cc"],
|
srcs = ["case_op.cc"],
|
||||||
hdrs = ["case_op.h"],
|
hdrs = ["case_op.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":if_while_utils",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/kernels/case_op.h"
|
#include "tensorflow/compiler/tf2xla/kernels/case_op.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
@ -34,10 +35,41 @@ XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
|||||||
} else {
|
} else {
|
||||||
has_token_input_output_ = !token_input_nodes_.empty();
|
has_token_input_output_ = !token_input_nodes_.empty();
|
||||||
}
|
}
|
||||||
|
if (ctx->HasAttr(kPropagateCompileTimeConsts)) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts,
|
||||||
|
&propagate_compile_time_consts_));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status ConvertCompileTimeConstArgumentsToConst(
|
||||||
|
XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args) {
|
||||||
|
for (int i = 0; i < args->size(); i++) {
|
||||||
|
XlaCompiler::Argument& arg = (*args)[i];
|
||||||
|
const XlaExpression& expression = ctx->InputExpression(i + 1);
|
||||||
|
// If the input tensor is a compile time constant build a kConstant type
|
||||||
|
// argument.
|
||||||
|
if (arg.kind == XlaCompiler::Argument::kParameter) {
|
||||||
|
// NOTE: We can not simply check that this is Kind::kConstant because
|
||||||
|
// this could be the output of a MetadataOnly op e.g. Size.
|
||||||
|
xla::StatusOr<absl::optional<Tensor>> maybe_constant =
|
||||||
|
expression.ResolveConstant(ctx->compiler()->client());
|
||||||
|
if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) {
|
||||||
|
arg.kind = XlaCompiler::Argument::kConstant;
|
||||||
|
arg.type = expression.dtype();
|
||||||
|
arg.constant_value = std::move(maybe_constant.ValueOrDie().value());
|
||||||
|
arg.shape = expression.GetShape().ValueOrDie();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// TODO(b/35949885): There is duplication here with the handling of the
|
// TODO(b/35949885): There is duplication here with the handling of the
|
||||||
// while_op. Refactor the common code out/rework.
|
// while_op/if_op. Refactor the common code out/rework.
|
||||||
void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||||
xla::XlaBuilder* b = ctx->builder();
|
xla::XlaBuilder* b = ctx->builder();
|
||||||
int num_branches = branches_.size();
|
int num_branches = branches_.size();
|
||||||
@ -84,12 +116,30 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
} else {
|
} else {
|
||||||
arg.kind = XlaCompiler::Argument::kParameter;
|
arg.kind = XlaCompiler::Argument::kParameter;
|
||||||
arg.type = input_types_[i];
|
arg.type = input_types_[i];
|
||||||
arg.shape = ctx->InputShape(i + 1);
|
// Use the xla::Shape for the input instead of ctx->InputShape. This is
|
||||||
|
// necessary for forwarding shapes of DT_VARIANTs, e.g. TensorLists.
|
||||||
|
auto shape_or = ctx->builder()->GetShape(ctx->Input(i + 1));
|
||||||
|
OP_REQUIRES_OK(ctx, shape_or.status());
|
||||||
|
arg.shape = shape_or.ValueOrDie();
|
||||||
VLOG(2) << "Arg type: " << DataTypeString(arg.type)
|
VLOG(2) << "Arg type: " << DataTypeString(arg.type)
|
||||||
<< " shape: " << arg.HumanString();
|
<< " shape: " << arg.HumanString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (propagate_compile_time_consts_) {
|
||||||
|
// Replaces `kParameter` type args in `arguments` with `kConstant` if
|
||||||
|
// the op input corresponding to that arg is a compile-time const. This
|
||||||
|
// is necessary to propagate compile time consts to ops in the branch
|
||||||
|
// functions.
|
||||||
|
// Note: Propagating "all" compile-time constants may not be necessary. We
|
||||||
|
// should ideally only propagate consts which are required to be compile
|
||||||
|
// time constants in the branch functions. But that would require calling
|
||||||
|
// BackwardsConstAnalysis here which would be expensive. However, if we
|
||||||
|
// start hitting memory issues we should revisit this.
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ConvertCompileTimeConstArgumentsToConst(ctx, &arguments));
|
||||||
|
}
|
||||||
|
|
||||||
// Compile each branch of the conditional.
|
// Compile each branch of the conditional.
|
||||||
XlaCompiler::CompileOptions options;
|
XlaCompiler::CompileOptions options;
|
||||||
options.use_tuple_arg = true;
|
options.use_tuple_arg = true;
|
||||||
@ -158,8 +208,6 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
}
|
}
|
||||||
OP_REQUIRES(ctx, branch_input_shape.IsTuple(),
|
OP_REQUIRES(ctx, branch_input_shape.IsTuple(),
|
||||||
errors::FailedPrecondition("Expected tuple shape"));
|
errors::FailedPrecondition("Expected tuple shape"));
|
||||||
OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1,
|
|
||||||
errors::FailedPrecondition("Expected one input shape"));
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx,
|
ctx,
|
||||||
xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape),
|
xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape),
|
||||||
@ -227,7 +275,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
|
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
|
||||||
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
|
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
|
||||||
} else {
|
} else {
|
||||||
inputs[i] = ctx->Input(i + 1);
|
inputs[i] = ctx->Input(input_num);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto input_tuple = xla::Tuple(b, inputs);
|
auto input_tuple = xla::Tuple(b, inputs);
|
||||||
@ -292,6 +340,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
VLOG(1) << "Done building Case";
|
VLOG(1) << "Done building Case";
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("Case").AllowResourceTypes(), XlaCaseOp);
|
REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(),
|
||||||
|
XlaCaseOp);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -55,6 +55,10 @@ class XlaCaseOp : public XlaOpKernel {
|
|||||||
DataTypeVector output_types_;
|
DataTypeVector output_types_;
|
||||||
bool has_token_input_output_;
|
bool has_token_input_output_;
|
||||||
std::vector<string> token_input_nodes_;
|
std::vector<string> token_input_nodes_;
|
||||||
|
// Whether to propagate compile time consts into the cond branches.
|
||||||
|
// This is not supported by default now since it may cause HBM memory
|
||||||
|
// overheads.
|
||||||
|
bool propagate_compile_time_consts_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -776,7 +776,7 @@ Status XlaCompiler::BuildArguments(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (input_to_args->empty()) {
|
if (input_to_args->empty() && !use_tuple_arg) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -829,8 +829,9 @@ Status XlaCompiler::BuildArguments(
|
|||||||
xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
|
xla::ShapeUtil::GetLeafCount(arg_shapes[i]),
|
||||||
args[input_to_args->at(i)].is_same_data_across_replicas);
|
args[input_to_args->at(i)].is_same_data_across_replicas);
|
||||||
}
|
}
|
||||||
xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
|
xla::XlaScopedShardingAssignment assign_tuple_sharding(
|
||||||
tuple_sharding);
|
builder, input_to_args->empty() ? absl::optional<xla::OpSharding>()
|
||||||
|
: tuple_sharding);
|
||||||
tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
|
tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple",
|
||||||
is_same_across_replicas);
|
is_same_across_replicas);
|
||||||
} else {
|
} else {
|
||||||
|
@ -3747,6 +3747,7 @@ cuda_py_test(
|
|||||||
":while_v2",
|
":while_v2",
|
||||||
"//tensorflow/python/eager:def_function",
|
"//tensorflow/python/eager:def_function",
|
||||||
],
|
],
|
||||||
|
shard_count = 2,
|
||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2593,8 +2593,15 @@ class Operation(object):
|
|||||||
func = attr_value_pb2.NameAttrList(name=func_name)
|
func = attr_value_pb2.NameAttrList(name=func_name)
|
||||||
self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
|
self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
|
||||||
|
|
||||||
|
def _set_func_list_attr(self, attr_name, func_names):
|
||||||
|
"""Private method used to set a list(function) attribute in the node_def."""
|
||||||
|
funcs = [attr_value_pb2.NameAttrList(name=func_name)
|
||||||
|
for func_name in func_names]
|
||||||
|
funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs)
|
||||||
|
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list))
|
||||||
|
|
||||||
def _set_type_list_attr(self, attr_name, types):
|
def _set_type_list_attr(self, attr_name, types):
|
||||||
"""Private method used to set a function attribute in the node_def."""
|
"""Private method used to set a list(type) attribute in the node_def."""
|
||||||
if not types:
|
if not types:
|
||||||
return
|
return
|
||||||
if isinstance(types[0], dtypes.DType):
|
if isinstance(types[0], dtypes.DType):
|
||||||
@ -2603,7 +2610,7 @@ class Operation(object):
|
|||||||
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
|
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
|
||||||
|
|
||||||
def _set_shape_list_attr(self, attr_name, shapes):
|
def _set_shape_list_attr(self, attr_name, shapes):
|
||||||
"""Private method used to set a function attribute in the node_def."""
|
"""Private method used to set a list(shape) attribute in the node_def."""
|
||||||
shapes = [s.as_proto() for s in shapes]
|
shapes = [s.as_proto() for s in shapes]
|
||||||
shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
|
shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
|
||||||
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
|
self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
|
||||||
|
@ -776,6 +776,33 @@ class ControlFlowTest(test.TestCase):
|
|||||||
"Tensor true_branch:0 in true_fn is accessed from false_fn."):
|
"Tensor true_branch:0 in true_fn is accessed from false_fn."):
|
||||||
f()
|
f()
|
||||||
|
|
||||||
|
def testSwitchCaseAccessBranch1TensorInBranch4Raises(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f():
|
||||||
|
c = constant_op.constant(1.)
|
||||||
|
inputs = {"c": c}
|
||||||
|
|
||||||
|
def br1_fn(inputs):
|
||||||
|
inputs["c"] = array_ops.identity(inputs["c"], name="br1_identity")
|
||||||
|
return inputs["c"]
|
||||||
|
|
||||||
|
def br4_fn(inputs):
|
||||||
|
return array_ops.identity(inputs["c"])
|
||||||
|
|
||||||
|
def other_fn():
|
||||||
|
return array_ops.identity(c)
|
||||||
|
|
||||||
|
return control_flow_ops.switch_case(
|
||||||
|
constant_op.constant(2),
|
||||||
|
[other_fn, lambda: br1_fn(inputs), other_fn, other_fn,
|
||||||
|
lambda: br4_fn(inputs)])
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError,
|
||||||
|
"Tensor br1_identity:0 in branch 1 is accessed from branch 4."):
|
||||||
|
f()
|
||||||
|
|
||||||
def testCondListOutput(self):
|
def testCondListOutput(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
x = constant_op.constant(10)
|
x = constant_op.constant(10)
|
||||||
|
@ -47,6 +47,9 @@ from tensorflow.python.util import nest
|
|||||||
# readability.
|
# readability.
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
|
_COND = 1
|
||||||
|
_CASE = 2
|
||||||
|
|
||||||
|
|
||||||
def cond_v2(pred, true_fn, false_fn, name="cond"):
|
def cond_v2(pred, true_fn, false_fn, name="cond"):
|
||||||
"""Like tf.cond, except emits a single If op."""
|
"""Like tf.cond, except emits a single If op."""
|
||||||
@ -80,7 +83,7 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
|
|||||||
add_control_dependencies=add_control_dependencies,
|
add_control_dependencies=add_control_dependencies,
|
||||||
op_return_value=pred)
|
op_return_value=pred)
|
||||||
|
|
||||||
verify_captures(true_graph, false_graph)
|
verify_captures(_COND, [true_graph, false_graph])
|
||||||
return _build_cond(pred, true_graph, false_graph,
|
return _build_cond(pred, true_graph, false_graph,
|
||||||
true_graph.external_captures,
|
true_graph.external_captures,
|
||||||
false_graph.external_captures,
|
false_graph.external_captures,
|
||||||
@ -106,8 +109,7 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
false_grad_graph = _create_grad_func(
|
false_grad_graph = _create_grad_func(
|
||||||
false_graph, grads, util.unique_grad_fn_name(false_graph.name))
|
false_graph, grads, util.unique_grad_fn_name(false_graph.name))
|
||||||
|
|
||||||
if (true_grad_graph.if_op_needs_rewrite or
|
if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite):
|
||||||
false_grad_graph.if_op_needs_rewrite):
|
|
||||||
# Modify 'op' to output the intermediates needed by the grad functions. Note
|
# Modify 'op' to output the intermediates needed by the grad functions. Note
|
||||||
# that all needed intermediates are wrapped in optionals. Each optional
|
# that all needed intermediates are wrapped in optionals. Each optional
|
||||||
# intermediate output will have a value iff its corresponding branch is
|
# intermediate output will have a value iff its corresponding branch is
|
||||||
@ -122,18 +124,18 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
true_intermediates = true_grad_graph.xla_intermediates
|
true_intermediates = true_grad_graph.xla_intermediates
|
||||||
false_intermediates = false_grad_graph.xla_intermediates
|
false_intermediates = false_grad_graph.xla_intermediates
|
||||||
extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
|
extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
|
||||||
true_graph, false_graph, true_intermediates, false_intermediates)
|
[true_graph, false_graph], [true_intermediates, false_intermediates])
|
||||||
else:
|
else:
|
||||||
true_intermediates = true_grad_graph.wrapped_intermediates
|
true_intermediates = true_grad_graph.wrapped_intermediates
|
||||||
false_intermediates = false_grad_graph.wrapped_intermediates
|
false_intermediates = false_grad_graph.wrapped_intermediates
|
||||||
# Make outputs match by adding none optionals.
|
# Make outputs match by adding none optionals.
|
||||||
extra_true_outputs, extra_false_outputs = _make_intermediates_match(
|
extra_true_outputs, extra_false_outputs = _make_intermediates_match(
|
||||||
true_graph, false_graph, true_intermediates, false_intermediates)
|
[true_graph, false_graph], [true_intermediates, false_intermediates])
|
||||||
|
|
||||||
true_graph.outputs.extend(extra_true_outputs)
|
true_graph.outputs.extend(extra_true_outputs)
|
||||||
false_graph.outputs.extend(extra_false_outputs)
|
false_graph.outputs.extend(extra_false_outputs)
|
||||||
# TODO(skyewm): indicate it's an internal bug if this fails.
|
# TODO(skyewm): indicate it's an internal bug if this fails.
|
||||||
_check_same_outputs(true_graph, false_graph)
|
_check_same_outputs(_COND, [true_graph, false_graph])
|
||||||
|
|
||||||
true_graph.name += "_rewritten"
|
true_graph.name += "_rewritten"
|
||||||
false_graph.name += "_rewritten"
|
false_graph.name += "_rewritten"
|
||||||
@ -153,7 +155,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
|
false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
|
||||||
|
|
||||||
# This modifies true_grad_graph and false_grad_graph.
|
# This modifies true_grad_graph and false_grad_graph.
|
||||||
_make_output_composite_tensors_match(true_grad_graph, false_grad_graph)
|
_make_output_composite_tensors_match(_COND,
|
||||||
|
[true_grad_graph, false_grad_graph])
|
||||||
|
|
||||||
outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph,
|
outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph,
|
||||||
true_grad_inputs, false_grad_inputs)
|
true_grad_inputs, false_grad_inputs)
|
||||||
@ -185,13 +188,13 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
|
|||||||
A list of Tensors which are the outputs of the If op. Does not include added
|
A list of Tensors which are the outputs of the If op. Does not include added
|
||||||
intermediate outputs.
|
intermediate outputs.
|
||||||
"""
|
"""
|
||||||
_make_indexed_slices_indices_types_match(true_graph, false_graph)
|
_make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
|
||||||
_check_same_outputs(true_graph, false_graph)
|
_check_same_outputs(_COND, [true_graph, false_graph])
|
||||||
|
|
||||||
# Add inputs to true_graph and false_graph to make them match. Note that
|
# Add inputs to true_graph and false_graph to make them match. Note that
|
||||||
# this modifies true_graph and false_graph.
|
# this modifies true_graph and false_graph.
|
||||||
cond_inputs = _make_inputs_match(true_graph, false_graph,
|
cond_inputs = _make_inputs_match([true_graph, false_graph],
|
||||||
true_inputs, false_inputs)
|
[true_inputs, false_inputs])
|
||||||
|
|
||||||
# Create the If op.
|
# Create the If op.
|
||||||
with ops.control_dependencies(
|
with ops.control_dependencies(
|
||||||
@ -226,38 +229,45 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
|
|||||||
tensors)
|
tensors)
|
||||||
|
|
||||||
|
|
||||||
def _get_func_graphs(if_op):
|
def _get_func_graphs(op):
|
||||||
"""Returns `FuncGraph`s for the input op branches.
|
"""Returns `FuncGraph`s for the input op branches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
if_op: The _If Operation.
|
op: The If or Case Operation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A 2-tuple of the `FuncGraph`s of the then_branch and else_branch.
|
A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches
|
||||||
|
for Case).
|
||||||
"""
|
"""
|
||||||
def _get_func_graph_for_branch(branch_name):
|
|
||||||
|
def _get_func_graph_for_branch(name_attr_list):
|
||||||
"""Generates and returns a FuncGraph for the given branch."""
|
"""Generates and returns a FuncGraph for the given branch."""
|
||||||
inputs = if_op.inputs[1:] # First input is pred.
|
inputs = op.inputs[1:] # First input is pred.
|
||||||
input_shapes = [t.shape for t in inputs]
|
input_shapes = [t.shape for t in inputs]
|
||||||
func_name = if_op.get_attr(branch_name).name
|
fdef = op.graph._get_function(name_attr_list.name).definition
|
||||||
fdef = if_op.graph._get_function(func_name).definition
|
# `op.graph` may not be the same as `ops.get_default_graph()` e.g.
|
||||||
# `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
|
|
||||||
# in the case of nested if ops or when the gradient is being computed
|
# in the case of nested if ops or when the gradient is being computed
|
||||||
# from inside a Defun. We build the `func_graph` with `if_op.graph` as its
|
# from inside a Defun. We build the `func_graph` with `op.graph` as its
|
||||||
# `outer_graph`. This resembles how the `FuncGraph` was built in the
|
# `outer_graph`. This resembles how the `FuncGraph` was built in the
|
||||||
# forward pass. We need this so that we can resolve references to tensors
|
# forward pass. We need this so that we can resolve references to tensors
|
||||||
# in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
|
# in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
|
||||||
with if_op.graph.as_default():
|
with op.graph.as_default():
|
||||||
func_graph = function_def_to_graph.function_def_to_graph(
|
func_graph = function_def_to_graph.function_def_to_graph(
|
||||||
fdef, input_shapes)
|
fdef, input_shapes)
|
||||||
func_graph.captures = collections.OrderedDict(zip(inputs,
|
func_graph.captures = collections.OrderedDict(zip(inputs,
|
||||||
func_graph.inputs))
|
func_graph.inputs))
|
||||||
# Set the if op so that the gradient code can use it.
|
# Link the op so that the gradient code can use it.
|
||||||
func_graph._if = if_op
|
func_graph._forward_cond = op
|
||||||
return func_graph
|
return func_graph
|
||||||
|
|
||||||
return (_get_func_graph_for_branch("then_branch"),
|
if op.type == "If":
|
||||||
_get_func_graph_for_branch("else_branch"))
|
return (_get_func_graph_for_branch(op.get_attr("then_branch")),
|
||||||
|
_get_func_graph_for_branch(op.get_attr("else_branch")))
|
||||||
|
elif op.type == "Case":
|
||||||
|
return [_get_func_graph_for_branch(branch_fn)
|
||||||
|
for branch_fn in op.get_attr("branches")]
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported op type: {}".format(op.type))
|
||||||
|
|
||||||
|
|
||||||
def _grad_fn(func_graph, grads):
|
def _grad_fn(func_graph, grads):
|
||||||
@ -348,7 +358,7 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
|
|||||||
# to If op outputs. So we get the outer tensor corresponding to those
|
# to If op outputs. So we get the outer tensor corresponding to those
|
||||||
# from the list of `external_captures`.
|
# from the list of `external_captures`.
|
||||||
try:
|
try:
|
||||||
t = t.graph._if.outputs[t.graph.outputs.index(t)]
|
t = t.graph._forward_cond.outputs[t.graph.outputs.index(t)]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
index = t.graph.internal_captures.index(t)
|
index = t.graph.internal_captures.index(t)
|
||||||
t = t.graph.external_captures[index]
|
t = t.graph.external_captures[index]
|
||||||
@ -373,211 +383,191 @@ def _get_intermediates(func_graph):
|
|||||||
return intermediates
|
return intermediates
|
||||||
|
|
||||||
|
|
||||||
def _separate_unique_inputs(true_inputs, false_inputs):
|
def _make_intermediates_match(branch_graphs, branch_optionals):
|
||||||
"""Separates tensors appearing only in true_inputs or false_inputs, or both.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
true_inputs: list of Tensors
|
|
||||||
false_inputs: list of Tensors
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Three lists of Tensors:
|
|
||||||
1. The tensors that appear in both true_inputs and false_inputs
|
|
||||||
2. The tensors that only appear in true_inputs
|
|
||||||
3. The tensors that only appear in false_inputs
|
|
||||||
"""
|
|
||||||
true_inputs = set(true_inputs)
|
|
||||||
false_inputs = set(false_inputs)
|
|
||||||
|
|
||||||
shared_inputs = true_inputs.intersection(false_inputs)
|
|
||||||
true_only_inputs = true_inputs - false_inputs
|
|
||||||
false_only_inputs = false_inputs - true_inputs
|
|
||||||
|
|
||||||
return list(shared_inputs), list(true_only_inputs), list(false_only_inputs)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_intermediates_match(true_graph, false_graph,
|
|
||||||
true_optionals, false_optionals):
|
|
||||||
"""Returns new optionals lists that have matching signatures.
|
"""Returns new optionals lists that have matching signatures.
|
||||||
|
|
||||||
This is done by mirroring each list in the other using none optionals.
|
This is done by mirroring each list in the other using none optionals.
|
||||||
There is no merging of like optionals.
|
There is no merging of like optionals.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
true_graph: FuncGraph
|
branch_graphs: `list` of `FuncGraph`.
|
||||||
false_graph: FuncGraph
|
branch_optionals: `list` of `list`s of optional `Tensor`s from other
|
||||||
true_optionals: a list of optional Tensors from true_graph
|
branch_graphs
|
||||||
false_optionals: a list of optional Tensors from false_graph
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new list of Tensors in true_graph and a new list of Tensors in
|
A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the
|
||||||
false_graph. The two lists have the same number of Tensors, all of which
|
same number of `Tensor`s, all of which will be optionals of the same
|
||||||
will be optionals of the same shape/type.
|
shape/type.
|
||||||
"""
|
"""
|
||||||
new_true_optionals = (true_optionals +
|
new_branch_optionals = []
|
||||||
_create_none_optionals(true_graph, false_optionals))
|
# Since the intermediates are optionals with dtype variant, we only need
|
||||||
new_false_optionals = (_create_none_optionals(false_graph, true_optionals)
|
# enough room for the longest list of intermediates.
|
||||||
+ false_optionals)
|
intermediates_size = max(len(o) for o in branch_optionals)
|
||||||
return new_true_optionals, new_false_optionals
|
for i, branch_graph in enumerate(branch_graphs):
|
||||||
|
other_optionals = _create_none_optionals(
|
||||||
|
branch_graph, intermediates_size - len(branch_optionals[i]))
|
||||||
|
new_branch_optionals.append(branch_optionals[i] + other_optionals)
|
||||||
|
return new_branch_optionals
|
||||||
|
|
||||||
|
|
||||||
def _make_intermediates_match_xla(true_graph, false_graph, true_intermediates,
|
def _make_intermediates_match_xla(branch_graphs, branch_intermediates):
|
||||||
false_intermediates):
|
|
||||||
"""Like _make_intermediates_match but for the XLA case."""
|
"""Like _make_intermediates_match but for the XLA case."""
|
||||||
new_true_intermediates = (true_intermediates +
|
new_branch_intermediates = []
|
||||||
_create_fakeparams(true_graph, false_intermediates))
|
for i, branch_graph in enumerate(branch_graphs):
|
||||||
new_false_intermediates = (_create_fakeparams(false_graph, true_intermediates)
|
other_fakeparams = _create_fakeparams(
|
||||||
+ false_intermediates)
|
branch_graph,
|
||||||
return new_true_intermediates, new_false_intermediates
|
sum((bi for bi in branch_intermediates
|
||||||
|
if bi is not branch_intermediates[i]), []))
|
||||||
|
num_preceding = sum(len(bi) for bi in branch_intermediates[:i])
|
||||||
|
new_branch_intermediates.append(other_fakeparams[:num_preceding] +
|
||||||
|
branch_intermediates[i] +
|
||||||
|
other_fakeparams[num_preceding:])
|
||||||
|
return new_branch_intermediates
|
||||||
|
|
||||||
|
|
||||||
def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
|
def _make_inputs_match(branch_graphs, branch_inputs):
|
||||||
"""Modifies true_graph and false_graph so they have the same input signature.
|
"""Modifies branch_graphs so they have the same input signature.
|
||||||
|
|
||||||
This method reorders and/or adds parameters to true_graph and false_graph so
|
This method reorders and/or adds parameters to each graph in branch_graphs so
|
||||||
they have the same input signature, and updates the 'inputs' and 'captured'
|
they have the same input signature, and updates the 'inputs' and 'captured'
|
||||||
fields of both graphs accordingly. It uses the input tensors from the outer
|
fields of each graph accordingly. It uses the input tensors from the outer
|
||||||
graph to avoid duplicating shared arguments.
|
graph to avoid duplicating shared arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
true_graph: FuncGraph
|
branch_graphs: a `list` of `FuncGraph`
|
||||||
false_graph: FuncGraph
|
branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
|
||||||
true_inputs: a list of Tensors in the outer graph. The inputs for
|
inputs for the corresponding graph in `branch_graphs`.
|
||||||
true_graph.
|
|
||||||
false_inputs: a list of Tensors in the outer graph. The inputs for
|
|
||||||
false_graph.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new list of Tensors from the outer graph that are the new inputs for both
|
A new list of Tensors from the outer graph that are the new inputs for each
|
||||||
true_graph and false_graph. This is a deduped version of true_inputs +
|
branch_graph. This is a deduped version of `sum(branch_inputs)`.
|
||||||
false_inputs.
|
|
||||||
"""
|
"""
|
||||||
shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs(
|
assert len(branch_graphs) == len(branch_inputs)
|
||||||
true_inputs, false_inputs)
|
new_inputs = set()
|
||||||
|
for branch_in in branch_inputs:
|
||||||
|
new_inputs |= set(branch_in)
|
||||||
|
new_inputs = list(new_inputs)
|
||||||
|
|
||||||
new_inputs = shared_inputs + true_only_inputs + false_only_inputs
|
for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
|
||||||
|
branch_input_to_param = dict(zip(branch_in, branch_graph.inputs))
|
||||||
|
input_list = []
|
||||||
|
for in_t in new_inputs:
|
||||||
|
param = branch_input_to_param.get(in_t, None)
|
||||||
|
if param is None:
|
||||||
|
param = _create_dummy_input(branch_graph, in_t)
|
||||||
|
input_list.append(param)
|
||||||
|
|
||||||
true_input_to_param = dict(zip(true_inputs, true_graph.inputs))
|
branch_graph.inputs = input_list
|
||||||
false_input_to_param = dict(zip(false_inputs, false_graph.inputs))
|
|
||||||
|
|
||||||
true_graph.inputs = (
|
|
||||||
[true_input_to_param[t] for t in shared_inputs] +
|
|
||||||
[true_input_to_param[t] for t in true_only_inputs] +
|
|
||||||
_create_dummy_inputs(true_graph, false_only_inputs))
|
|
||||||
|
|
||||||
false_graph.inputs = (
|
|
||||||
[false_input_to_param[t] for t in shared_inputs] +
|
|
||||||
_create_dummy_inputs(false_graph, true_only_inputs) +
|
|
||||||
[false_input_to_param[t] for t in false_only_inputs])
|
|
||||||
|
|
||||||
# Rewrite the FuncGraphs' state to reflect the new inputs.
|
# Rewrite the FuncGraphs' state to reflect the new inputs.
|
||||||
true_graph.captures = collections.OrderedDict(zip(new_inputs,
|
branch_graph.captures = collections.OrderedDict(
|
||||||
true_graph.inputs))
|
zip(new_inputs, branch_graph.inputs))
|
||||||
false_graph.captures = collections.OrderedDict(zip(new_inputs,
|
|
||||||
false_graph.inputs))
|
|
||||||
|
|
||||||
return new_inputs
|
return new_inputs
|
||||||
|
|
||||||
|
|
||||||
def _make_output_composite_tensors_match(true_graph, false_graph):
|
def _make_output_composite_tensors_match(op_type, branch_graphs):
|
||||||
"""Modifies true_graph and false_graph so they have the same output signature.
|
"""Modifies each branch_graph's outputs to have the same output signature.
|
||||||
|
|
||||||
Currently the only transformation implemented is turning a Tensor into an
|
Currently the only transformation implemented is turning a Tensor into an
|
||||||
equivalent IndexedSlices if the other branch returns an IndexedSlices.
|
equivalent IndexedSlices if the other branch returns an IndexedSlices.
|
||||||
Updates {true,false}_graph.{outputs,structured_outputs}.
|
Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
|
||||||
|
branch_graphs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
true_graph: FuncGraph
|
op_type: _COND or _CASE
|
||||||
false_graph: FuncGraph
|
branch_graphs: `list` of `FuncGraph`
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if a pair of outputs cannot be rewritten.
|
TypeError: if a set of outputs cannot be rewritten.
|
||||||
"""
|
"""
|
||||||
# Note: since this is only used for gradient graphs, we do not expect the
|
# Note: since this is only used for gradient graphs, we do not expect the
|
||||||
# outputs to be structured (e.g. nested lists), and thus do not need to use
|
# outputs to be structured (e.g. nested lists), and thus do not need to use
|
||||||
# nest.flatten, etc.
|
# nest.flatten, etc.
|
||||||
true_outputs = list(true_graph.structured_outputs)
|
assert branch_graphs
|
||||||
false_outputs = list(false_graph.structured_outputs)
|
branch_outputs = [g.structured_outputs for g in branch_graphs]
|
||||||
assert len(true_outputs) == len(false_outputs)
|
outputs_per_branch = list(len(outs) for outs in branch_outputs)
|
||||||
|
assert len(set(outputs_per_branch)) == 1, outputs_per_branch
|
||||||
|
|
||||||
for idx, (true_out, false_out) in enumerate(zip(true_outputs, false_outputs)):
|
for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
|
||||||
if type(true_out) == type(false_out): # pylint: disable=unidiomatic-typecheck
|
if len(set(type(out) for out in branch_outs)) == 1:
|
||||||
continue
|
continue
|
||||||
if (isinstance(true_out, ops.IndexedSlices) and
|
if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs):
|
||||||
isinstance(false_out, ops.Tensor)):
|
continue
|
||||||
with false_graph.as_default():
|
for branch_idx, branch_out in enumerate(branch_outs):
|
||||||
false_outputs[idx] = math_ops._as_indexed_slices(false_out)
|
if isinstance(branch_out, ops.IndexedSlices):
|
||||||
elif (isinstance(true_out, ops.Tensor) and
|
continue
|
||||||
isinstance(false_out, ops.IndexedSlices)):
|
elif isinstance(branch_out, ops.Tensor):
|
||||||
with true_graph.as_default():
|
with branch_graphs[branch_idx].as_default():
|
||||||
true_outputs[idx] = math_ops._as_indexed_slices(true_out)
|
branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices(
|
||||||
|
branch_out)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Cannot reconcile tf.cond %i-th outputs:\n"
|
"Cannot reconcile {op_name} {output_idx}-th outputs:\n"
|
||||||
" true_fn returned: %s\n"
|
" outputs from all branches: {outputs}".format(
|
||||||
" false_fn returned: %s" % (idx, true_out, false_out))
|
op_name="tf.cond" if op_type == _COND else "tf.case",
|
||||||
|
output_idx=output_idx,
|
||||||
|
outputs=branch_outs))
|
||||||
|
|
||||||
true_graph.structured_outputs = true_outputs
|
for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
|
||||||
true_graph.outputs = func_graph_module.flatten(true_outputs)
|
branch_graph.structured_outputs = branch_outs
|
||||||
false_graph.structured_outputs = false_outputs
|
branch_graph.outputs = func_graph_module.flatten(branch_outs)
|
||||||
false_graph.outputs = func_graph_module.flatten(false_outputs)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_indexed_slices_indices_types_match(true_graph, false_graph):
|
def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
|
||||||
"""Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs."""
|
"""Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
|
||||||
|
assert branch_graphs
|
||||||
indexed_slice_indices = []
|
indexed_slice_indices = []
|
||||||
current_index = 0
|
current_index = 0
|
||||||
true_outputs_flat_with_composites = nest.flatten(
|
branch_outputs_flat_with_composites = [
|
||||||
true_graph.structured_outputs, expand_composites=False)
|
nest.flatten(branch_graph.structured_outputs, expand_composites=False)
|
||||||
false_outputs_flat_with_composites = nest.flatten(
|
for branch_graph in branch_graphs
|
||||||
false_graph.structured_outputs, expand_composites=False)
|
]
|
||||||
|
outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
|
||||||
|
assert len(set(outs_per_branch)) == 1, outs_per_branch
|
||||||
# Store indices of IndexedSlices.indices in `indexed_slice_indices`.
|
# Store indices of IndexedSlices.indices in `indexed_slice_indices`.
|
||||||
for idx, (true_out, false_out) in enumerate(
|
for output_idx, branch_outs in enumerate(
|
||||||
zip(true_outputs_flat_with_composites,
|
zip(*branch_outputs_flat_with_composites)):
|
||||||
false_outputs_flat_with_composites)):
|
if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
|
||||||
if isinstance(true_out, ops.IndexedSlices) != isinstance(
|
raise TypeError("Cannot reconcile {op_name} {output_idx}-th outputs:\n"
|
||||||
false_out, ops.IndexedSlices):
|
" branches returned: {outputs}".format(
|
||||||
raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n"
|
op_name="tf.cond" if op_type == _COND else "tf.case",
|
||||||
" true_fn returned: %s\n"
|
output_idx=output_idx,
|
||||||
" false_fn returned: %s" % (idx, true_out, false_out))
|
outputs=branch_outs))
|
||||||
if isinstance(true_out, ops.IndexedSlices):
|
if isinstance(branch_outs[0], ops.IndexedSlices):
|
||||||
# indices is the second component of the composite tensor.
|
# indices is the second component of the composite tensor.
|
||||||
indexed_slice_indices.append(current_index + 1)
|
indexed_slice_indices.append(current_index + 1)
|
||||||
if nest.is_sequence_or_composite(true_out):
|
if nest.is_sequence_or_composite(branch_outs[0]):
|
||||||
current_index += len(nest.flatten(true_out, expand_composites=True))
|
current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
|
||||||
else:
|
else:
|
||||||
current_index += 1
|
current_index += 1
|
||||||
|
|
||||||
if not indexed_slice_indices:
|
if not indexed_slice_indices:
|
||||||
return
|
return
|
||||||
|
|
||||||
if current_index != len(true_graph.outputs):
|
if current_index != len(branch_graphs[0].outputs):
|
||||||
raise ValueError("Insufficient elements in true_graph.outputs.\n"
|
raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
|
||||||
"Expected: %i\n"
|
"Expected: %i\n"
|
||||||
"Actual: %i" % (current_index, len(true_graph.outputs)))
|
"Actual: %i" %
|
||||||
|
(current_index, len(branch_graphs[0].outputs)))
|
||||||
|
|
||||||
# Cast indices with mismatching types to int64.
|
# Cast indices with mismatching types to int64.
|
||||||
for index in indexed_slice_indices:
|
for index in indexed_slice_indices:
|
||||||
if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
|
if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
|
||||||
|
for bg in branch_graphs):
|
||||||
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
|
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
|
||||||
"Found: %s" % str(true_graph.outputs[index].dtype))
|
"Found: %s" %
|
||||||
if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
|
str([bg.outputs[index].dtype for bg in branch_graphs]))
|
||||||
raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
|
if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
|
||||||
"Found: %s" % str(false_graph.outputs[index].dtype))
|
for branch_graph in branch_graphs:
|
||||||
if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype:
|
if branch_graph.outputs[index].dtype == dtypes.int32:
|
||||||
if false_graph.outputs[index].dtype == dtypes.int32:
|
with branch_graph.as_default():
|
||||||
with false_graph.as_default():
|
branch_graph.outputs[index] = math_ops.cast(
|
||||||
false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index],
|
branch_graph.outputs[index], dtypes.int64)
|
||||||
dtypes.int64)
|
|
||||||
else:
|
|
||||||
with true_graph.as_default():
|
|
||||||
true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index],
|
|
||||||
dtypes.int64)
|
|
||||||
|
|
||||||
true_graph.structured_outputs = func_graph_module.pack_sequence_as(
|
for branch_graph in branch_graphs:
|
||||||
true_graph.structured_outputs, true_graph.outputs)
|
branch_graph.structured_outputs = func_graph_module.pack_sequence_as(
|
||||||
false_graph.structured_outputs = func_graph_module.pack_sequence_as(
|
branch_graph.structured_outputs, branch_graph.outputs)
|
||||||
false_graph.structured_outputs, false_graph.outputs)
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap_intermediates(func_graph, intermediates):
|
def _wrap_intermediates(func_graph, intermediates):
|
||||||
@ -585,33 +575,33 @@ def _wrap_intermediates(func_graph, intermediates):
|
|||||||
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
|
return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]
|
||||||
|
|
||||||
|
|
||||||
def _create_dummy_inputs(func_graph, template_tensors):
|
def _create_dummy_input(func_graph, template_tensor):
|
||||||
"""Creates tensors in func_graph to represent template_tensors.
|
"""Creates tensors in func_graph to represent template_tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func_graph: FuncGraph.
|
func_graph: FuncGraph.
|
||||||
template_tensors: a list of tensors in the outer graph.
|
template_tensor: a tensor in the outer graph.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of tensors in func_graph.
|
A tensor in func_graph.
|
||||||
"""
|
"""
|
||||||
with func_graph.as_default():
|
with func_graph.as_default():
|
||||||
return [array_ops.placeholder(t.dtype, shape=t.shape)
|
return array_ops.placeholder(
|
||||||
for t in template_tensors]
|
template_tensor.dtype, shape=template_tensor.shape)
|
||||||
|
|
||||||
|
|
||||||
def _create_none_optionals(func_graph, template_tensors):
|
def _create_none_optionals(func_graph, n):
|
||||||
"""Creates none optionals in func_graph to represent template_tensors.
|
"""Creates `n` `None` optionals in func_graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func_graph: FuncGraph.
|
func_graph: FuncGraph.
|
||||||
template_tensors: a list of tensors in func_graph.
|
n: `int` the number of `None` optionals to make.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of tensors in func_graph.
|
A list of tensors in func_graph.
|
||||||
"""
|
"""
|
||||||
with func_graph.as_default():
|
with func_graph.as_default():
|
||||||
return [gen_dataset_ops.optional_none() for _ in template_tensors]
|
return [gen_dataset_ops.optional_none() for _ in range(n)]
|
||||||
|
|
||||||
|
|
||||||
def _create_fakeparams(func_graph, template_tensors):
|
def _create_fakeparams(func_graph, template_tensors):
|
||||||
@ -621,55 +611,69 @@ def _create_fakeparams(func_graph, template_tensors):
|
|||||||
for t in template_tensors]
|
for t in template_tensors]
|
||||||
|
|
||||||
|
|
||||||
def _check_same_outputs(true_graph, false_graph):
|
def _check_same_outputs(op_type, graphs):
|
||||||
"""Raises an error if true_graph and false_graph have different outputs."""
|
"""Raises an error if `graphs` have different outputs."""
|
||||||
|
|
||||||
def error(error_detail):
|
def error(branch_idx, error_detail):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"true_fn and false_fn arguments to tf.cond must have the same number, "
|
"{b0_name} and {bn_name} arguments to {op_name} must have the same "
|
||||||
"type, and overall structure of return values.\n"
|
"number, type, and overall structure of return values.\n"
|
||||||
"\n"
|
"\n"
|
||||||
"true_fn output: %s\n"
|
"{b0_name} output: {b0_out}\n"
|
||||||
"false_fn output: %s\n"
|
"{bn_name} output: {bn_out}\n"
|
||||||
"\n"
|
"\n"
|
||||||
"Error details:\n"
|
"Error details:\n"
|
||||||
"%s" % (true_graph.structured_outputs, false_graph.structured_outputs,
|
"{detail}".format(
|
||||||
error_detail))
|
b0_name="true_fn" if op_type == _COND else "branches[0]",
|
||||||
|
bn_name=("false_fn" if op_type == _COND else
|
||||||
|
"branches[{}]".format(branch_idx)),
|
||||||
|
op_name="tf.cond" if op_type == _COND else "tf.case",
|
||||||
|
b0_out=graphs[0].structured_outputs,
|
||||||
|
bn_out=graphs[branch_idx].structured_outputs,
|
||||||
|
detail=error_detail))
|
||||||
|
|
||||||
|
for b in range(1, len(graphs)):
|
||||||
try:
|
try:
|
||||||
nest.assert_same_structure(true_graph.structured_outputs,
|
nest.assert_same_structure(
|
||||||
false_graph.structured_outputs,
|
graphs[0].structured_outputs,
|
||||||
|
graphs[b].structured_outputs,
|
||||||
expand_composites=True)
|
expand_composites=True)
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
error(str(e))
|
error(b, str(e))
|
||||||
|
|
||||||
assert len(true_graph.outputs) == len(false_graph.outputs)
|
assert len(graphs[0].outputs) == len(graphs[b].outputs)
|
||||||
for true_out, false_out in zip(true_graph.outputs, false_graph.outputs):
|
for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs):
|
||||||
if true_out.dtype != false_out.dtype:
|
if b0_out.dtype != bn_out.dtype:
|
||||||
error("%s and %s have different types" % (true_out, false_out))
|
error(b, "%s and %s have different types" % (b0_out, bn_out))
|
||||||
|
|
||||||
|
|
||||||
def _get_output_shapes(true_graph_outputs, false_graph_outputs):
|
def _get_output_shapes(*branch_graph_outputs):
|
||||||
output_shapes = [
|
output_shapes = []
|
||||||
t_out.shape.most_specific_compatible_shape(f_out.shape)
|
for out_by_branch in zip(*branch_graph_outputs):
|
||||||
for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
|
shape = out_by_branch[0].shape
|
||||||
]
|
for other_out in out_by_branch[1:]:
|
||||||
|
shape = shape.most_specific_compatible_shape(other_out.shape)
|
||||||
|
output_shapes.append(shape)
|
||||||
return output_shapes
|
return output_shapes
|
||||||
|
|
||||||
|
|
||||||
def verify_captures(true_graph, false_graph):
|
def verify_captures(op_type, branch_graphs):
|
||||||
"""Verify that a true_fn tensor is not accessed in false_fn and vice-versa."""
|
"""Verify that a branch's tensor is not accessed in another branch fn."""
|
||||||
for t in false_graph.external_captures:
|
# Note: It is technically not possible for lower-branch_index branches to
|
||||||
if not isinstance(t, ops.EagerTensor) and t.graph is true_graph:
|
# capture tensors from higher-branch_index branches, because of the order of
|
||||||
raise ValueError("Tensor {} in true_fn is accessed from false_fn.".format(
|
# branch graph construction, but we check all for completeness and to
|
||||||
t.name))
|
|
||||||
# Note: This is technically not possible right now because `false_graph`
|
|
||||||
# is built "after" `true_graph` but we add this check for completeness and to
|
|
||||||
# guard against potential future changes.
|
# guard against potential future changes.
|
||||||
for t in true_graph.external_captures:
|
other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)}
|
||||||
if not isinstance(t, ops.EagerTensor) and t.graph is false_graph:
|
for i, branch_graph in enumerate(branch_graphs):
|
||||||
raise ValueError("Tensor {} in false_fn is accessed from true_fn.".format(
|
for t in branch_graph.external_captures:
|
||||||
t.name))
|
if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs:
|
||||||
|
branch_names = ["true_fn", "false_fn"] if op_type == _COND else [
|
||||||
|
"branch {}".format(bi) for bi in range(len(branch_graphs))]
|
||||||
|
raise ValueError(
|
||||||
|
"Tensor {tname} in {b0name} is accessed from {b1name}.".format(
|
||||||
|
tname=t.name,
|
||||||
|
b0name=branch_names[other_branch_graphs[t.graph]],
|
||||||
|
b1name=branch_names[i]))
|
||||||
|
|
||||||
|
|
||||||
class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
||||||
@ -679,14 +683,14 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
|||||||
gradient computation in optionals.
|
gradient computation in optionals.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
if_op_needs_rewrite: True if any intermediates were captured, meaning the
|
op_needs_rewrite: True if any intermediates were captured, meaning the
|
||||||
forward If op needs to be written to output the wrapped intermediates.
|
forward If op needs to be written to output the wrapped intermediates.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, forward_graph):
|
def __init__(self, name, forward_graph):
|
||||||
super(_CondGradFuncGraph, self).__init__(
|
super(_CondGradFuncGraph, self).__init__(
|
||||||
name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access
|
name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access
|
||||||
self.if_op_needs_rewrite = False
|
self.op_needs_rewrite = False
|
||||||
self._forward_graph = forward_graph
|
self._forward_graph = forward_graph
|
||||||
# Maps from forward intermediate tensor -> the unwrapped captured
|
# Maps from forward intermediate tensor -> the unwrapped captured
|
||||||
# intermediate.
|
# intermediate.
|
||||||
@ -719,7 +723,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
|||||||
# TODO(skyewm,jpienaar): can XLA support optionals?
|
# TODO(skyewm,jpienaar): can XLA support optionals?
|
||||||
if tensor not in self.captures:
|
if tensor not in self.captures:
|
||||||
self.xla_intermediates.append(tensor)
|
self.xla_intermediates.append(tensor)
|
||||||
self.if_op_needs_rewrite = True
|
self.op_needs_rewrite = True
|
||||||
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
|
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
|
||||||
|
|
||||||
captured_tensor = self._indirect_captures.get(tensor)
|
captured_tensor = self._indirect_captures.get(tensor)
|
||||||
@ -756,7 +760,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
|||||||
# 'tensor' hasn't been wrapped, do it now.
|
# 'tensor' hasn't been wrapped, do it now.
|
||||||
with self._forward_graph.as_default():
|
with self._forward_graph.as_default():
|
||||||
optional = gen_dataset_ops.optional_from_value([tensor])
|
optional = gen_dataset_ops.optional_from_value([tensor])
|
||||||
self.if_op_needs_rewrite = True
|
self.op_needs_rewrite = True
|
||||||
self._wrapped_intermediates[tensor] = optional
|
self._wrapped_intermediates[tensor] = optional
|
||||||
|
|
||||||
optional = self._wrapped_intermediates[tensor]
|
optional = self._wrapped_intermediates[tensor]
|
||||||
@ -767,3 +771,182 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
|||||||
|
|
||||||
self._indirect_captures[tensor] = captured_tensor
|
self._indirect_captures[tensor] = captured_tensor
|
||||||
return captured_tensor
|
return captured_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def indexed_case(branch_index, branch_fns, name="indexed_case"):
|
||||||
|
"""Like conv_v2, except emits a Case op instead of an If."""
|
||||||
|
if isinstance(branch_index, int):
|
||||||
|
raise TypeError("branch_index must not be a Python int", branch_index)
|
||||||
|
|
||||||
|
with ops.name_scope(name) as scope:
|
||||||
|
branch_names = [
|
||||||
|
util.unique_fn_name(scope, "branch{}".format(b))
|
||||||
|
for b in range(len(branch_fns))
|
||||||
|
]
|
||||||
|
|
||||||
|
# Automatic control dependencies are added in defuns, but not in v1
|
||||||
|
# graphs. Propagate that behavior here.
|
||||||
|
add_control_dependencies = ops.get_default_graph()._add_control_dependencies
|
||||||
|
branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
|
||||||
|
|
||||||
|
branch_graphs = []
|
||||||
|
for branch_name, branch_fn in zip(branch_names, branch_fns):
|
||||||
|
branch_graphs.append(
|
||||||
|
func_graph_module.func_graph_from_py_func(
|
||||||
|
branch_name,
|
||||||
|
branch_fn,
|
||||||
|
[],
|
||||||
|
{},
|
||||||
|
func_graph=util.CondBranchFuncGraph(
|
||||||
|
branch_name,
|
||||||
|
collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
|
||||||
|
add_control_dependencies=add_control_dependencies,
|
||||||
|
op_return_value=branch_index))
|
||||||
|
|
||||||
|
verify_captures(_CASE, branch_graphs)
|
||||||
|
return _build_case(
|
||||||
|
branch_index,
|
||||||
|
branch_graphs, [g.external_captures for g in branch_graphs],
|
||||||
|
name=scope)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("Case")
|
||||||
|
def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
||||||
|
"""The gradient of a Case op produced (w/ branch_index) by tf.case."""
|
||||||
|
# Get the if operator (this logic handles the case where op is a MockOp)
|
||||||
|
case_op = op.outputs[0].op
|
||||||
|
branch_graphs = _get_func_graphs(case_op)
|
||||||
|
assert branch_graphs
|
||||||
|
# Note: op.graph != ops.get_default_graph() when we are computing the gradient
|
||||||
|
# of a nested cond.
|
||||||
|
for branch_graph in branch_graphs:
|
||||||
|
assert branch_graph.outer_graph == case_op.graph
|
||||||
|
|
||||||
|
# Create grad functions that compute the gradient of the branch forward
|
||||||
|
# graphs. These functions will capture tensors from the forward pass
|
||||||
|
# functions.
|
||||||
|
branch_grad_graphs = []
|
||||||
|
for branch_graph in branch_graphs:
|
||||||
|
branch_grad_graphs.append(
|
||||||
|
_create_grad_func(branch_graph, grads,
|
||||||
|
util.unique_grad_fn_name(branch_graph.name)))
|
||||||
|
|
||||||
|
if any(g.op_needs_rewrite for g in branch_grad_graphs):
|
||||||
|
# Modify 'op' to output the intermediates needed by the grad functions. Note
|
||||||
|
# that all needed intermediates are wrapped in optionals. Each optional
|
||||||
|
# intermediate output will have a value iff its corresponding branch is
|
||||||
|
# taken.
|
||||||
|
# NOTE(bjp): if there are any active sessions, this modification to `op`
|
||||||
|
# may make them unrunnable!
|
||||||
|
|
||||||
|
if control_flow_util.InXlaContext(ops.get_default_graph()):
|
||||||
|
# XLA does not yet support optionals, so output intermediates directly and
|
||||||
|
# make them match via FakeParams, which can be converted to zeros in XLA.
|
||||||
|
# TODO(bjp,jpienaar): can XLA support optionals?
|
||||||
|
branches_intermediates = [
|
||||||
|
branch_grad_graph.xla_intermediates
|
||||||
|
for branch_grad_graph in branch_grad_graphs
|
||||||
|
]
|
||||||
|
extra_branch_outputs = _make_intermediates_match_xla(
|
||||||
|
branch_graphs, branches_intermediates)
|
||||||
|
else:
|
||||||
|
branch_intermediates = [
|
||||||
|
g.wrapped_intermediates for g in branch_grad_graphs
|
||||||
|
]
|
||||||
|
# Make outputs match by adding none optionals.
|
||||||
|
extra_branch_outputs = _make_intermediates_match(branch_graphs,
|
||||||
|
branch_intermediates)
|
||||||
|
|
||||||
|
for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs):
|
||||||
|
branch_graph.outputs.extend(extra_outputs)
|
||||||
|
_make_indexed_slices_indices_types_match(_CASE, branch_graphs)
|
||||||
|
# TODO(bjp): indicate it's an internal bug if this fails.
|
||||||
|
_check_same_outputs(_CASE, branch_graphs)
|
||||||
|
|
||||||
|
for branch_graph in branch_graphs:
|
||||||
|
branch_graph.name += "_rewritten"
|
||||||
|
|
||||||
|
case_op._set_func_list_attr("branches", [
|
||||||
|
util.create_new_tf_function(branch_graph)
|
||||||
|
for branch_graph in branch_graphs
|
||||||
|
])
|
||||||
|
case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
|
||||||
|
case_op._set_shape_list_attr("output_shapes",
|
||||||
|
branch_graphs[0].output_shapes)
|
||||||
|
case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
|
||||||
|
[t.shape for t in extra_branch_outputs[0]])
|
||||||
|
|
||||||
|
# Resolve references to forward graph tensors in grad graphs and ensure
|
||||||
|
# they are in-scope, i.e., belong to one of outer graphs of the grad graph.
|
||||||
|
branches_grad_inputs = [
|
||||||
|
_resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
|
||||||
|
branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
|
||||||
|
]
|
||||||
|
|
||||||
|
# This modifies the graphs in branch_grad_graphs.
|
||||||
|
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
|
||||||
|
|
||||||
|
outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
|
||||||
|
branches_grad_inputs)
|
||||||
|
|
||||||
|
# The predicate has no gradient.
|
||||||
|
return [None] + outputs
|
||||||
|
|
||||||
|
|
||||||
|
def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
|
||||||
|
"""Creates an `Case` op from `branch_index`, branch graphs and inputs.
|
||||||
|
|
||||||
|
Note that this modifies `branch_graphs` to make the inputs match, and to
|
||||||
|
output all intermediates values so they're available for the gradient
|
||||||
|
computation.
|
||||||
|
|
||||||
|
`branch_graphs` need not have the same input types, but they must
|
||||||
|
have the same outpute types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_index: integer Tensor
|
||||||
|
branch_graphs: List of FuncGraph
|
||||||
|
branch_inputs: List of lists of Tensors to be passed to corresponding
|
||||||
|
branch_graph as input.
|
||||||
|
name: the name for the Case op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of Tensors which are the outputs of the Case op. Does not include
|
||||||
|
added intermediate outputs.
|
||||||
|
"""
|
||||||
|
_check_same_outputs(_CASE, branch_graphs)
|
||||||
|
|
||||||
|
# Add inputs to branch_graphs to make them match. Note that this modifies the
|
||||||
|
# graphs in `branch_graphs`.
|
||||||
|
case_inputs = _make_inputs_match(branch_graphs, branch_inputs)
|
||||||
|
|
||||||
|
# Create the Case op.
|
||||||
|
with ops.control_dependencies(
|
||||||
|
sum((list(bg.control_captures) for bg in branch_graphs), [])):
|
||||||
|
tensors = gen_functional_ops.case(
|
||||||
|
branch_index,
|
||||||
|
case_inputs, [t.dtype for t in branch_graphs[0].outputs],
|
||||||
|
[util.create_new_tf_function(g) for g in branch_graphs],
|
||||||
|
output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
# TODO(b/110167197) this requires Case to have at least 1 output
|
||||||
|
case_op = tensors[0].op
|
||||||
|
# TODO(b/131304144): Enable lowering Case to SwitchN/Merge for graph mode.
|
||||||
|
# util.maybe_set_lowering_attr(case_op)
|
||||||
|
util.maybe_propagate_compile_time_consts_in_xla(case_op)
|
||||||
|
|
||||||
|
# Return identities for each output of the Case op, rather than the output of
|
||||||
|
# the Case op directly. This makes pruning work if the output of select_case()
|
||||||
|
# is fetched: the lowering pass converts the Case outputs into IdentityN
|
||||||
|
# outputs, which if fetched will cause all ops in the taken branch to be run
|
||||||
|
# (since it takes all merge ops as input). After lowering, each output
|
||||||
|
# identity op will end up with only the appropriate merge op as input.
|
||||||
|
# TODO(b/79984175): this doesn't have to be a tuple once we covert to the
|
||||||
|
# correct output structure
|
||||||
|
tensors = [array_ops.identity(t) for t in tensors]
|
||||||
|
|
||||||
|
# Prevent fetching since the variant outputs can't be fetched directly.
|
||||||
|
case_op.graph.prevent_fetching(case_op)
|
||||||
|
return func_graph_module.pack_sequence_as(branch_graphs[0].structured_outputs,
|
||||||
|
tensors)
|
||||||
|
@ -3931,6 +3931,101 @@ def _case_helper(cond_fn,
|
|||||||
return fn()
|
return fn()
|
||||||
|
|
||||||
|
|
||||||
|
def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
|
||||||
|
branch_index):
|
||||||
|
"""Verifies input arguments for the case function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_fns: Dict or list of pairs of an `int` and a callable which
|
||||||
|
returns a list of tensors.
|
||||||
|
default: Optional callable that returns a list of tensors.
|
||||||
|
branch_index: Optional int `Tensor`, which selects for the corresponding
|
||||||
|
pred_fn_pair.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `branch_fns` is not a list/dictionary.
|
||||||
|
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
|
||||||
|
callables.
|
||||||
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
||||||
|
callable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
branch_fns: validated list of callables for each branch (default last).
|
||||||
|
"""
|
||||||
|
if not isinstance(branch_index, ops.Tensor):
|
||||||
|
raise TypeError("branch_index must a Tensor, got {}".format(
|
||||||
|
type(branch_index)))
|
||||||
|
if not branch_index.dtype.is_integer:
|
||||||
|
raise TypeError("branch_index must an integer Tensor, got {}".format(
|
||||||
|
branch_index.dtype))
|
||||||
|
|
||||||
|
if not branch_fns:
|
||||||
|
raise ValueError("Must provide at least one item in branch_fns")
|
||||||
|
if not isinstance(branch_fns, (list, _basetuple, dict)):
|
||||||
|
raise TypeError("branch_fns must be a list, tuple, or dict")
|
||||||
|
|
||||||
|
if isinstance(branch_fns, dict):
|
||||||
|
branch_fns = branch_fns.items()
|
||||||
|
|
||||||
|
if all(callable(fn) for fn in branch_fns):
|
||||||
|
branch_fns = list(enumerate(branch_fns))
|
||||||
|
|
||||||
|
for key_fn_pair in branch_fns:
|
||||||
|
if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2:
|
||||||
|
raise TypeError("Each entry in branch_fns must be a 2-tuple")
|
||||||
|
key, branch_fn = key_fn_pair
|
||||||
|
|
||||||
|
if not isinstance(key, int):
|
||||||
|
raise TypeError("key must be a Python `int`, got {}".format(type(key)))
|
||||||
|
|
||||||
|
if not callable(branch_fn):
|
||||||
|
raise TypeError("fn for key {} must be callable.".format(key))
|
||||||
|
|
||||||
|
keys = [p[0] for p in branch_fns]
|
||||||
|
if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys):
|
||||||
|
raise ValueError(
|
||||||
|
"branch indices (keys) must form contiguous range of [0 to {}) but "
|
||||||
|
"found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys)))))
|
||||||
|
actions = [p[1] for p in sorted(branch_fns)]
|
||||||
|
if default is not None:
|
||||||
|
actions.append(default)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
def _indexed_case_helper(branch_fns, default, branch_index, name):
|
||||||
|
"""Implementation of case that emits the n-way indexed Case op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_fns: Dict or list of pairs of a boolean scalar tensor, and a
|
||||||
|
callable which returns a list of tensors.
|
||||||
|
default: Optional callable that returns a list of tensors.
|
||||||
|
branch_index: Optional int `Tensor`, which selects for the corresponding
|
||||||
|
pred_fn_pair.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tensors returned by the pair whose key matched branch_index, or
|
||||||
|
those returned by `default` if none does.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `branch_fns` is not a list/dictionary.
|
||||||
|
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
|
||||||
|
callables.
|
||||||
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
||||||
|
callable.
|
||||||
|
"""
|
||||||
|
branch_fns = _indexed_case_verify_and_canonicalize_args(
|
||||||
|
branch_fns, default, branch_index)
|
||||||
|
with ops.name_scope(name, "case", [branch_index]):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
branch_index = array_ops.where(
|
||||||
|
math_ops.less(branch_index, 0)
|
||||||
|
| math_ops.greater_equal(branch_index, len(branch_fns)),
|
||||||
|
len(branch_fns) - 1, branch_index)
|
||||||
|
return branch_fns[int(branch_index)]()
|
||||||
|
return cond_v2.indexed_case(branch_index, branch_fns)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("case")
|
@tf_export("case")
|
||||||
def case(pred_fn_pairs,
|
def case(pred_fn_pairs,
|
||||||
default=None,
|
default=None,
|
||||||
@ -3939,6 +4034,8 @@ def case(pred_fn_pairs,
|
|||||||
name="case"):
|
name="case"):
|
||||||
"""Create a case operation.
|
"""Create a case operation.
|
||||||
|
|
||||||
|
See also `tf.switch_case`.
|
||||||
|
|
||||||
The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
|
The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
|
||||||
Each pair contains a boolean scalar tensor and a python callable that
|
Each pair contains a boolean scalar tensor and a python callable that
|
||||||
creates the tensors to be returned if the boolean evaluates to True.
|
creates the tensors to be returned if the boolean evaluates to True.
|
||||||
@ -4037,6 +4134,82 @@ def case(pred_fn_pairs,
|
|||||||
strict=strict)
|
strict=strict)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("switch_case")
|
||||||
|
def switch_case(branch_index,
|
||||||
|
branch_fns,
|
||||||
|
default=None,
|
||||||
|
name="switch_case"):
|
||||||
|
"""Create a switch/case operation, i.e. an integer-indexed conditional.
|
||||||
|
|
||||||
|
See also `tf.case`.
|
||||||
|
|
||||||
|
This op can be substantially more efficient than `tf.case` when exactly one
|
||||||
|
branch will be selected. `tf.switch_case` is more like a C++ switch/case
|
||||||
|
statement than `tf.case`, which is more like an if/elif/elif/else chain.
|
||||||
|
|
||||||
|
The `branch_fns` parameter is either a dict from `int` to callables, or list
|
||||||
|
of (`int, callable) pairs, or simply a list of callables (in which case the
|
||||||
|
index is implicitly the key). The `branch_index` `Tensor` is used to select an
|
||||||
|
element in `branch_fns` with matching `int` key, falling back to `default`
|
||||||
|
if none match, or `max(keys)` if no `default` is provided. The keys must form
|
||||||
|
a contiguous set from `0` to `len(branch_fns) - 1`.
|
||||||
|
|
||||||
|
`tf.switch_case` supports nested structures as implemented in `tf.nest`. All
|
||||||
|
callables must return the same (possibly nested) value structure of lists,
|
||||||
|
tuples, and/or named tuples.
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
|
||||||
|
Pseudocode:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
switch (branch_index) { // c-style switch
|
||||||
|
case 0: return 17;
|
||||||
|
case 1: return 31;
|
||||||
|
default: return -1;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
or
|
||||||
|
```python
|
||||||
|
branches = {0: lambda: 17, 1: lambda: 31}
|
||||||
|
branches.get(branch_index, lambda: -1)()
|
||||||
|
```
|
||||||
|
|
||||||
|
Expressions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def f1(): return tf.constant(17)
|
||||||
|
def f2(): return tf.constant(31)
|
||||||
|
def f3(): return tf.constant(-1)
|
||||||
|
r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
|
||||||
|
# Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
branch_index: An int Tensor specifying which of `branch_fns` should be
|
||||||
|
executed.
|
||||||
|
branch_fns: A `dict` mapping `int`s to callables, or a `list` of
|
||||||
|
(`int, callable) pairs, or simply a list of callables (in which case the
|
||||||
|
index serves as the key). Each callable must return a matching structure
|
||||||
|
of tensors.
|
||||||
|
default: Optional callable that returns a structure of tensors.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tensors returned by the callable identified by `branch_index`, or those
|
||||||
|
returned by `default` if no key matches and `default` was provided, or those
|
||||||
|
returned by the max-keyed `branch_fn` if no `default` is provided.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `branch_fns` is not a list/dictionary.
|
||||||
|
TypeError: If `branch_fns` is a list but does not contain 2-tuples or
|
||||||
|
callables.
|
||||||
|
TypeError: If `fns[i]` is not callable for any i, or `default` is not
|
||||||
|
callable.
|
||||||
|
"""
|
||||||
|
return _indexed_case_helper(branch_fns, default, branch_index, name)
|
||||||
|
|
||||||
|
|
||||||
class XLAControlFlowContext(ControlFlowContext):
|
class XLAControlFlowContext(ControlFlowContext):
|
||||||
"""Base class for XLA and TPU control flow contexts."""
|
"""Base class for XLA and TPU control flow contexts."""
|
||||||
|
|
||||||
|
@ -24,6 +24,8 @@ import numpy as np
|
|||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.framework import node_def_pb2
|
from tensorflow.core.framework import node_def_pb2
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -417,6 +419,28 @@ class CondTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
|
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
|
||||||
|
|
||||||
|
@test_util.enable_control_flow_v2
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testCond_gradient(self):
|
||||||
|
true_in, false_in = array_ops.constant(1.), array_ops.constant(5.)
|
||||||
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
|
tape.watch(true_in)
|
||||||
|
tape.watch(false_in)
|
||||||
|
cond_true = control_flow_ops.cond(
|
||||||
|
array_ops.constant(True), lambda: true_in**2., lambda: false_in**2.)
|
||||||
|
cond_false = control_flow_ops.cond(
|
||||||
|
array_ops.constant(False), lambda: true_in**2., lambda: false_in**2.)
|
||||||
|
grads_true = tape.gradient(
|
||||||
|
cond_true, [true_in, false_in], output_gradients=3.)
|
||||||
|
grads_false = tape.gradient(
|
||||||
|
cond_false, [true_in, false_in], output_gradients=3.)
|
||||||
|
self.assertEqual(3. * 2. * 1., self.evaluate(grads_true[0]))
|
||||||
|
self.assertEqual(None if context.executing_eagerly() else 0.,
|
||||||
|
self.evaluate(grads_true[1]))
|
||||||
|
self.assertEqual(3. * 2. * 5., self.evaluate(grads_false[1]))
|
||||||
|
self.assertEqual(None if context.executing_eagerly() else 0.,
|
||||||
|
self.evaluate(grads_false[0]))
|
||||||
|
|
||||||
|
|
||||||
class ContextTest(test_util.TensorFlowTestCase):
|
class ContextTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@ -908,6 +932,179 @@ class DataTypesTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
|
self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class IndexedCaseTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self):
|
||||||
|
nbranches = 5
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: array_ops.constant(bi * 10, name="br{}_out".format(bi))
|
||||||
|
|
||||||
|
branches = [(i, make_func(i)) for i in range(nbranches)]
|
||||||
|
for bi in range(nbranches):
|
||||||
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
|
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||||
|
self.assertEqual(bi * 10, self.evaluate(case_out))
|
||||||
|
|
||||||
|
def testCase(self):
|
||||||
|
nbranches = 5
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
|
branches = [(i, make_func(i)) for i in range(nbranches)]
|
||||||
|
for bi in 0, 2, 3:
|
||||||
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
|
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||||
|
self.assertEqual(bi * 10., self.evaluate(case_out))
|
||||||
|
|
||||||
|
def testCase_withDefault(self):
|
||||||
|
nbranches = 5
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
|
branches = [(i, make_func(i)) for i in range(nbranches)]
|
||||||
|
for bi in -1, 2, nbranches:
|
||||||
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
|
case_out = control_flow_ops.switch_case(
|
||||||
|
branch_index, branches, default=make_func(6))
|
||||||
|
if bi < 0 or bi >= nbranches:
|
||||||
|
expected = 60.
|
||||||
|
else:
|
||||||
|
expected = bi * 10.
|
||||||
|
self.assertEqual(expected, self.evaluate(case_out))
|
||||||
|
|
||||||
|
def testCase_dictWithDefault(self):
|
||||||
|
nbranches = 5
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
|
branches = {i: make_func(i) for i in range(nbranches)}
|
||||||
|
for bi in -1, 0, 3, nbranches:
|
||||||
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
|
case_out = control_flow_ops.switch_case(
|
||||||
|
branch_index, branches, default=make_func(6))
|
||||||
|
if bi < 0 or bi >= nbranches:
|
||||||
|
expected = 60.
|
||||||
|
else:
|
||||||
|
expected = bi * 10.
|
||||||
|
self.assertEqual(expected, self.evaluate(case_out))
|
||||||
|
|
||||||
|
def testCase_gradient(self):
|
||||||
|
nbranches = 5
|
||||||
|
inputs = [
|
||||||
|
array_ops.constant(float(bi), name="br{}_in".format(bi))
|
||||||
|
for bi in range(nbranches)
|
||||||
|
]
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: inputs[bi]**2.
|
||||||
|
|
||||||
|
branches = {bi: make_func(bi) for bi in range(nbranches)}
|
||||||
|
|
||||||
|
for bi in -1, 1, 4, nbranches:
|
||||||
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
for x in inputs:
|
||||||
|
tape.watch(x)
|
||||||
|
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||||
|
out_grad = 3.
|
||||||
|
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
||||||
|
expected_grads = [None if context.executing_eagerly() else 0.] * nbranches
|
||||||
|
used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi
|
||||||
|
expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx
|
||||||
|
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||||
|
for expected, actual in zip(expected_grads, actual_grads):
|
||||||
|
self.assertEqual(expected, self.evaluate(actual))
|
||||||
|
|
||||||
|
def testCase_gradient_diffShapedIntermediates(self):
|
||||||
|
nbranches = 5
|
||||||
|
inputs = [
|
||||||
|
array_ops.constant(
|
||||||
|
float(bi), shape=[bi + 1], name="br{}_in".format(bi))
|
||||||
|
for bi in range(nbranches)
|
||||||
|
]
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
|
||||||
|
def f():
|
||||||
|
x = inputs[bi]**2 * inputs[bi][:bi + 1, None]
|
||||||
|
return math_ops.reduce_sum(x)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
branches = {bi: make_func(bi) for bi in range(nbranches)}
|
||||||
|
|
||||||
|
for bi in -1, 2, nbranches:
|
||||||
|
branch_index = array_ops.placeholder_with_default(bi, [])
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
for x in inputs:
|
||||||
|
tape.watch(x)
|
||||||
|
case_out = control_flow_ops.switch_case(branch_index, branches)
|
||||||
|
out_grad = 3.
|
||||||
|
actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad)
|
||||||
|
used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi
|
||||||
|
expected_grads = []
|
||||||
|
for input_idx in range(nbranches):
|
||||||
|
if used_bi == input_idx:
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(inputs[used_bi])
|
||||||
|
y = make_func(used_bi)()
|
||||||
|
expected_grads.append(
|
||||||
|
self.evaluate(
|
||||||
|
tape.gradient(y, inputs[used_bi], output_gradients=out_grad)))
|
||||||
|
else:
|
||||||
|
expected_grads.append(None if context.executing_eagerly() else [0.] *
|
||||||
|
(input_idx + 1))
|
||||||
|
|
||||||
|
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||||
|
for expected, actual in zip(expected_grads, actual_grads):
|
||||||
|
if expected is None:
|
||||||
|
self.assertIsNone(actual)
|
||||||
|
else:
|
||||||
|
self.assertAllEqual(expected, self.evaluate(actual))
|
||||||
|
|
||||||
|
def testCase_validateIndicesContiguous(self):
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
|
branches = {i: make_func(i) for i in range(0, 6, 2)}
|
||||||
|
with self.assertRaisesRegexp(ValueError, "must form contiguous"):
|
||||||
|
control_flow_ops.switch_case(array_ops.constant(0), branches)
|
||||||
|
|
||||||
|
def testCase_validateIndicesDup(self):
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
|
branches = [(i, make_func(i)) for i in range(0, 6, 2)]
|
||||||
|
branches.append((0, make_func(7)))
|
||||||
|
with self.assertRaisesRegexp(ValueError, "must form contiguous"):
|
||||||
|
control_flow_ops.switch_case(array_ops.constant(0), branches)
|
||||||
|
|
||||||
|
def testCase_validateBranchIndex(self):
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
|
branches = {i: make_func(i) for i in range(5)}
|
||||||
|
with self.assertRaisesRegexp(TypeError, "branch_index.*Tensor"):
|
||||||
|
control_flow_ops.switch_case(1, branches)
|
||||||
|
|
||||||
|
def testCase_validateNonIntKeys(self):
|
||||||
|
|
||||||
|
def make_func(bi):
|
||||||
|
return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi))
|
||||||
|
|
||||||
|
branches = {array_ops.constant(i): make_func(i) for i in range(5)}
|
||||||
|
with self.assertRaisesRegexp(TypeError, "must be a Python `int`"):
|
||||||
|
control_flow_ops.switch_case(array_ops.constant(1), branches)
|
||||||
|
|
||||||
|
|
||||||
class CaseTest(test_util.TensorFlowTestCase):
|
class CaseTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@ -2256,6 +2256,10 @@ tf_module {
|
|||||||
name: "svd"
|
name: "svd"
|
||||||
argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "switch_case"
|
||||||
|
argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'switch_case\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "tables_initializer"
|
name: "tables_initializer"
|
||||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
|
||||||
|
@ -988,6 +988,10 @@ tf_module {
|
|||||||
name: "subtract"
|
name: "subtract"
|
||||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "switch_case"
|
||||||
|
argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'switch_case\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "tan"
|
name: "tan"
|
||||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user