Fix logic for enabling MLIR bridge depending on has_tensor_list_arg.

It appears that the polarity of the use of has_tensor_list_arg was
inadvertently flipped.

Disable any MLIR bridge enabled tests that were passing because they weren't
using the MLIR bridge due to this issue.

PiperOrigin-RevId: 337125651
Change-Id: I93e9e61acda9a2aeffaee5cce13e93635d33f5a4
This commit is contained in:
Richard Uhler 2020-10-14 10:56:30 -07:00 committed by TensorFlower Gardener
parent 7e54bf3113
commit b881485eb5
5 changed files with 6 additions and 22 deletions

View File

@ -283,25 +283,27 @@ Status XlaCompilationCache::CompileSingleOp(
const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
// TODO(b/155596779): Support TensorList args.
bool has_tensor_list_arg =
absl::c_any_of(args, [](const XlaCompiler::Argument arg) {
return arg.kind == XlaCompiler::Argument::kTensorList;
});
const ConfigProto* config = ctx->function_library()->config_proto();
bool use_mlir = config && config->experimental().enable_mlir_bridge();
bool use_mlir = config && config->experimental().enable_mlir_bridge() &&
!has_tensor_list_arg;
#ifdef LIBTPU_ON_GCE
if (use_mlir && has_tensor_list_arg) {
if (use_mlir) {
LOG(WARNING) << "MLIR is not supported in this environment.";
}
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
#else
// TODO(b/155596779): Support TensorList args.
if (!use_mlir || !has_tensor_list_arg) {
if (!use_mlir) {
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
}
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
std::vector<std::string> control_rets;
if (result_dtypes.empty()) {

View File

@ -327,7 +327,6 @@ tf_xla_py_test(
name = "self_adjoint_eig_op_test",
size = "medium",
srcs = ["self_adjoint_eig_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -393,7 +392,6 @@ tf_xla_py_test(
size = "small",
timeout = "moderate",
srcs = ["matrix_inverse_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -416,7 +414,6 @@ tf_xla_py_test(
size = "small",
timeout = "moderate",
srcs = ["matrix_solve_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -639,7 +636,6 @@ tf_xla_py_test(
name = "extract_image_patches_op_test",
size = "small",
srcs = ["extract_image_patches_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -696,7 +692,6 @@ tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 6,
tags = [
@ -1018,7 +1013,6 @@ tf_xla_py_test(
"cpu",
"cpu_ondemand",
],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 5,
tags = [
@ -1089,7 +1083,6 @@ tf_xla_py_test(
name = "reduce_ops_test",
size = "medium",
srcs = ["reduce_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 5,
tags = [
@ -1391,7 +1384,6 @@ tf_xla_py_test(
name = "unary_ops_test",
size = "medium",
srcs = ["unary_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1548,7 +1540,6 @@ tf_xla_py_test(
name = "sort_ops_test",
size = "medium",
srcs = ["sort_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 1,
# Times out in fastbuild mode.
@ -1790,7 +1781,6 @@ tf_xla_py_test(
name = "fake_quant_ops_test",
size = "medium",
srcs = ["fake_quant_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -935,7 +935,6 @@ cuda_py_test(
distribute_py_test(
name = "checkpointing_test",
srcs = ["checkpointing_test.py"],
disable_mlir_bridge = False,
main = "checkpointing_test.py",
tags = [
"multi_and_single_gpu",
@ -1154,7 +1153,6 @@ distribute_py_test(
name = "values_test",
size = "medium",
srcs = ["values_test.py"],
disable_mlir_bridge = False,
main = "values_test.py",
shard_count = 5,
tags = [
@ -1302,7 +1300,6 @@ distribute_py_test(
distribute_py_test(
name = "moving_averages_test",
srcs = ["moving_averages_test.py"],
disable_mlir_bridge = False,
main = "moving_averages_test.py",
deps = [
":combinations",

View File

@ -150,7 +150,6 @@ cuda_py_test(
distribute_py_test(
name = "checkpointing_test",
srcs = ["checkpointing_test.py"],
disable_mlir_bridge = False,
main = "checkpointing_test.py",
tags = [
"multi_and_single_gpu",
@ -520,7 +519,6 @@ distribute_py_test(
name = "keras_save_load_test",
size = "medium",
srcs = ["keras_save_load_test.py"],
disable_mlir_bridge = False,
full_precision = True,
main = "keras_save_load_test.py",
shard_count = 7,
@ -790,7 +788,6 @@ distribute_py_test(
name = "saved_model_save_load_test",
size = "medium",
srcs = ["saved_model_save_load_test.py"],
disable_mlir_bridge = False,
full_precision = True,
main = "saved_model_save_load_test.py",
shard_count = 7,
@ -808,7 +805,6 @@ distribute_py_test(
name = "saved_model_mixed_api_test",
size = "medium",
srcs = ["saved_model_mixed_api_test.py"],
disable_mlir_bridge = False,
full_precision = True,
main = "saved_model_mixed_api_test.py",
shard_count = 7,

View File

@ -57,7 +57,6 @@ tpu_py_test(
"automatic_outside_compilation_test.py",
],
disable_experimental = True,
disable_mlir_bridge = False,
python_version = "PY3",
tags = ["no_oss"],
deps = [