Fix logic for enabling MLIR bridge depending on has_tensor_list_arg.
It appears that the polarity of the use of has_tensor_list_arg was inadvertently flipped. Disable any MLIR bridge enabled tests that were passing because they weren't using the MLIR bridge due to this issue. PiperOrigin-RevId: 337125651 Change-Id: I93e9e61acda9a2aeffaee5cce13e93635d33f5a4
This commit is contained in:
parent
7e54bf3113
commit
b881485eb5
@ -283,25 +283,27 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
const NodeDef& node_def = ctx->op_kernel().def();
|
||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
||||
|
||||
// TODO(b/155596779): Support TensorList args.
|
||||
bool has_tensor_list_arg =
|
||||
absl::c_any_of(args, [](const XlaCompiler::Argument arg) {
|
||||
return arg.kind == XlaCompiler::Argument::kTensorList;
|
||||
});
|
||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||
bool use_mlir = config && config->experimental().enable_mlir_bridge() &&
|
||||
!has_tensor_list_arg;
|
||||
#ifdef LIBTPU_ON_GCE
|
||||
if (use_mlir && has_tensor_list_arg) {
|
||||
if (use_mlir) {
|
||||
LOG(WARNING) << "MLIR is not supported in this environment.";
|
||||
}
|
||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||
std::move(graph), args, result);
|
||||
#else
|
||||
// TODO(b/155596779): Support TensorList args.
|
||||
if (!use_mlir || !has_tensor_list_arg) {
|
||||
if (!use_mlir) {
|
||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||
std::move(graph), args, result);
|
||||
}
|
||||
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
std::vector<std::string> control_rets;
|
||||
if (result_dtypes.empty()) {
|
||||
|
@ -327,7 +327,6 @@ tf_xla_py_test(
|
||||
name = "self_adjoint_eig_op_test",
|
||||
size = "medium",
|
||||
srcs = ["self_adjoint_eig_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -393,7 +392,6 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
timeout = "moderate",
|
||||
srcs = ["matrix_inverse_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -416,7 +414,6 @@ tf_xla_py_test(
|
||||
size = "small",
|
||||
timeout = "moderate",
|
||||
srcs = ["matrix_solve_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -639,7 +636,6 @@ tf_xla_py_test(
|
||||
name = "extract_image_patches_op_test",
|
||||
size = "small",
|
||||
srcs = ["extract_image_patches_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -696,7 +692,6 @@ tf_xla_py_test(
|
||||
name = "fft_test",
|
||||
size = "medium",
|
||||
srcs = ["fft_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 6,
|
||||
tags = [
|
||||
@ -1018,7 +1013,6 @@ tf_xla_py_test(
|
||||
"cpu",
|
||||
"cpu_ondemand",
|
||||
],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
@ -1089,7 +1083,6 @@ tf_xla_py_test(
|
||||
name = "reduce_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["reduce_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
@ -1391,7 +1384,6 @@ tf_xla_py_test(
|
||||
name = "unary_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["unary_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -1548,7 +1540,6 @@ tf_xla_py_test(
|
||||
name = "sort_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["sort_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 1,
|
||||
# Times out in fastbuild mode.
|
||||
@ -1790,7 +1781,6 @@ tf_xla_py_test(
|
||||
name = "fake_quant_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["fake_quant_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -935,7 +935,6 @@ cuda_py_test(
|
||||
distribute_py_test(
|
||||
name = "checkpointing_test",
|
||||
srcs = ["checkpointing_test.py"],
|
||||
disable_mlir_bridge = False,
|
||||
main = "checkpointing_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
@ -1154,7 +1153,6 @@ distribute_py_test(
|
||||
name = "values_test",
|
||||
size = "medium",
|
||||
srcs = ["values_test.py"],
|
||||
disable_mlir_bridge = False,
|
||||
main = "values_test.py",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
@ -1302,7 +1300,6 @@ distribute_py_test(
|
||||
distribute_py_test(
|
||||
name = "moving_averages_test",
|
||||
srcs = ["moving_averages_test.py"],
|
||||
disable_mlir_bridge = False,
|
||||
main = "moving_averages_test.py",
|
||||
deps = [
|
||||
":combinations",
|
||||
|
@ -150,7 +150,6 @@ cuda_py_test(
|
||||
distribute_py_test(
|
||||
name = "checkpointing_test",
|
||||
srcs = ["checkpointing_test.py"],
|
||||
disable_mlir_bridge = False,
|
||||
main = "checkpointing_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
@ -520,7 +519,6 @@ distribute_py_test(
|
||||
name = "keras_save_load_test",
|
||||
size = "medium",
|
||||
srcs = ["keras_save_load_test.py"],
|
||||
disable_mlir_bridge = False,
|
||||
full_precision = True,
|
||||
main = "keras_save_load_test.py",
|
||||
shard_count = 7,
|
||||
@ -790,7 +788,6 @@ distribute_py_test(
|
||||
name = "saved_model_save_load_test",
|
||||
size = "medium",
|
||||
srcs = ["saved_model_save_load_test.py"],
|
||||
disable_mlir_bridge = False,
|
||||
full_precision = True,
|
||||
main = "saved_model_save_load_test.py",
|
||||
shard_count = 7,
|
||||
@ -808,7 +805,6 @@ distribute_py_test(
|
||||
name = "saved_model_mixed_api_test",
|
||||
size = "medium",
|
||||
srcs = ["saved_model_mixed_api_test.py"],
|
||||
disable_mlir_bridge = False,
|
||||
full_precision = True,
|
||||
main = "saved_model_mixed_api_test.py",
|
||||
shard_count = 7,
|
||||
|
@ -57,7 +57,6 @@ tpu_py_test(
|
||||
"automatic_outside_compilation_test.py",
|
||||
],
|
||||
disable_experimental = True,
|
||||
disable_mlir_bridge = False,
|
||||
python_version = "PY3",
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user