diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 034ec82de10..42353451408 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1453,6 +1453,26 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "case_test", + size = "small", + srcs = ["case_test.py"], + disabled_backends = ["cpu_ondemand"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + use_xla_device = False, # Uses tf.function(experimental_compile=True) + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "gather_test", size = "medium", diff --git a/tensorflow/compiler/tests/case_test.py b/tensorflow/compiler/tests/case_test.py new file mode 100644 index 00000000000..3b2dff537da --- /dev/null +++ b/tensorflow/compiler/tests/case_test.py @@ -0,0 +1,87 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for while loops in XLA.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import def_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import image_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.platform import test + + +class CaseTest(xla_test.XLATestCase): + + def testCaseBasic(self): + + @def_function.function(experimental_compile=True) + def switch_case_test(branch_index): + + def f1(): + return array_ops.constant(17) + + def f2(): + return array_ops.constant(31) + + def f3(): + return array_ops.constant(-1) + + return control_flow_ops.switch_case( + branch_index, branch_fns={ + 0: f1, + 1: f2 + }, default=f3) + + with ops.device(self.device): + self.assertEqual(switch_case_test(array_ops.constant(0)).numpy(), 17) + self.assertEqual(switch_case_test(array_ops.constant(1)).numpy(), 31) + self.assertEqual(switch_case_test(array_ops.constant(2)).numpy(), -1) + self.assertEqual(switch_case_test(array_ops.constant(3)).numpy(), -1) + + def testBranchIsPruned(self): + + @def_function.function(experimental_compile=True) + def switch_case_test(): + branch_index = array_ops.constant(0) + + def f1(): + return array_ops.constant(17) + + def f2(): + # Some operations that XLA cannot compile. + image_ops.decode_image(io_ops.read_file('/tmp/bmp')) + return array_ops.constant(31) + + # This tests that we do not try to compile all branches if the branch + # index in trivially constant. + return control_flow_ops.switch_case( + branch_index, branch_fns={ + 0: f1, + 1: f2 + }, default=f2) + + with ops.device(self.device): + self.assertEqual(switch_case_test().numpy(), 17) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index bfdfe38305b..bdaeeafd295 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -316,6 +316,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index 1b15c09f7e3..fbd54f1ef39 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -21,13 +21,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branches_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &unpruned_branches_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { @@ -41,12 +42,29 @@ XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } } +std::pair, xla::XlaOp> +XlaCaseOp::GetPrunedBranchesAndIndex(XlaOpKernelContext* ctx) { + xla::Literal branch_index_literal; + bool branch_index_is_constant = + ctx->ConstantInput(0, &branch_index_literal).ok(); + + if (!branch_index_is_constant) { + return {unpruned_branches_, ctx->Input(0)}; + } + + int32 branch_index = branch_index_literal.Get({}); + if (branch_index < 0 || branch_index >= unpruned_branches_.size()) { + branch_index = unpruned_branches_.size() - 1; + } + + std::vector pruned_branch = {unpruned_branches_[branch_index]}; + return {pruned_branch, xla::ZerosLike(ctx->Input(0))}; +} + // TODO(b/35949885): There is duplication here with the handling of the // while_op/if_op. Refactor the common code out/rework. void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { - xla::XlaBuilder* b = ctx->builder(); - int num_branches = branches_.size(); - OP_REQUIRES(ctx, num_branches >= 1, + OP_REQUIRES(ctx, !unpruned_branches_.empty(), errors::InvalidArgument("Must provide at least one case branch")); OP_REQUIRES(ctx, input_type(0) == DT_INT32, errors::InvalidArgument( @@ -55,6 +73,18 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { errors::InvalidArgument( "branch_index argument must be scalar for XLA compilation")); + xla::XlaBuilder* b = ctx->builder(); + + // We opportunistically prune out branches if the branch index is a + // compile-time constant. This is important in the context of the DeviceIndex + // ops (and other such ops that may come later) since we may have a Case with + // trivially unselected branches that cannot be compiled into HLO. + std::vector branches; + xla::XlaOp branch_index; + std::tie(branches, branch_index) = GetPrunedBranchesAndIndex(ctx); + + int num_branches = branches.size(); + VLOG(1) << "Building Case: " << input_types_.size() << " inputs"; std::vector arguments(input_types_.size()); @@ -94,7 +124,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { std::vector case_bodies(num_branches); for (int branch_idx = 0; branch_idx < num_branches; branch_idx++) { OP_REQUIRES_OK(ctx, FindMustBeConstNodes( - ctx, branches_[branch_idx], + ctx, branches[branch_idx], &case_branch_must_be_const_nodes[branch_idx], &case_bodies[branch_idx])); } @@ -133,7 +163,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { std::vector branch_results_p(num_branches); for (int j = 0; j < num_branches; ++j) { OP_REQUIRES_OK(ctx, - compiler->CompileFunction(options, branches_[j], arguments, + compiler->CompileFunction(options, branches[j], arguments, &branch_results[j])); branch_results_p[j] = &branch_results[j]; } @@ -171,7 +201,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { for (int j = 0; j < num_branches; ++j) { branch_results[j] = {}; OP_REQUIRES_OK(ctx, - compiler->CompileFunction(options, branches_[j], arguments, + compiler->CompileFunction(options, branches[j], arguments, &branch_results[j])); } } @@ -277,7 +307,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { auto input_tuple = xla::Tuple(b, inputs); xla::XlaOp outputs = - xla::Conditional(ctx->Input(0), absl::MakeSpan(result_computations), + xla::Conditional(branch_index, absl::MakeSpan(result_computations), std::vector(num_branches, input_tuple)); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h index 4a61707864e..4d22a3db830 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.h +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -50,7 +50,16 @@ class XlaCaseOp : public XlaOpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp); - std::vector branches_; + // If the branch_index input is a constant: prunes out all but the branch + // corrresponding to that constant branch index, and returns that branch and + // the literal 0 (as the first and second component of the pair). + // + // If the branch_index input is not a constant: returns unpruned_branches_ and + // the branch_index input. + std::pair, xla::XlaOp> GetPrunedBranchesAndIndex( + XlaOpKernelContext* ctx); + + std::vector unpruned_branches_; DataTypeVector input_types_; DataTypeVector output_types_; bool has_token_input_output_;