Do not try to compile trivially dead branches in the Case tf2xla lowering
This is important for the upcoming DeviceIndex op which can be used to select one of many implementations depending on the device, and some of them may not be compilable by tf2xla. PiperOrigin-RevId: 317376420 Change-Id: I6428df6f4da238e5d2bc3618d51c579e34454945
This commit is contained in:
parent
83fe1bad15
commit
94bf57d06c
@ -1453,6 +1453,26 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "case_test",
|
||||
size = "small",
|
||||
srcs = ["case_test.py"],
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
use_xla_device = False, # Uses tf.function(experimental_compile=True)
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "gather_test",
|
||||
size = "medium",
|
||||
|
87
tensorflow/compiler/tests/case_test.py
Normal file
87
tensorflow/compiler/tests/case_test.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for while loops in XLA."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CaseTest(xla_test.XLATestCase):
|
||||
|
||||
def testCaseBasic(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def switch_case_test(branch_index):
|
||||
|
||||
def f1():
|
||||
return array_ops.constant(17)
|
||||
|
||||
def f2():
|
||||
return array_ops.constant(31)
|
||||
|
||||
def f3():
|
||||
return array_ops.constant(-1)
|
||||
|
||||
return control_flow_ops.switch_case(
|
||||
branch_index, branch_fns={
|
||||
0: f1,
|
||||
1: f2
|
||||
}, default=f3)
|
||||
|
||||
with ops.device(self.device):
|
||||
self.assertEqual(switch_case_test(array_ops.constant(0)).numpy(), 17)
|
||||
self.assertEqual(switch_case_test(array_ops.constant(1)).numpy(), 31)
|
||||
self.assertEqual(switch_case_test(array_ops.constant(2)).numpy(), -1)
|
||||
self.assertEqual(switch_case_test(array_ops.constant(3)).numpy(), -1)
|
||||
|
||||
def testBranchIsPruned(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def switch_case_test():
|
||||
branch_index = array_ops.constant(0)
|
||||
|
||||
def f1():
|
||||
return array_ops.constant(17)
|
||||
|
||||
def f2():
|
||||
# Some operations that XLA cannot compile.
|
||||
image_ops.decode_image(io_ops.read_file('/tmp/bmp'))
|
||||
return array_ops.constant(31)
|
||||
|
||||
# This tests that we do not try to compile all branches if the branch
|
||||
# index in trivially constant.
|
||||
return control_flow_ops.switch_case(
|
||||
branch_index, branch_fns={
|
||||
0: f1,
|
||||
1: f2
|
||||
}, default=f2)
|
||||
|
||||
with ops.device(self.device):
|
||||
self.assertEqual(switch_case_test().numpy(), 17)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
@ -316,6 +316,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -21,13 +21,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branches_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &unpruned_branches_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
|
||||
if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
|
||||
@ -41,12 +42,29 @@ XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<NameAttrList>, xla::XlaOp>
|
||||
XlaCaseOp::GetPrunedBranchesAndIndex(XlaOpKernelContext* ctx) {
|
||||
xla::Literal branch_index_literal;
|
||||
bool branch_index_is_constant =
|
||||
ctx->ConstantInput(0, &branch_index_literal).ok();
|
||||
|
||||
if (!branch_index_is_constant) {
|
||||
return {unpruned_branches_, ctx->Input(0)};
|
||||
}
|
||||
|
||||
int32 branch_index = branch_index_literal.Get<int32>({});
|
||||
if (branch_index < 0 || branch_index >= unpruned_branches_.size()) {
|
||||
branch_index = unpruned_branches_.size() - 1;
|
||||
}
|
||||
|
||||
std::vector<NameAttrList> pruned_branch = {unpruned_branches_[branch_index]};
|
||||
return {pruned_branch, xla::ZerosLike(ctx->Input(0))};
|
||||
}
|
||||
|
||||
// TODO(b/35949885): There is duplication here with the handling of the
|
||||
// while_op/if_op. Refactor the common code out/rework.
|
||||
void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
int num_branches = branches_.size();
|
||||
OP_REQUIRES(ctx, num_branches >= 1,
|
||||
OP_REQUIRES(ctx, !unpruned_branches_.empty(),
|
||||
errors::InvalidArgument("Must provide at least one case branch"));
|
||||
OP_REQUIRES(ctx, input_type(0) == DT_INT32,
|
||||
errors::InvalidArgument(
|
||||
@ -55,6 +73,18 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||
errors::InvalidArgument(
|
||||
"branch_index argument must be scalar for XLA compilation"));
|
||||
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
|
||||
// We opportunistically prune out branches if the branch index is a
|
||||
// compile-time constant. This is important in the context of the DeviceIndex
|
||||
// ops (and other such ops that may come later) since we may have a Case with
|
||||
// trivially unselected branches that cannot be compiled into HLO.
|
||||
std::vector<NameAttrList> branches;
|
||||
xla::XlaOp branch_index;
|
||||
std::tie(branches, branch_index) = GetPrunedBranchesAndIndex(ctx);
|
||||
|
||||
int num_branches = branches.size();
|
||||
|
||||
VLOG(1) << "Building Case: " << input_types_.size() << " inputs";
|
||||
|
||||
std::vector<XlaCompiler::Argument> arguments(input_types_.size());
|
||||
@ -94,7 +124,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||
std::vector<const FunctionBody*> case_bodies(num_branches);
|
||||
for (int branch_idx = 0; branch_idx < num_branches; branch_idx++) {
|
||||
OP_REQUIRES_OK(ctx, FindMustBeConstNodes(
|
||||
ctx, branches_[branch_idx],
|
||||
ctx, branches[branch_idx],
|
||||
&case_branch_must_be_const_nodes[branch_idx],
|
||||
&case_bodies[branch_idx]));
|
||||
}
|
||||
@ -133,7 +163,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||
std::vector<XlaCompiler::CompilationResult*> branch_results_p(num_branches);
|
||||
for (int j = 0; j < num_branches; ++j) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
compiler->CompileFunction(options, branches_[j], arguments,
|
||||
compiler->CompileFunction(options, branches[j], arguments,
|
||||
&branch_results[j]));
|
||||
branch_results_p[j] = &branch_results[j];
|
||||
}
|
||||
@ -171,7 +201,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||
for (int j = 0; j < num_branches; ++j) {
|
||||
branch_results[j] = {};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
compiler->CompileFunction(options, branches_[j], arguments,
|
||||
compiler->CompileFunction(options, branches[j], arguments,
|
||||
&branch_results[j]));
|
||||
}
|
||||
}
|
||||
@ -277,7 +307,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||
auto input_tuple = xla::Tuple(b, inputs);
|
||||
|
||||
xla::XlaOp outputs =
|
||||
xla::Conditional(ctx->Input(0), absl::MakeSpan(result_computations),
|
||||
xla::Conditional(branch_index, absl::MakeSpan(result_computations),
|
||||
std::vector<xla::XlaOp>(num_branches, input_tuple));
|
||||
// Sets non-variable outputs.
|
||||
for (int i = 0; i < output_types_.size(); ++i) {
|
||||
|
@ -50,7 +50,16 @@ class XlaCaseOp : public XlaOpKernel {
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp);
|
||||
|
||||
std::vector<NameAttrList> branches_;
|
||||
// If the branch_index input is a constant: prunes out all but the branch
|
||||
// corrresponding to that constant branch index, and returns that branch and
|
||||
// the literal 0 (as the first and second component of the pair).
|
||||
//
|
||||
// If the branch_index input is not a constant: returns unpruned_branches_ and
|
||||
// the branch_index input.
|
||||
std::pair<std::vector<NameAttrList>, xla::XlaOp> GetPrunedBranchesAndIndex(
|
||||
XlaOpKernelContext* ctx);
|
||||
|
||||
std::vector<NameAttrList> unpruned_branches_;
|
||||
DataTypeVector input_types_;
|
||||
DataTypeVector output_types_;
|
||||
bool has_token_input_output_;
|
||||
|
Loading…
Reference in New Issue
Block a user