Fix logic for enabling MLIR bridge depending on has_tensor_list_arg.
It appears that the polarity of the use of has_tensor_list_arg was inadvertently flipped. Disable any MLIR bridge enabled tests that were passing because they weren't using the MLIR bridge due to this issue. PiperOrigin-RevId: 337125651 Change-Id: I93e9e61acda9a2aeffaee5cce13e93635d33f5a4
This commit is contained in:
parent
7e54bf3113
commit
b881485eb5
@ -283,25 +283,27 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
const NodeDef& node_def = ctx->op_kernel().def();
|
const NodeDef& node_def = ctx->op_kernel().def();
|
||||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
||||||
|
|
||||||
|
// TODO(b/155596779): Support TensorList args.
|
||||||
bool has_tensor_list_arg =
|
bool has_tensor_list_arg =
|
||||||
absl::c_any_of(args, [](const XlaCompiler::Argument arg) {
|
absl::c_any_of(args, [](const XlaCompiler::Argument arg) {
|
||||||
return arg.kind == XlaCompiler::Argument::kTensorList;
|
return arg.kind == XlaCompiler::Argument::kTensorList;
|
||||||
});
|
});
|
||||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
bool use_mlir = config && config->experimental().enable_mlir_bridge() &&
|
||||||
|
!has_tensor_list_arg;
|
||||||
#ifdef LIBTPU_ON_GCE
|
#ifdef LIBTPU_ON_GCE
|
||||||
if (use_mlir && has_tensor_list_arg) {
|
if (use_mlir) {
|
||||||
LOG(WARNING) << "MLIR is not supported in this environment.";
|
LOG(WARNING) << "MLIR is not supported in this environment.";
|
||||||
}
|
}
|
||||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||||
std::move(graph), args, result);
|
std::move(graph), args, result);
|
||||||
#else
|
#else
|
||||||
// TODO(b/155596779): Support TensorList args.
|
if (!use_mlir) {
|
||||||
if (!use_mlir || !has_tensor_list_arg) {
|
|
||||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||||
std::move(graph), args, result);
|
std::move(graph), args, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Using MLIR bridge";
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
std::vector<std::string> control_rets;
|
std::vector<std::string> control_rets;
|
||||||
if (result_dtypes.empty()) {
|
if (result_dtypes.empty()) {
|
||||||
|
@ -327,7 +327,6 @@ tf_xla_py_test(
|
|||||||
name = "self_adjoint_eig_op_test",
|
name = "self_adjoint_eig_op_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["self_adjoint_eig_op_test.py"],
|
srcs = ["self_adjoint_eig_op_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
@ -393,7 +392,6 @@ tf_xla_py_test(
|
|||||||
size = "small",
|
size = "small",
|
||||||
timeout = "moderate",
|
timeout = "moderate",
|
||||||
srcs = ["matrix_inverse_op_test.py"],
|
srcs = ["matrix_inverse_op_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
@ -416,7 +414,6 @@ tf_xla_py_test(
|
|||||||
size = "small",
|
size = "small",
|
||||||
timeout = "moderate",
|
timeout = "moderate",
|
||||||
srcs = ["matrix_solve_op_test.py"],
|
srcs = ["matrix_solve_op_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
@ -639,7 +636,6 @@ tf_xla_py_test(
|
|||||||
name = "extract_image_patches_op_test",
|
name = "extract_image_patches_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["extract_image_patches_op_test.py"],
|
srcs = ["extract_image_patches_op_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
@ -696,7 +692,6 @@ tf_xla_py_test(
|
|||||||
name = "fft_test",
|
name = "fft_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["fft_test.py"],
|
srcs = ["fft_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 6,
|
shard_count = 6,
|
||||||
tags = [
|
tags = [
|
||||||
@ -1018,7 +1013,6 @@ tf_xla_py_test(
|
|||||||
"cpu",
|
"cpu",
|
||||||
"cpu_ondemand",
|
"cpu_ondemand",
|
||||||
],
|
],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = [
|
tags = [
|
||||||
@ -1089,7 +1083,6 @@ tf_xla_py_test(
|
|||||||
name = "reduce_ops_test",
|
name = "reduce_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["reduce_ops_test.py"],
|
srcs = ["reduce_ops_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = [
|
tags = [
|
||||||
@ -1391,7 +1384,6 @@ tf_xla_py_test(
|
|||||||
name = "unary_ops_test",
|
name = "unary_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["unary_ops_test.py"],
|
srcs = ["unary_ops_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
@ -1548,7 +1540,6 @@ tf_xla_py_test(
|
|||||||
name = "sort_ops_test",
|
name = "sort_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["sort_ops_test.py"],
|
srcs = ["sort_ops_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 1,
|
shard_count = 1,
|
||||||
# Times out in fastbuild mode.
|
# Times out in fastbuild mode.
|
||||||
@ -1790,7 +1781,6 @@ tf_xla_py_test(
|
|||||||
name = "fake_quant_ops_test",
|
name = "fake_quant_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["fake_quant_ops_test.py"],
|
srcs = ["fake_quant_ops_test.py"],
|
||||||
enable_mlir_bridge = True,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
|
@ -935,7 +935,6 @@ cuda_py_test(
|
|||||||
distribute_py_test(
|
distribute_py_test(
|
||||||
name = "checkpointing_test",
|
name = "checkpointing_test",
|
||||||
srcs = ["checkpointing_test.py"],
|
srcs = ["checkpointing_test.py"],
|
||||||
disable_mlir_bridge = False,
|
|
||||||
main = "checkpointing_test.py",
|
main = "checkpointing_test.py",
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
@ -1154,7 +1153,6 @@ distribute_py_test(
|
|||||||
name = "values_test",
|
name = "values_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["values_test.py"],
|
srcs = ["values_test.py"],
|
||||||
disable_mlir_bridge = False,
|
|
||||||
main = "values_test.py",
|
main = "values_test.py",
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = [
|
tags = [
|
||||||
@ -1302,7 +1300,6 @@ distribute_py_test(
|
|||||||
distribute_py_test(
|
distribute_py_test(
|
||||||
name = "moving_averages_test",
|
name = "moving_averages_test",
|
||||||
srcs = ["moving_averages_test.py"],
|
srcs = ["moving_averages_test.py"],
|
||||||
disable_mlir_bridge = False,
|
|
||||||
main = "moving_averages_test.py",
|
main = "moving_averages_test.py",
|
||||||
deps = [
|
deps = [
|
||||||
":combinations",
|
":combinations",
|
||||||
|
@ -150,7 +150,6 @@ cuda_py_test(
|
|||||||
distribute_py_test(
|
distribute_py_test(
|
||||||
name = "checkpointing_test",
|
name = "checkpointing_test",
|
||||||
srcs = ["checkpointing_test.py"],
|
srcs = ["checkpointing_test.py"],
|
||||||
disable_mlir_bridge = False,
|
|
||||||
main = "checkpointing_test.py",
|
main = "checkpointing_test.py",
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
@ -520,7 +519,6 @@ distribute_py_test(
|
|||||||
name = "keras_save_load_test",
|
name = "keras_save_load_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["keras_save_load_test.py"],
|
srcs = ["keras_save_load_test.py"],
|
||||||
disable_mlir_bridge = False,
|
|
||||||
full_precision = True,
|
full_precision = True,
|
||||||
main = "keras_save_load_test.py",
|
main = "keras_save_load_test.py",
|
||||||
shard_count = 7,
|
shard_count = 7,
|
||||||
@ -790,7 +788,6 @@ distribute_py_test(
|
|||||||
name = "saved_model_save_load_test",
|
name = "saved_model_save_load_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["saved_model_save_load_test.py"],
|
srcs = ["saved_model_save_load_test.py"],
|
||||||
disable_mlir_bridge = False,
|
|
||||||
full_precision = True,
|
full_precision = True,
|
||||||
main = "saved_model_save_load_test.py",
|
main = "saved_model_save_load_test.py",
|
||||||
shard_count = 7,
|
shard_count = 7,
|
||||||
@ -808,7 +805,6 @@ distribute_py_test(
|
|||||||
name = "saved_model_mixed_api_test",
|
name = "saved_model_mixed_api_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["saved_model_mixed_api_test.py"],
|
srcs = ["saved_model_mixed_api_test.py"],
|
||||||
disable_mlir_bridge = False,
|
|
||||||
full_precision = True,
|
full_precision = True,
|
||||||
main = "saved_model_mixed_api_test.py",
|
main = "saved_model_mixed_api_test.py",
|
||||||
shard_count = 7,
|
shard_count = 7,
|
||||||
|
@ -57,7 +57,6 @@ tpu_py_test(
|
|||||||
"automatic_outside_compilation_test.py",
|
"automatic_outside_compilation_test.py",
|
||||||
],
|
],
|
||||||
disable_experimental = True,
|
disable_experimental = True,
|
||||||
disable_mlir_bridge = False,
|
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = ["no_oss"],
|
tags = ["no_oss"],
|
||||||
deps = [
|
deps = [
|
||||||
|
Loading…
x
Reference in New Issue
Block a user