diff --git a/tensorflow/compiler/tests/cond_test.py b/tensorflow/compiler/tests/cond_test.py index 53e02212058..5963020bbb7 100644 --- a/tensorflow/compiler/tests/cond_test.py +++ b/tensorflow/compiler/tests/cond_test.py @@ -19,12 +19,16 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tests import xla_test +from tensorflow.python.compiler.xla import xla from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @@ -33,6 +37,7 @@ from tensorflow.python.platform import test class CondTest(xla_test.XLATestCase): def testCondAndTensorArrayInDefun(self): + # TODO(b/132430685): Make test more useful. Also b/129396295, b/127846988 with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() @@ -47,7 +52,7 @@ class CondTest(xla_test.XLATestCase): return output.stack() output_t = f() - self.assertAllEqual(self.evaluate(output_t), [5.]) + self.assertAllEqual([5.], self.evaluate(output_t)) xla_context.Exit() @@ -71,11 +76,178 @@ class CondTest(xla_test.XLATestCase): output = control_flow_ops.cond( constant_op.constant(True), if_true, if_false) - self.assertAllEqual( - sess.run(output, feed_dict={ - x: [0., 1., 2.], - p: 1 - }), 1.) + self.assertAllEqual(1., + sess.run(output, feed_dict={ + x: [0., 1., 2.], + p: 1 + })) + + xla_context.Exit() + + def testCondConstPropagation_xlaCompile(self): + self.skipTest("b/132430685") + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder_with_default([0., 1., 2.], shape=[3]) + p = constant_op.constant(1) + + def f(): + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def if_true(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[p] + + def if_false(): + return 5. + + return control_flow_ops.cond( + constant_op.constant(True), if_true, if_false) + + output = xla.compile(f) + + self.assertAllEqual(1., self.evaluate(output)) + + xla_context.Exit() + + def testCondConstPropagation_errorMsg(self): + self.skipTest("b/132430685") + with self.session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder(dtypes.float32) + p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32) + + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def if_true(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[:p] + + def if_false(): + return array_ops.fill([p], 5.) + + output = control_flow_ops.cond( + constant_op.constant(True), if_true, if_false) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must be a compile-time constant"): + sess.run( + output, feed_dict={ + x: [0., 1., 2.], + }) + + xla_context.Exit() + + def testCondConstPropagation_errorMsg_xlaCompile(self): + with self.session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder(dtypes.float32) + p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32) + condition = math_ops.cast( + random_ops.random_uniform([], minval=0, maxval=2, dtype=dtypes.int32), + dtypes.bool) + + def f(): + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def if_true(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[:p] + + def if_false(): + return array_ops.fill([p], 5.) + + return control_flow_ops.cond(condition, if_true, if_false) + + output = xla.compile(f) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "must be a compile-time constant"): + sess.run( + output, feed_dict={ + x: [0., 1., 2.], + }) + + xla_context.Exit() + + def testSwitchCaseAndTensorArrayInDefun(self): + self.skipTest("b/127846988") + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + @function.defun + def f(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) + output = control_flow_ops.switch_case( + constant_op.constant(1), { + 0: lambda: ta.write(0, 5.), + 1: lambda: ta.write(0, 10.), + 2: lambda: ta.write(0, 15.), + }) + + return output.stack() + + output_t = f() + self.assertAllEqual([10.], self.evaluate(output_t)) + + xla_context.Exit() + + def testSwitchCaseConstPropagation(self): + self.skipTest("b/127846988") + with self.session() as sess, self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + x = array_ops.placeholder(dtypes.float32) + p = array_ops.placeholder(dtypes.int32) + + def branch0(): + return 5. + + def branch1(): + return 15. + + # TODO(b/129021699): Wrapping this in a tf.function does not work. + def branch2(): + # This emits a StridedSlice op which expects the index to be a + # compile-time const. + return x[p] + + output = control_flow_ops.switch_case( + constant_op.constant(2), { + 0: branch0, + 1: branch1, + 2: branch2, + }) + + self.assertAllEqual(7., + sess.run(output, feed_dict={ + x: [0., 1., 7.], + p: 2, + })) + + xla_context.Exit() + + def testCondNoInputs(self): + """Verifies against `Failed precondition: Expected one input shape`.""" + + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + + for pred in True, False: + cond_out = control_flow_ops.cond( + array_ops.placeholder_with_default(pred, []), + lambda: constant_op.constant(2.), + lambda: constant_op.constant(1.)) + self.assertEqual(int(pred) + 1., self.evaluate(cond_out)) xla_context.Exit() diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 5b5a7e0863d..1c94f38e06d 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -125,6 +125,8 @@ Status BackwardsConstAnalysis(const Graph& g, return status; } +namespace { + Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node, StringPiece func_attr_name, const FunctionBody** fbody) { NameAttrList name_attr_list; @@ -136,6 +138,50 @@ Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, const Node* node, return Status::OK(); } +Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, const Node* node, + StringPiece func_list_attr_name, + std::vector* fbodies) { + std::vector name_attr_lists; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->def(), func_list_attr_name, &name_attr_lists)); + for (const NameAttrList& name_attr_list : name_attr_lists) { + FunctionLibraryRuntime::Handle func_handle; + TF_RETURN_IF_ERROR(flib_runtime->Instantiate( + name_attr_list.name(), AttrSlice(&name_attr_list.attr()), + &func_handle)); + fbodies->push_back(flib_runtime->GetFunctionBody(func_handle)); + } + return Status::OK(); +} + +Status CondConstInputIndices( + absl::Span branch_bodies, + std::vector* const_input_idxs, FunctionLibraryRuntime* flib_runtime) { + TF_RET_CHECK(!branch_bodies.empty()); + TF_RET_CHECK(branch_bodies[0] != nullptr); + int num_inputs = branch_bodies[0]->fdef.signature().input_arg_size(); + // Stores indices of the "branch function" inputs that are expected to be + // compile time constants. + std::vector compile_time_const_arg_indices(num_inputs); + for (auto fbody : branch_bodies) { + TF_RET_CHECK(fbody != nullptr); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *(fbody->graph), &compile_time_const_arg_indices, + /*compile_time_const_nodes=*/nullptr, flib_runtime)); + } + for (int i = 0; i < compile_time_const_arg_indices.size(); i++) { + if (compile_time_const_arg_indices[i]) { + // The 0th input is the pred or branch index, which is not passed to the + // branches. So the i'th input of a branch function corresponds to the + // i + 1'th input of the If/Case op. + const_input_idxs->push_back(i + 1); + } + } + return Status::OK(); +} + +} // namespace + Status GetCompileTimeConstInputs(const Node* node, std::vector* const_input_idxs, FunctionLibraryRuntime* flib_runtime) { @@ -179,6 +225,7 @@ Status GetCompileTimeConstInputs(const Node* node, } } } + return Status::OK(); } else if (node->type_string() == "If") { const FunctionBody* fthen = nullptr; const FunctionBody* felse = nullptr; @@ -186,31 +233,17 @@ Status GetCompileTimeConstInputs(const Node* node, GetFunctionBody(flib_runtime, node, "then_branch", &fthen)); TF_RETURN_IF_ERROR( GetFunctionBody(flib_runtime, node, "else_branch", &felse)); - TF_RET_CHECK(fthen); - TF_RET_CHECK(felse); - int num_inputs = fthen->fdef.signature().input_arg_size(); - // Stores indices of the "branch function" inputs that are expected to be - // compile time constants. - std::vector compile_time_const_arg_indices(num_inputs); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis( - *(fthen->graph), &compile_time_const_arg_indices, - /*compile_time_const_nodes=*/nullptr, flib_runtime)); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis( - *(felse->graph), &compile_time_const_arg_indices, - /*compile_time_const_nodes=*/nullptr, flib_runtime)); - for (int i = 0; i < compile_time_const_arg_indices.size(); i++) { - if (compile_time_const_arg_indices[i]) { - // The 0th input is the loop condition which is not passed to the - // branches. So the i'th input of a branch function corresponds to - // i + 1'th input of the If op. - const_input_idxs->push_back(i + 1); - } - } + return CondConstInputIndices({fthen, felse}, const_input_idxs, + flib_runtime); + } else if (node->type_string() == "Case") { + std::vector branch_bodies; + TF_RETURN_IF_ERROR( + GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies)); + return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime); } else { return XlaOpRegistry::CompileTimeConstantInputs(node->def(), node->op_def(), const_input_idxs); } - return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index fcc1ea2575b..d6dfa39e658 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -311,6 +311,7 @@ tf_kernel_library( srcs = ["case_op.cc"], hdrs = ["case_op.h"], deps = [ + ":if_while_utils", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index 24623768f38..5ba844e10bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/case_op.h" +#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -34,10 +35,41 @@ XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } else { has_token_input_output_ = !token_input_nodes_.empty(); } + if (ctx->HasAttr(kPropagateCompileTimeConsts)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts, + &propagate_compile_time_consts_)); + } } +namespace { + +Status ConvertCompileTimeConstArgumentsToConst( + XlaOpKernelContext* ctx, std::vector* args) { + for (int i = 0; i < args->size(); i++) { + XlaCompiler::Argument& arg = (*args)[i]; + const XlaExpression& expression = ctx->InputExpression(i + 1); + // If the input tensor is a compile time constant build a kConstant type + // argument. + if (arg.kind == XlaCompiler::Argument::kParameter) { + // NOTE: We can not simply check that this is Kind::kConstant because + // this could be the output of a MetadataOnly op e.g. Size. + xla::StatusOr> maybe_constant = + expression.ResolveConstant(ctx->compiler()->client()); + if (maybe_constant.ok() && maybe_constant.ValueOrDie().has_value()) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = expression.dtype(); + arg.constant_value = std::move(maybe_constant.ValueOrDie().value()); + arg.shape = expression.GetShape().ValueOrDie(); + } + } + } + return Status::OK(); +} + +} // namespace + // TODO(b/35949885): There is duplication here with the handling of the -// while_op. Refactor the common code out/rework. +// while_op/if_op. Refactor the common code out/rework. void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { xla::XlaBuilder* b = ctx->builder(); int num_branches = branches_.size(); @@ -84,12 +116,30 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; - arg.shape = ctx->InputShape(i + 1); + // Use the xla::Shape for the input instead of ctx->InputShape. This is + // necessary for forwarding shapes of DT_VARIANTs, e.g. TensorLists. + auto shape_or = ctx->builder()->GetShape(ctx->Input(i + 1)); + OP_REQUIRES_OK(ctx, shape_or.status()); + arg.shape = shape_or.ValueOrDie(); VLOG(2) << "Arg type: " << DataTypeString(arg.type) << " shape: " << arg.HumanString(); } } + if (propagate_compile_time_consts_) { + // Replaces `kParameter` type args in `arguments` with `kConstant` if + // the op input corresponding to that arg is a compile-time const. This + // is necessary to propagate compile time consts to ops in the branch + // functions. + // Note: Propagating "all" compile-time constants may not be necessary. We + // should ideally only propagate consts which are required to be compile + // time constants in the branch functions. But that would require calling + // BackwardsConstAnalysis here which would be expensive. However, if we + // start hitting memory issues we should revisit this. + OP_REQUIRES_OK(ctx, + ConvertCompileTimeConstArgumentsToConst(ctx, &arguments)); + } + // Compile each branch of the conditional. XlaCompiler::CompileOptions options; options.use_tuple_arg = true; @@ -158,8 +208,6 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { } OP_REQUIRES(ctx, branch_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); - OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1, - errors::FailedPrecondition("Expected one input shape")); OP_REQUIRES( ctx, xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape), @@ -227,7 +275,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); } else { - inputs[i] = ctx->Input(i + 1); + inputs[i] = ctx->Input(input_num); } } auto input_tuple = xla::Tuple(b, inputs); @@ -292,6 +340,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building Case"; } -REGISTER_XLA_OP(Name("Case").AllowResourceTypes(), XlaCaseOp); +REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(), + XlaCaseOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h index ea14b18149c..4a61707864e 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.h +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -55,6 +55,10 @@ class XlaCaseOp : public XlaOpKernel { DataTypeVector output_types_; bool has_token_input_output_; std::vector token_input_nodes_; + // Whether to propagate compile time consts into the cond branches. + // This is not supported by default now since it may cause HBM memory + // overheads. + bool propagate_compile_time_consts_ = false; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 7548442d1ad..b8eda1de94a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -776,7 +776,7 @@ Status XlaCompiler::BuildArguments( } } - if (input_to_args->empty()) { + if (input_to_args->empty() && !use_tuple_arg) { return Status::OK(); } @@ -829,8 +829,9 @@ Status XlaCompiler::BuildArguments( xla::ShapeUtil::GetLeafCount(arg_shapes[i]), args[input_to_args->at(i)].is_same_data_across_replicas); } - xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, - tuple_sharding); + xla::XlaScopedShardingAssignment assign_tuple_sharding( + builder, input_to_args->empty() ? absl::optional() + : tuple_sharding); tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple", is_same_across_replicas); } else { diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 40de1513fa0..85e08249057 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3747,6 +3747,7 @@ cuda_py_test( ":while_v2", "//tensorflow/python/eager:def_function", ], + shard_count = 2, xla_enable_strict_auto_jit = True, ) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 62eb993fe28..30c820a6020 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2593,8 +2593,15 @@ class Operation(object): func = attr_value_pb2.NameAttrList(name=func_name) self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func)) + def _set_func_list_attr(self, attr_name, func_names): + """Private method used to set a list(function) attribute in the node_def.""" + funcs = [attr_value_pb2.NameAttrList(name=func_name) + for func_name in func_names] + funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs) + self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list)) + def _set_type_list_attr(self, attr_name, types): - """Private method used to set a function attribute in the node_def.""" + """Private method used to set a list(type) attribute in the node_def.""" if not types: return if isinstance(types[0], dtypes.DType): @@ -2603,7 +2610,7 @@ class Operation(object): self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list)) def _set_shape_list_attr(self, attr_name, shapes): - """Private method used to set a function attribute in the node_def.""" + """Private method used to set a list(shape) attribute in the node_def.""" shapes = [s.as_proto() for s in shapes] shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index feb10431d40..ae1181803b7 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -776,6 +776,33 @@ class ControlFlowTest(test.TestCase): "Tensor true_branch:0 in true_fn is accessed from false_fn."): f() + def testSwitchCaseAccessBranch1TensorInBranch4Raises(self): + + @def_function.function + def f(): + c = constant_op.constant(1.) + inputs = {"c": c} + + def br1_fn(inputs): + inputs["c"] = array_ops.identity(inputs["c"], name="br1_identity") + return inputs["c"] + + def br4_fn(inputs): + return array_ops.identity(inputs["c"]) + + def other_fn(): + return array_ops.identity(c) + + return control_flow_ops.switch_case( + constant_op.constant(2), + [other_fn, lambda: br1_fn(inputs), other_fn, other_fn, + lambda: br4_fn(inputs)]) + + with self.assertRaisesRegexp( + ValueError, + "Tensor br1_identity:0 in branch 1 is accessed from branch 4."): + f() + def testCondListOutput(self): with self.cached_session() as sess: x = constant_op.constant(10) diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 5d661397b3d..ed5869f7d34 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -47,6 +47,9 @@ from tensorflow.python.util import nest # readability. # pylint: disable=protected-access +_COND = 1 +_CASE = 2 + def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" @@ -80,7 +83,7 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): add_control_dependencies=add_control_dependencies, op_return_value=pred) - verify_captures(true_graph, false_graph) + verify_captures(_COND, [true_graph, false_graph]) return _build_cond(pred, true_graph, false_graph, true_graph.external_captures, false_graph.external_captures, @@ -106,8 +109,7 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name false_grad_graph = _create_grad_func( false_graph, grads, util.unique_grad_fn_name(false_graph.name)) - if (true_grad_graph.if_op_needs_rewrite or - false_grad_graph.if_op_needs_rewrite): + if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite): # Modify 'op' to output the intermediates needed by the grad functions. Note # that all needed intermediates are wrapped in optionals. Each optional # intermediate output will have a value iff its corresponding branch is @@ -122,18 +124,18 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name true_intermediates = true_grad_graph.xla_intermediates false_intermediates = false_grad_graph.xla_intermediates extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( - true_graph, false_graph, true_intermediates, false_intermediates) + [true_graph, false_graph], [true_intermediates, false_intermediates]) else: true_intermediates = true_grad_graph.wrapped_intermediates false_intermediates = false_grad_graph.wrapped_intermediates # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( - true_graph, false_graph, true_intermediates, false_intermediates) + [true_graph, false_graph], [true_intermediates, false_intermediates]) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # TODO(skyewm): indicate it's an internal bug if this fails. - _check_same_outputs(true_graph, false_graph) + _check_same_outputs(_COND, [true_graph, false_graph]) true_graph.name += "_rewritten" false_graph.name += "_rewritten" @@ -153,7 +155,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) # This modifies true_grad_graph and false_grad_graph. - _make_output_composite_tensors_match(true_grad_graph, false_grad_graph) + _make_output_composite_tensors_match(_COND, + [true_grad_graph, false_grad_graph]) outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph, true_grad_inputs, false_grad_inputs) @@ -185,13 +188,13 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ - _make_indexed_slices_indices_types_match(true_graph, false_graph) - _check_same_outputs(true_graph, false_graph) + _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) + _check_same_outputs(_COND, [true_graph, false_graph]) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. - cond_inputs = _make_inputs_match(true_graph, false_graph, - true_inputs, false_inputs) + cond_inputs = _make_inputs_match([true_graph, false_graph], + [true_inputs, false_inputs]) # Create the If op. with ops.control_dependencies( @@ -226,38 +229,45 @@ def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, tensors) -def _get_func_graphs(if_op): +def _get_func_graphs(op): """Returns `FuncGraph`s for the input op branches. Args: - if_op: The _If Operation. + op: The If or Case Operation. Returns: - A 2-tuple of the `FuncGraph`s of the then_branch and else_branch. + A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches + for Case). """ - def _get_func_graph_for_branch(branch_name): + + def _get_func_graph_for_branch(name_attr_list): """Generates and returns a FuncGraph for the given branch.""" - inputs = if_op.inputs[1:] # First input is pred. + inputs = op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in inputs] - func_name = if_op.get_attr(branch_name).name - fdef = if_op.graph._get_function(func_name).definition - # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g. + fdef = op.graph._get_function(name_attr_list.name).definition + # `op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed - # from inside a Defun. We build the `func_graph` with `if_op.graph` as its + # from inside a Defun. We build the `func_graph` with `op.graph` as its # `outer_graph`. This resembles how the `FuncGraph` was built in the # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. - with if_op.graph.as_default(): + with op.graph.as_default(): func_graph = function_def_to_graph.function_def_to_graph( fdef, input_shapes) func_graph.captures = collections.OrderedDict(zip(inputs, func_graph.inputs)) - # Set the if op so that the gradient code can use it. - func_graph._if = if_op + # Link the op so that the gradient code can use it. + func_graph._forward_cond = op return func_graph - return (_get_func_graph_for_branch("then_branch"), - _get_func_graph_for_branch("else_branch")) + if op.type == "If": + return (_get_func_graph_for_branch(op.get_attr("then_branch")), + _get_func_graph_for_branch(op.get_attr("else_branch"))) + elif op.type == "Case": + return [_get_func_graph_for_branch(branch_fn) + for branch_fn in op.get_attr("branches")] + else: + raise ValueError("Unsupported op type: {}".format(op.type)) def _grad_fn(func_graph, grads): @@ -348,7 +358,7 @@ def _resolve_grad_inputs(cond_graph, grad_graph): # to If op outputs. So we get the outer tensor corresponding to those # from the list of `external_captures`. try: - t = t.graph._if.outputs[t.graph.outputs.index(t)] + t = t.graph._forward_cond.outputs[t.graph.outputs.index(t)] except ValueError: index = t.graph.internal_captures.index(t) t = t.graph.external_captures[index] @@ -373,211 +383,191 @@ def _get_intermediates(func_graph): return intermediates -def _separate_unique_inputs(true_inputs, false_inputs): - """Separates tensors appearing only in true_inputs or false_inputs, or both. - - Args: - true_inputs: list of Tensors - false_inputs: list of Tensors - - Returns: - Three lists of Tensors: - 1. The tensors that appear in both true_inputs and false_inputs - 2. The tensors that only appear in true_inputs - 3. The tensors that only appear in false_inputs - """ - true_inputs = set(true_inputs) - false_inputs = set(false_inputs) - - shared_inputs = true_inputs.intersection(false_inputs) - true_only_inputs = true_inputs - false_inputs - false_only_inputs = false_inputs - true_inputs - - return list(shared_inputs), list(true_only_inputs), list(false_only_inputs) - - -def _make_intermediates_match(true_graph, false_graph, - true_optionals, false_optionals): +def _make_intermediates_match(branch_graphs, branch_optionals): """Returns new optionals lists that have matching signatures. This is done by mirroring each list in the other using none optionals. There is no merging of like optionals. Args: - true_graph: FuncGraph - false_graph: FuncGraph - true_optionals: a list of optional Tensors from true_graph - false_optionals: a list of optional Tensors from false_graph + branch_graphs: `list` of `FuncGraph`. + branch_optionals: `list` of `list`s of optional `Tensor`s from other + branch_graphs Returns: - A new list of Tensors in true_graph and a new list of Tensors in - false_graph. The two lists have the same number of Tensors, all of which - will be optionals of the same shape/type. + A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the + same number of `Tensor`s, all of which will be optionals of the same + shape/type. """ - new_true_optionals = (true_optionals + - _create_none_optionals(true_graph, false_optionals)) - new_false_optionals = (_create_none_optionals(false_graph, true_optionals) - + false_optionals) - return new_true_optionals, new_false_optionals + new_branch_optionals = [] + # Since the intermediates are optionals with dtype variant, we only need + # enough room for the longest list of intermediates. + intermediates_size = max(len(o) for o in branch_optionals) + for i, branch_graph in enumerate(branch_graphs): + other_optionals = _create_none_optionals( + branch_graph, intermediates_size - len(branch_optionals[i])) + new_branch_optionals.append(branch_optionals[i] + other_optionals) + return new_branch_optionals -def _make_intermediates_match_xla(true_graph, false_graph, true_intermediates, - false_intermediates): +def _make_intermediates_match_xla(branch_graphs, branch_intermediates): """Like _make_intermediates_match but for the XLA case.""" - new_true_intermediates = (true_intermediates + - _create_fakeparams(true_graph, false_intermediates)) - new_false_intermediates = (_create_fakeparams(false_graph, true_intermediates) - + false_intermediates) - return new_true_intermediates, new_false_intermediates + new_branch_intermediates = [] + for i, branch_graph in enumerate(branch_graphs): + other_fakeparams = _create_fakeparams( + branch_graph, + sum((bi for bi in branch_intermediates + if bi is not branch_intermediates[i]), [])) + num_preceding = sum(len(bi) for bi in branch_intermediates[:i]) + new_branch_intermediates.append(other_fakeparams[:num_preceding] + + branch_intermediates[i] + + other_fakeparams[num_preceding:]) + return new_branch_intermediates -def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): - """Modifies true_graph and false_graph so they have the same input signature. +def _make_inputs_match(branch_graphs, branch_inputs): + """Modifies branch_graphs so they have the same input signature. - This method reorders and/or adds parameters to true_graph and false_graph so + This method reorders and/or adds parameters to each graph in branch_graphs so they have the same input signature, and updates the 'inputs' and 'captured' - fields of both graphs accordingly. It uses the input tensors from the outer + fields of each graph accordingly. It uses the input tensors from the outer graph to avoid duplicating shared arguments. Args: - true_graph: FuncGraph - false_graph: FuncGraph - true_inputs: a list of Tensors in the outer graph. The inputs for - true_graph. - false_inputs: a list of Tensors in the outer graph. The inputs for - false_graph. + branch_graphs: a `list` of `FuncGraph` + branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The + inputs for the corresponding graph in `branch_graphs`. Returns: - A new list of Tensors from the outer graph that are the new inputs for both - true_graph and false_graph. This is a deduped version of true_inputs + - false_inputs. + A new list of Tensors from the outer graph that are the new inputs for each + branch_graph. This is a deduped version of `sum(branch_inputs)`. """ - shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( - true_inputs, false_inputs) + assert len(branch_graphs) == len(branch_inputs) + new_inputs = set() + for branch_in in branch_inputs: + new_inputs |= set(branch_in) + new_inputs = list(new_inputs) - new_inputs = shared_inputs + true_only_inputs + false_only_inputs + for branch_graph, branch_in in zip(branch_graphs, branch_inputs): + branch_input_to_param = dict(zip(branch_in, branch_graph.inputs)) + input_list = [] + for in_t in new_inputs: + param = branch_input_to_param.get(in_t, None) + if param is None: + param = _create_dummy_input(branch_graph, in_t) + input_list.append(param) - true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) - false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) + branch_graph.inputs = input_list - true_graph.inputs = ( - [true_input_to_param[t] for t in shared_inputs] + - [true_input_to_param[t] for t in true_only_inputs] + - _create_dummy_inputs(true_graph, false_only_inputs)) - - false_graph.inputs = ( - [false_input_to_param[t] for t in shared_inputs] + - _create_dummy_inputs(false_graph, true_only_inputs) + - [false_input_to_param[t] for t in false_only_inputs]) - - # Rewrite the FuncGraphs' state to reflect the new inputs. - true_graph.captures = collections.OrderedDict(zip(new_inputs, - true_graph.inputs)) - false_graph.captures = collections.OrderedDict(zip(new_inputs, - false_graph.inputs)) + # Rewrite the FuncGraphs' state to reflect the new inputs. + branch_graph.captures = collections.OrderedDict( + zip(new_inputs, branch_graph.inputs)) return new_inputs -def _make_output_composite_tensors_match(true_graph, false_graph): - """Modifies true_graph and false_graph so they have the same output signature. +def _make_output_composite_tensors_match(op_type, branch_graphs): + """Modifies each branch_graph's outputs to have the same output signature. Currently the only transformation implemented is turning a Tensor into an equivalent IndexedSlices if the other branch returns an IndexedSlices. - Updates {true,false}_graph.{outputs,structured_outputs}. + Updates branch_graph.{outputs,structured_outputs} for each branch_graph in + branch_graphs. Args: - true_graph: FuncGraph - false_graph: FuncGraph + op_type: _COND or _CASE + branch_graphs: `list` of `FuncGraph` Raises: - TypeError: if a pair of outputs cannot be rewritten. + TypeError: if a set of outputs cannot be rewritten. """ # Note: since this is only used for gradient graphs, we do not expect the # outputs to be structured (e.g. nested lists), and thus do not need to use # nest.flatten, etc. - true_outputs = list(true_graph.structured_outputs) - false_outputs = list(false_graph.structured_outputs) - assert len(true_outputs) == len(false_outputs) + assert branch_graphs + branch_outputs = [g.structured_outputs for g in branch_graphs] + outputs_per_branch = list(len(outs) for outs in branch_outputs) + assert len(set(outputs_per_branch)) == 1, outputs_per_branch - for idx, (true_out, false_out) in enumerate(zip(true_outputs, false_outputs)): - if type(true_out) == type(false_out): # pylint: disable=unidiomatic-typecheck + for output_idx, branch_outs in enumerate(zip(*branch_outputs)): + if len(set(type(out) for out in branch_outs)) == 1: continue - if (isinstance(true_out, ops.IndexedSlices) and - isinstance(false_out, ops.Tensor)): - with false_graph.as_default(): - false_outputs[idx] = math_ops._as_indexed_slices(false_out) - elif (isinstance(true_out, ops.Tensor) and - isinstance(false_out, ops.IndexedSlices)): - with true_graph.as_default(): - true_outputs[idx] = math_ops._as_indexed_slices(true_out) - else: - raise TypeError( - "Cannot reconcile tf.cond %i-th outputs:\n" - " true_fn returned: %s\n" - " false_fn returned: %s" % (idx, true_out, false_out)) + if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs): + continue + for branch_idx, branch_out in enumerate(branch_outs): + if isinstance(branch_out, ops.IndexedSlices): + continue + elif isinstance(branch_out, ops.Tensor): + with branch_graphs[branch_idx].as_default(): + branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices( + branch_out) + else: + raise TypeError( + "Cannot reconcile {op_name} {output_idx}-th outputs:\n" + " outputs from all branches: {outputs}".format( + op_name="tf.cond" if op_type == _COND else "tf.case", + output_idx=output_idx, + outputs=branch_outs)) - true_graph.structured_outputs = true_outputs - true_graph.outputs = func_graph_module.flatten(true_outputs) - false_graph.structured_outputs = false_outputs - false_graph.outputs = func_graph_module.flatten(false_outputs) + for branch_graph, branch_outs in zip(branch_graphs, branch_outputs): + branch_graph.structured_outputs = branch_outs + branch_graph.outputs = func_graph_module.flatten(branch_outs) -def _make_indexed_slices_indices_types_match(true_graph, false_graph): - """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs.""" +def _make_indexed_slices_indices_types_match(op_type, branch_graphs): + """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" + assert branch_graphs indexed_slice_indices = [] current_index = 0 - true_outputs_flat_with_composites = nest.flatten( - true_graph.structured_outputs, expand_composites=False) - false_outputs_flat_with_composites = nest.flatten( - false_graph.structured_outputs, expand_composites=False) + branch_outputs_flat_with_composites = [ + nest.flatten(branch_graph.structured_outputs, expand_composites=False) + for branch_graph in branch_graphs + ] + outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites] + assert len(set(outs_per_branch)) == 1, outs_per_branch # Store indices of IndexedSlices.indices in `indexed_slice_indices`. - for idx, (true_out, false_out) in enumerate( - zip(true_outputs_flat_with_composites, - false_outputs_flat_with_composites)): - if isinstance(true_out, ops.IndexedSlices) != isinstance( - false_out, ops.IndexedSlices): - raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n" - " true_fn returned: %s\n" - " false_fn returned: %s" % (idx, true_out, false_out)) - if isinstance(true_out, ops.IndexedSlices): + for output_idx, branch_outs in enumerate( + zip(*branch_outputs_flat_with_composites)): + if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1: + raise TypeError("Cannot reconcile {op_name} {output_idx}-th outputs:\n" + " branches returned: {outputs}".format( + op_name="tf.cond" if op_type == _COND else "tf.case", + output_idx=output_idx, + outputs=branch_outs)) + if isinstance(branch_outs[0], ops.IndexedSlices): # indices is the second component of the composite tensor. indexed_slice_indices.append(current_index + 1) - if nest.is_sequence_or_composite(true_out): - current_index += len(nest.flatten(true_out, expand_composites=True)) + if nest.is_sequence_or_composite(branch_outs[0]): + current_index += len(nest.flatten(branch_outs[0], expand_composites=True)) else: current_index += 1 if not indexed_slice_indices: return - if current_index != len(true_graph.outputs): - raise ValueError("Insufficient elements in true_graph.outputs.\n" + if current_index != len(branch_graphs[0].outputs): + raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" "Expected: %i\n" - "Actual: %i" % (current_index, len(true_graph.outputs))) + "Actual: %i" % + (current_index, len(branch_graphs[0].outputs))) # Cast indices with mismatching types to int64. for index in indexed_slice_indices: - if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): + if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64) + for bg in branch_graphs): raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " - "Found: %s" % str(true_graph.outputs[index].dtype)) - if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): - raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " - "Found: %s" % str(false_graph.outputs[index].dtype)) - if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype: - if false_graph.outputs[index].dtype == dtypes.int32: - with false_graph.as_default(): - false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index], - dtypes.int64) - else: - with true_graph.as_default(): - true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index], - dtypes.int64) + "Found: %s" % + str([bg.outputs[index].dtype for bg in branch_graphs])) + if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1: + for branch_graph in branch_graphs: + if branch_graph.outputs[index].dtype == dtypes.int32: + with branch_graph.as_default(): + branch_graph.outputs[index] = math_ops.cast( + branch_graph.outputs[index], dtypes.int64) - true_graph.structured_outputs = func_graph_module.pack_sequence_as( - true_graph.structured_outputs, true_graph.outputs) - false_graph.structured_outputs = func_graph_module.pack_sequence_as( - false_graph.structured_outputs, false_graph.outputs) + for branch_graph in branch_graphs: + branch_graph.structured_outputs = func_graph_module.pack_sequence_as( + branch_graph.structured_outputs, branch_graph.outputs) def _wrap_intermediates(func_graph, intermediates): @@ -585,33 +575,33 @@ def _wrap_intermediates(func_graph, intermediates): return [gen_dataset_ops.optional_from_value([t]) for t in intermediates] -def _create_dummy_inputs(func_graph, template_tensors): +def _create_dummy_input(func_graph, template_tensor): """Creates tensors in func_graph to represent template_tensors. Args: func_graph: FuncGraph. - template_tensors: a list of tensors in the outer graph. + template_tensor: a tensor in the outer graph. Returns: - A list of tensors in func_graph. + A tensor in func_graph. """ with func_graph.as_default(): - return [array_ops.placeholder(t.dtype, shape=t.shape) - for t in template_tensors] + return array_ops.placeholder( + template_tensor.dtype, shape=template_tensor.shape) -def _create_none_optionals(func_graph, template_tensors): - """Creates none optionals in func_graph to represent template_tensors. +def _create_none_optionals(func_graph, n): + """Creates `n` `None` optionals in func_graph. Args: func_graph: FuncGraph. - template_tensors: a list of tensors in func_graph. + n: `int` the number of `None` optionals to make. Returns: A list of tensors in func_graph. """ with func_graph.as_default(): - return [gen_dataset_ops.optional_none() for _ in template_tensors] + return [gen_dataset_ops.optional_none() for _ in range(n)] def _create_fakeparams(func_graph, template_tensors): @@ -621,55 +611,69 @@ def _create_fakeparams(func_graph, template_tensors): for t in template_tensors] -def _check_same_outputs(true_graph, false_graph): - """Raises an error if true_graph and false_graph have different outputs.""" +def _check_same_outputs(op_type, graphs): + """Raises an error if `graphs` have different outputs.""" - def error(error_detail): + def error(branch_idx, error_detail): raise TypeError( - "true_fn and false_fn arguments to tf.cond must have the same number, " - "type, and overall structure of return values.\n" + "{b0_name} and {bn_name} arguments to {op_name} must have the same " + "number, type, and overall structure of return values.\n" "\n" - "true_fn output: %s\n" - "false_fn output: %s\n" + "{b0_name} output: {b0_out}\n" + "{bn_name} output: {bn_out}\n" "\n" "Error details:\n" - "%s" % (true_graph.structured_outputs, false_graph.structured_outputs, - error_detail)) + "{detail}".format( + b0_name="true_fn" if op_type == _COND else "branches[0]", + bn_name=("false_fn" if op_type == _COND else + "branches[{}]".format(branch_idx)), + op_name="tf.cond" if op_type == _COND else "tf.case", + b0_out=graphs[0].structured_outputs, + bn_out=graphs[branch_idx].structured_outputs, + detail=error_detail)) - try: - nest.assert_same_structure(true_graph.structured_outputs, - false_graph.structured_outputs, - expand_composites=True) - except (ValueError, TypeError) as e: - error(str(e)) + for b in range(1, len(graphs)): + try: + nest.assert_same_structure( + graphs[0].structured_outputs, + graphs[b].structured_outputs, + expand_composites=True) + except (ValueError, TypeError) as e: + error(b, str(e)) - assert len(true_graph.outputs) == len(false_graph.outputs) - for true_out, false_out in zip(true_graph.outputs, false_graph.outputs): - if true_out.dtype != false_out.dtype: - error("%s and %s have different types" % (true_out, false_out)) + assert len(graphs[0].outputs) == len(graphs[b].outputs) + for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs): + if b0_out.dtype != bn_out.dtype: + error(b, "%s and %s have different types" % (b0_out, bn_out)) -def _get_output_shapes(true_graph_outputs, false_graph_outputs): - output_shapes = [ - t_out.shape.most_specific_compatible_shape(f_out.shape) - for t_out, f_out in zip(true_graph_outputs, false_graph_outputs) - ] +def _get_output_shapes(*branch_graph_outputs): + output_shapes = [] + for out_by_branch in zip(*branch_graph_outputs): + shape = out_by_branch[0].shape + for other_out in out_by_branch[1:]: + shape = shape.most_specific_compatible_shape(other_out.shape) + output_shapes.append(shape) return output_shapes -def verify_captures(true_graph, false_graph): - """Verify that a true_fn tensor is not accessed in false_fn and vice-versa.""" - for t in false_graph.external_captures: - if not isinstance(t, ops.EagerTensor) and t.graph is true_graph: - raise ValueError("Tensor {} in true_fn is accessed from false_fn.".format( - t.name)) - # Note: This is technically not possible right now because `false_graph` - # is built "after" `true_graph` but we add this check for completeness and to +def verify_captures(op_type, branch_graphs): + """Verify that a branch's tensor is not accessed in another branch fn.""" + # Note: It is technically not possible for lower-branch_index branches to + # capture tensors from higher-branch_index branches, because of the order of + # branch graph construction, but we check all for completeness and to # guard against potential future changes. - for t in true_graph.external_captures: - if not isinstance(t, ops.EagerTensor) and t.graph is false_graph: - raise ValueError("Tensor {} in false_fn is accessed from true_fn.".format( - t.name)) + other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)} + for i, branch_graph in enumerate(branch_graphs): + for t in branch_graph.external_captures: + if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs: + branch_names = ["true_fn", "false_fn"] if op_type == _COND else [ + "branch {}".format(bi) for bi in range(len(branch_graphs))] + raise ValueError( + "Tensor {tname} in {b0name} is accessed from {b1name}.".format( + tname=t.name, + b0name=branch_names[other_branch_graphs[t.graph]], + b1name=branch_names[i])) class _CondGradFuncGraph(util.CondBranchFuncGraph): @@ -679,14 +683,14 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph): gradient computation in optionals. Attributes: - if_op_needs_rewrite: True if any intermediates were captured, meaning the + op_needs_rewrite: True if any intermediates were captured, meaning the forward If op needs to be written to output the wrapped intermediates. """ def __init__(self, name, forward_graph): super(_CondGradFuncGraph, self).__init__( name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access - self.if_op_needs_rewrite = False + self.op_needs_rewrite = False self._forward_graph = forward_graph # Maps from forward intermediate tensor -> the unwrapped captured # intermediate. @@ -719,7 +723,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph): # TODO(skyewm,jpienaar): can XLA support optionals? if tensor not in self.captures: self.xla_intermediates.append(tensor) - self.if_op_needs_rewrite = True + self.op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) captured_tensor = self._indirect_captures.get(tensor) @@ -756,7 +760,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph): # 'tensor' hasn't been wrapped, do it now. with self._forward_graph.as_default(): optional = gen_dataset_ops.optional_from_value([tensor]) - self.if_op_needs_rewrite = True + self.op_needs_rewrite = True self._wrapped_intermediates[tensor] = optional optional = self._wrapped_intermediates[tensor] @@ -767,3 +771,182 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph): self._indirect_captures[tensor] = captured_tensor return captured_tensor + + +def indexed_case(branch_index, branch_fns, name="indexed_case"): + """Like conv_v2, except emits a Case op instead of an If.""" + if isinstance(branch_index, int): + raise TypeError("branch_index must not be a Python int", branch_index) + + with ops.name_scope(name) as scope: + branch_names = [ + util.unique_fn_name(scope, "branch{}".format(b)) + for b in range(len(branch_fns)) + ] + + # Automatic control dependencies are added in defuns, but not in v1 + # graphs. Propagate that behavior here. + add_control_dependencies = ops.get_default_graph()._add_control_dependencies + branch_index = ops.convert_to_tensor(branch_index, name="branch_index") + + branch_graphs = [] + for branch_name, branch_fn in zip(branch_names, branch_fns): + branch_graphs.append( + func_graph_module.func_graph_from_py_func( + branch_name, + branch_fn, + [], + {}, + func_graph=util.CondBranchFuncGraph( + branch_name, + collections=ops.get_default_graph()._collections), # pylint: disable=protected-access + add_control_dependencies=add_control_dependencies, + op_return_value=branch_index)) + + verify_captures(_CASE, branch_graphs) + return _build_case( + branch_index, + branch_graphs, [g.external_captures for g in branch_graphs], + name=scope) + + +@ops.RegisterGradient("Case") +def _CaseGrad(op, *grads): # pylint: disable=invalid-name + """The gradient of a Case op produced (w/ branch_index) by tf.case.""" + # Get the if operator (this logic handles the case where op is a MockOp) + case_op = op.outputs[0].op + branch_graphs = _get_func_graphs(case_op) + assert branch_graphs + # Note: op.graph != ops.get_default_graph() when we are computing the gradient + # of a nested cond. + for branch_graph in branch_graphs: + assert branch_graph.outer_graph == case_op.graph + + # Create grad functions that compute the gradient of the branch forward + # graphs. These functions will capture tensors from the forward pass + # functions. + branch_grad_graphs = [] + for branch_graph in branch_graphs: + branch_grad_graphs.append( + _create_grad_func(branch_graph, grads, + util.unique_grad_fn_name(branch_graph.name))) + + if any(g.op_needs_rewrite for g in branch_grad_graphs): + # Modify 'op' to output the intermediates needed by the grad functions. Note + # that all needed intermediates are wrapped in optionals. Each optional + # intermediate output will have a value iff its corresponding branch is + # taken. + # NOTE(bjp): if there are any active sessions, this modification to `op` + # may make them unrunnable! + + if control_flow_util.InXlaContext(ops.get_default_graph()): + # XLA does not yet support optionals, so output intermediates directly and + # make them match via FakeParams, which can be converted to zeros in XLA. + # TODO(bjp,jpienaar): can XLA support optionals? + branches_intermediates = [ + branch_grad_graph.xla_intermediates + for branch_grad_graph in branch_grad_graphs + ] + extra_branch_outputs = _make_intermediates_match_xla( + branch_graphs, branches_intermediates) + else: + branch_intermediates = [ + g.wrapped_intermediates for g in branch_grad_graphs + ] + # Make outputs match by adding none optionals. + extra_branch_outputs = _make_intermediates_match(branch_graphs, + branch_intermediates) + + for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs): + branch_graph.outputs.extend(extra_outputs) + _make_indexed_slices_indices_types_match(_CASE, branch_graphs) + # TODO(bjp): indicate it's an internal bug if this fails. + _check_same_outputs(_CASE, branch_graphs) + + for branch_graph in branch_graphs: + branch_graph.name += "_rewritten" + + case_op._set_func_list_attr("branches", [ + util.create_new_tf_function(branch_graph) + for branch_graph in branch_graphs + ]) + case_op._set_type_list_attr("Tout", branch_graphs[0].output_types) + case_op._set_shape_list_attr("output_shapes", + branch_graphs[0].output_shapes) + case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]], + [t.shape for t in extra_branch_outputs[0]]) + + # Resolve references to forward graph tensors in grad graphs and ensure + # they are in-scope, i.e., belong to one of outer graphs of the grad graph. + branches_grad_inputs = [ + _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph, + branch_grad_graph in zip(branch_graphs, branch_grad_graphs) + ] + + # This modifies the graphs in branch_grad_graphs. + _make_output_composite_tensors_match(_CASE, branch_grad_graphs) + + outputs = _build_case(case_op.inputs[0], branch_grad_graphs, + branches_grad_inputs) + + # The predicate has no gradient. + return [None] + outputs + + +def _build_case(branch_index, branch_graphs, branch_inputs, name=None): + """Creates an `Case` op from `branch_index`, branch graphs and inputs. + + Note that this modifies `branch_graphs` to make the inputs match, and to + output all intermediates values so they're available for the gradient + computation. + + `branch_graphs` need not have the same input types, but they must + have the same outpute types. + + Args: + branch_index: integer Tensor + branch_graphs: List of FuncGraph + branch_inputs: List of lists of Tensors to be passed to corresponding + branch_graph as input. + name: the name for the Case op. + + Returns: + A list of Tensors which are the outputs of the Case op. Does not include + added intermediate outputs. + """ + _check_same_outputs(_CASE, branch_graphs) + + # Add inputs to branch_graphs to make them match. Note that this modifies the + # graphs in `branch_graphs`. + case_inputs = _make_inputs_match(branch_graphs, branch_inputs) + + # Create the Case op. + with ops.control_dependencies( + sum((list(bg.control_captures) for bg in branch_graphs), [])): + tensors = gen_functional_ops.case( + branch_index, + case_inputs, [t.dtype for t in branch_graphs[0].outputs], + [util.create_new_tf_function(g) for g in branch_graphs], + output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), + name=name) + + # TODO(b/110167197) this requires Case to have at least 1 output + case_op = tensors[0].op + # TODO(b/131304144): Enable lowering Case to SwitchN/Merge for graph mode. + # util.maybe_set_lowering_attr(case_op) + util.maybe_propagate_compile_time_consts_in_xla(case_op) + + # Return identities for each output of the Case op, rather than the output of + # the Case op directly. This makes pruning work if the output of select_case() + # is fetched: the lowering pass converts the Case outputs into IdentityN + # outputs, which if fetched will cause all ops in the taken branch to be run + # (since it takes all merge ops as input). After lowering, each output + # identity op will end up with only the appropriate merge op as input. + # TODO(b/79984175): this doesn't have to be a tuple once we covert to the + # correct output structure + tensors = [array_ops.identity(t) for t in tensors] + + # Prevent fetching since the variant outputs can't be fetched directly. + case_op.graph.prevent_fetching(case_op) + return func_graph_module.pack_sequence_as(branch_graphs[0].structured_outputs, + tensors) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 4ad443e4b84..171e57e85ed 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -3931,6 +3931,101 @@ def _case_helper(cond_fn, return fn() +def _indexed_case_verify_and_canonicalize_args(branch_fns, default, + branch_index): + """Verifies input arguments for the case function. + + Args: + branch_fns: Dict or list of pairs of an `int` and a callable which + returns a list of tensors. + default: Optional callable that returns a list of tensors. + branch_index: Optional int `Tensor`, which selects for the corresponding + pred_fn_pair. + + Raises: + TypeError: If `branch_fns` is not a list/dictionary. + TypeError: If `branch_fns` is a list but does not contain 2-tuples or + callables. + TypeError: If `fns[i]` is not callable for any i, or `default` is not + callable. + + Returns: + branch_fns: validated list of callables for each branch (default last). + """ + if not isinstance(branch_index, ops.Tensor): + raise TypeError("branch_index must a Tensor, got {}".format( + type(branch_index))) + if not branch_index.dtype.is_integer: + raise TypeError("branch_index must an integer Tensor, got {}".format( + branch_index.dtype)) + + if not branch_fns: + raise ValueError("Must provide at least one item in branch_fns") + if not isinstance(branch_fns, (list, _basetuple, dict)): + raise TypeError("branch_fns must be a list, tuple, or dict") + + if isinstance(branch_fns, dict): + branch_fns = branch_fns.items() + + if all(callable(fn) for fn in branch_fns): + branch_fns = list(enumerate(branch_fns)) + + for key_fn_pair in branch_fns: + if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2: + raise TypeError("Each entry in branch_fns must be a 2-tuple") + key, branch_fn = key_fn_pair + + if not isinstance(key, int): + raise TypeError("key must be a Python `int`, got {}".format(type(key))) + + if not callable(branch_fn): + raise TypeError("fn for key {} must be callable.".format(key)) + + keys = [p[0] for p in branch_fns] + if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys): + raise ValueError( + "branch indices (keys) must form contiguous range of [0 to {}) but " + "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys))))) + actions = [p[1] for p in sorted(branch_fns)] + if default is not None: + actions.append(default) + return actions + + +def _indexed_case_helper(branch_fns, default, branch_index, name): + """Implementation of case that emits the n-way indexed Case op. + + Args: + branch_fns: Dict or list of pairs of a boolean scalar tensor, and a + callable which returns a list of tensors. + default: Optional callable that returns a list of tensors. + branch_index: Optional int `Tensor`, which selects for the corresponding + pred_fn_pair. + name: A name for this operation (optional). + + Returns: + The tensors returned by the pair whose key matched branch_index, or + those returned by `default` if none does. + + Raises: + TypeError: If `branch_fns` is not a list/dictionary. + TypeError: If `branch_fns` is a list but does not contain 2-tuples or + callables. + TypeError: If `fns[i]` is not callable for any i, or `default` is not + callable. + """ + branch_fns = _indexed_case_verify_and_canonicalize_args( + branch_fns, default, branch_index) + with ops.name_scope(name, "case", [branch_index]): + if context.executing_eagerly(): + branch_index = array_ops.where( + math_ops.less(branch_index, 0) + | math_ops.greater_equal(branch_index, len(branch_fns)), + len(branch_fns) - 1, branch_index) + return branch_fns[int(branch_index)]() + return cond_v2.indexed_case(branch_index, branch_fns) + + @tf_export("case") def case(pred_fn_pairs, default=None, @@ -3939,6 +4034,8 @@ def case(pred_fn_pairs, name="case"): """Create a case operation. + See also `tf.switch_case`. + The `pred_fn_pairs` parameter is a dict or list of pairs of size N. Each pair contains a boolean scalar tensor and a python callable that creates the tensors to be returned if the boolean evaluates to True. @@ -4037,6 +4134,82 @@ def case(pred_fn_pairs, strict=strict) +@tf_export("switch_case") +def switch_case(branch_index, + branch_fns, + default=None, + name="switch_case"): + """Create a switch/case operation, i.e. an integer-indexed conditional. + + See also `tf.case`. + + This op can be substantially more efficient than `tf.case` when exactly one + branch will be selected. `tf.switch_case` is more like a C++ switch/case + statement than `tf.case`, which is more like an if/elif/elif/else chain. + + The `branch_fns` parameter is either a dict from `int` to callables, or list + of (`int, callable) pairs, or simply a list of callables (in which case the + index is implicitly the key). The `branch_index` `Tensor` is used to select an + element in `branch_fns` with matching `int` key, falling back to `default` + if none match, or `max(keys)` if no `default` is provided. The keys must form + a contiguous set from `0` to `len(branch_fns) - 1`. + + `tf.switch_case` supports nested structures as implemented in `tf.nest`. All + callables must return the same (possibly nested) value structure of lists, + tuples, and/or named tuples. + + **Example:** + + Pseudocode: + + ```c++ + switch (branch_index) { // c-style switch + case 0: return 17; + case 1: return 31; + default: return -1; + } + ``` + or + ```python + branches = {0: lambda: 17, 1: lambda: 31} + branches.get(branch_index, lambda: -1)() + ``` + + Expressions: + + ```python + def f1(): return tf.constant(17) + def f2(): return tf.constant(31) + def f3(): return tf.constant(-1) + r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3) + # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3}) + ``` + + Args: + branch_index: An int Tensor specifying which of `branch_fns` should be + executed. + branch_fns: A `dict` mapping `int`s to callables, or a `list` of + (`int, callable) pairs, or simply a list of callables (in which case the + index serves as the key). Each callable must return a matching structure + of tensors. + default: Optional callable that returns a structure of tensors. + name: A name for this operation (optional). + + Returns: + The tensors returned by the callable identified by `branch_index`, or those + returned by `default` if no key matches and `default` was provided, or those + returned by the max-keyed `branch_fn` if no `default` is provided. + + Raises: + TypeError: If `branch_fns` is not a list/dictionary. + TypeError: If `branch_fns` is a list but does not contain 2-tuples or + callables. + TypeError: If `fns[i]` is not callable for any i, or `default` is not + callable. + """ + return _indexed_case_helper(branch_fns, default, branch_index, name) + + class XLAControlFlowContext(ControlFlowContext): """Base class for XLA and TPU control flow contexts.""" diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index f1dd4f529fc..f67d785fc0e 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -24,6 +24,8 @@ import numpy as np from tensorflow.python import tf2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -417,6 +419,28 @@ class CondTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) + @test_util.enable_control_flow_v2 + @test_util.run_in_graph_and_eager_modes + def testCond_gradient(self): + true_in, false_in = array_ops.constant(1.), array_ops.constant(5.) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(true_in) + tape.watch(false_in) + cond_true = control_flow_ops.cond( + array_ops.constant(True), lambda: true_in**2., lambda: false_in**2.) + cond_false = control_flow_ops.cond( + array_ops.constant(False), lambda: true_in**2., lambda: false_in**2.) + grads_true = tape.gradient( + cond_true, [true_in, false_in], output_gradients=3.) + grads_false = tape.gradient( + cond_false, [true_in, false_in], output_gradients=3.) + self.assertEqual(3. * 2. * 1., self.evaluate(grads_true[0])) + self.assertEqual(None if context.executing_eagerly() else 0., + self.evaluate(grads_true[1])) + self.assertEqual(3. * 2. * 5., self.evaluate(grads_false[1])) + self.assertEqual(None if context.executing_eagerly() else 0., + self.evaluate(grads_false[0])) + class ContextTest(test_util.TensorFlowTestCase): @@ -908,6 +932,179 @@ class DataTypesTest(test_util.TensorFlowTestCase): self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) +@test_util.run_all_in_graph_and_eager_modes +class IndexedCaseTest(test_util.TensorFlowTestCase): + + def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self): + nbranches = 5 + + def make_func(bi): + return lambda: array_ops.constant(bi * 10, name="br{}_out".format(bi)) + + branches = [(i, make_func(i)) for i in range(nbranches)] + for bi in range(nbranches): + branch_index = array_ops.placeholder_with_default(bi, []) + case_out = control_flow_ops.switch_case(branch_index, branches) + self.assertEqual(bi * 10, self.evaluate(case_out)) + + def testCase(self): + nbranches = 5 + + def make_func(bi): + return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) + + branches = [(i, make_func(i)) for i in range(nbranches)] + for bi in 0, 2, 3: + branch_index = array_ops.placeholder_with_default(bi, []) + case_out = control_flow_ops.switch_case(branch_index, branches) + self.assertEqual(bi * 10., self.evaluate(case_out)) + + def testCase_withDefault(self): + nbranches = 5 + + def make_func(bi): + return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) + + branches = [(i, make_func(i)) for i in range(nbranches)] + for bi in -1, 2, nbranches: + branch_index = array_ops.placeholder_with_default(bi, []) + case_out = control_flow_ops.switch_case( + branch_index, branches, default=make_func(6)) + if bi < 0 or bi >= nbranches: + expected = 60. + else: + expected = bi * 10. + self.assertEqual(expected, self.evaluate(case_out)) + + def testCase_dictWithDefault(self): + nbranches = 5 + + def make_func(bi): + return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) + + branches = {i: make_func(i) for i in range(nbranches)} + for bi in -1, 0, 3, nbranches: + branch_index = array_ops.placeholder_with_default(bi, []) + case_out = control_flow_ops.switch_case( + branch_index, branches, default=make_func(6)) + if bi < 0 or bi >= nbranches: + expected = 60. + else: + expected = bi * 10. + self.assertEqual(expected, self.evaluate(case_out)) + + def testCase_gradient(self): + nbranches = 5 + inputs = [ + array_ops.constant(float(bi), name="br{}_in".format(bi)) + for bi in range(nbranches) + ] + + def make_func(bi): + return lambda: inputs[bi]**2. + + branches = {bi: make_func(bi) for bi in range(nbranches)} + + for bi in -1, 1, 4, nbranches: + branch_index = array_ops.placeholder_with_default(bi, []) + with backprop.GradientTape() as tape: + for x in inputs: + tape.watch(x) + case_out = control_flow_ops.switch_case(branch_index, branches) + out_grad = 3. + actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad) + expected_grads = [None if context.executing_eagerly() else 0.] * nbranches + used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi + expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx + self.assertEqual(len(expected_grads), len(actual_grads)) + for expected, actual in zip(expected_grads, actual_grads): + self.assertEqual(expected, self.evaluate(actual)) + + def testCase_gradient_diffShapedIntermediates(self): + nbranches = 5 + inputs = [ + array_ops.constant( + float(bi), shape=[bi + 1], name="br{}_in".format(bi)) + for bi in range(nbranches) + ] + + def make_func(bi): + + def f(): + x = inputs[bi]**2 * inputs[bi][:bi + 1, None] + return math_ops.reduce_sum(x) + + return f + + branches = {bi: make_func(bi) for bi in range(nbranches)} + + for bi in -1, 2, nbranches: + branch_index = array_ops.placeholder_with_default(bi, []) + with backprop.GradientTape() as tape: + for x in inputs: + tape.watch(x) + case_out = control_flow_ops.switch_case(branch_index, branches) + out_grad = 3. + actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad) + used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi + expected_grads = [] + for input_idx in range(nbranches): + if used_bi == input_idx: + with backprop.GradientTape() as tape: + tape.watch(inputs[used_bi]) + y = make_func(used_bi)() + expected_grads.append( + self.evaluate( + tape.gradient(y, inputs[used_bi], output_gradients=out_grad))) + else: + expected_grads.append(None if context.executing_eagerly() else [0.] * + (input_idx + 1)) + + self.assertEqual(len(expected_grads), len(actual_grads)) + for expected, actual in zip(expected_grads, actual_grads): + if expected is None: + self.assertIsNone(actual) + else: + self.assertAllEqual(expected, self.evaluate(actual)) + + def testCase_validateIndicesContiguous(self): + + def make_func(bi): + return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) + + branches = {i: make_func(i) for i in range(0, 6, 2)} + with self.assertRaisesRegexp(ValueError, "must form contiguous"): + control_flow_ops.switch_case(array_ops.constant(0), branches) + + def testCase_validateIndicesDup(self): + + def make_func(bi): + return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) + + branches = [(i, make_func(i)) for i in range(0, 6, 2)] + branches.append((0, make_func(7))) + with self.assertRaisesRegexp(ValueError, "must form contiguous"): + control_flow_ops.switch_case(array_ops.constant(0), branches) + + def testCase_validateBranchIndex(self): + + def make_func(bi): + return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) + + branches = {i: make_func(i) for i in range(5)} + with self.assertRaisesRegexp(TypeError, "branch_index.*Tensor"): + control_flow_ops.switch_case(1, branches) + + def testCase_validateNonIntKeys(self): + + def make_func(bi): + return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) + + branches = {array_ops.constant(i): make_func(i) for i in range(5)} + with self.assertRaisesRegexp(TypeError, "must be a Python `int`"): + control_flow_ops.switch_case(array_ops.constant(1), branches) + + class CaseTest(test_util.TensorFlowTestCase): @test_util.run_deprecated_v1 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 091cc04357e..ae66ee8febd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -2256,6 +2256,10 @@ tf_module { name: "svd" argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], " } + member_method { + name: "switch_case" + argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'switch_case\'], " + } member_method { name: "tables_initializer" argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 656d026cb63..00eb4ddc75d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -988,6 +988,10 @@ tf_module { name: "subtract" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "switch_case" + argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'switch_case\'], " + } member_method { name: "tan" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "