Do not try to compile trivially dead branches in the Case tf2xla lowering

This is important for the upcoming DeviceIndex op which can be used to select
one of many implementations depending on the device, and some of them may not be
compilable by tf2xla.

PiperOrigin-RevId: 317376420
Change-Id: I6428df6f4da238e5d2bc3618d51c579e34454945
This commit is contained in:
Sanjoy Das 2020-06-19 14:07:16 -07:00 committed by TensorFlower Gardener
parent 83fe1bad15
commit 94bf57d06c
5 changed files with 156 additions and 9 deletions

View File

@ -1453,6 +1453,26 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "case_test",
size = "small",
srcs = ["case_test.py"],
disabled_backends = ["cpu_ondemand"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
use_xla_device = False, # Uses tf.function(experimental_compile=True)
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
)
tf_xla_py_test(
name = "gather_test",
size = "medium",

View File

@ -0,0 +1,87 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for while loops in XLA."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import test
class CaseTest(xla_test.XLATestCase):
def testCaseBasic(self):
@def_function.function(experimental_compile=True)
def switch_case_test(branch_index):
def f1():
return array_ops.constant(17)
def f2():
return array_ops.constant(31)
def f3():
return array_ops.constant(-1)
return control_flow_ops.switch_case(
branch_index, branch_fns={
0: f1,
1: f2
}, default=f3)
with ops.device(self.device):
self.assertEqual(switch_case_test(array_ops.constant(0)).numpy(), 17)
self.assertEqual(switch_case_test(array_ops.constant(1)).numpy(), 31)
self.assertEqual(switch_case_test(array_ops.constant(2)).numpy(), -1)
self.assertEqual(switch_case_test(array_ops.constant(3)).numpy(), -1)
def testBranchIsPruned(self):
@def_function.function(experimental_compile=True)
def switch_case_test():
branch_index = array_ops.constant(0)
def f1():
return array_ops.constant(17)
def f2():
# Some operations that XLA cannot compile.
image_ops.decode_image(io_ops.read_file('/tmp/bmp'))
return array_ops.constant(31)
# This tests that we do not try to compile all branches if the branch
# index in trivially constant.
return control_flow_ops.switch_case(
branch_index, branch_fns={
0: f1,
1: f2
}, default=f2)
with ops.device(self.device):
self.assertEqual(switch_case_test().numpy(), 17)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -316,6 +316,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -21,13 +21,14 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branches_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &unpruned_branches_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
@ -41,12 +42,29 @@ XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
}
}
std::pair<std::vector<NameAttrList>, xla::XlaOp>
XlaCaseOp::GetPrunedBranchesAndIndex(XlaOpKernelContext* ctx) {
xla::Literal branch_index_literal;
bool branch_index_is_constant =
ctx->ConstantInput(0, &branch_index_literal).ok();
if (!branch_index_is_constant) {
return {unpruned_branches_, ctx->Input(0)};
}
int32 branch_index = branch_index_literal.Get<int32>({});
if (branch_index < 0 || branch_index >= unpruned_branches_.size()) {
branch_index = unpruned_branches_.size() - 1;
}
std::vector<NameAttrList> pruned_branch = {unpruned_branches_[branch_index]};
return {pruned_branch, xla::ZerosLike(ctx->Input(0))};
}
// TODO(b/35949885): There is duplication here with the handling of the
// while_op/if_op. Refactor the common code out/rework.
void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaBuilder* b = ctx->builder();
int num_branches = branches_.size();
OP_REQUIRES(ctx, num_branches >= 1,
OP_REQUIRES(ctx, !unpruned_branches_.empty(),
errors::InvalidArgument("Must provide at least one case branch"));
OP_REQUIRES(ctx, input_type(0) == DT_INT32,
errors::InvalidArgument(
@ -55,6 +73,18 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
errors::InvalidArgument(
"branch_index argument must be scalar for XLA compilation"));
xla::XlaBuilder* b = ctx->builder();
// We opportunistically prune out branches if the branch index is a
// compile-time constant. This is important in the context of the DeviceIndex
// ops (and other such ops that may come later) since we may have a Case with
// trivially unselected branches that cannot be compiled into HLO.
std::vector<NameAttrList> branches;
xla::XlaOp branch_index;
std::tie(branches, branch_index) = GetPrunedBranchesAndIndex(ctx);
int num_branches = branches.size();
VLOG(1) << "Building Case: " << input_types_.size() << " inputs";
std::vector<XlaCompiler::Argument> arguments(input_types_.size());
@ -94,7 +124,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
std::vector<const FunctionBody*> case_bodies(num_branches);
for (int branch_idx = 0; branch_idx < num_branches; branch_idx++) {
OP_REQUIRES_OK(ctx, FindMustBeConstNodes(
ctx, branches_[branch_idx],
ctx, branches[branch_idx],
&case_branch_must_be_const_nodes[branch_idx],
&case_bodies[branch_idx]));
}
@ -133,7 +163,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
std::vector<XlaCompiler::CompilationResult*> branch_results_p(num_branches);
for (int j = 0; j < num_branches; ++j) {
OP_REQUIRES_OK(ctx,
compiler->CompileFunction(options, branches_[j], arguments,
compiler->CompileFunction(options, branches[j], arguments,
&branch_results[j]));
branch_results_p[j] = &branch_results[j];
}
@ -171,7 +201,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
for (int j = 0; j < num_branches; ++j) {
branch_results[j] = {};
OP_REQUIRES_OK(ctx,
compiler->CompileFunction(options, branches_[j], arguments,
compiler->CompileFunction(options, branches[j], arguments,
&branch_results[j]));
}
}
@ -277,7 +307,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
auto input_tuple = xla::Tuple(b, inputs);
xla::XlaOp outputs =
xla::Conditional(ctx->Input(0), absl::MakeSpan(result_computations),
xla::Conditional(branch_index, absl::MakeSpan(result_computations),
std::vector<xla::XlaOp>(num_branches, input_tuple));
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {

View File

@ -50,7 +50,16 @@ class XlaCaseOp : public XlaOpKernel {
private:
TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp);
std::vector<NameAttrList> branches_;
// If the branch_index input is a constant: prunes out all but the branch
// corrresponding to that constant branch index, and returns that branch and
// the literal 0 (as the first and second component of the pair).
//
// If the branch_index input is not a constant: returns unpruned_branches_ and
// the branch_index input.
std::pair<std::vector<NameAttrList>, xla::XlaOp> GetPrunedBranchesAndIndex(
XlaOpKernelContext* ctx);
std::vector<NameAttrList> unpruned_branches_;
DataTypeVector input_types_;
DataTypeVector output_types_;
bool has_token_input_output_;