diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index d7d5ee02265..c97e71bf800 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -283,25 +283,27 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); + // TODO(b/155596779): Support TensorList args. bool has_tensor_list_arg = absl::c_any_of(args, [](const XlaCompiler::Argument arg) { return arg.kind == XlaCompiler::Argument::kTensorList; }); const ConfigProto* config = ctx->function_library()->config_proto(); - bool use_mlir = config && config->experimental().enable_mlir_bridge(); + bool use_mlir = config && config->experimental().enable_mlir_bridge() && + !has_tensor_list_arg; #ifdef LIBTPU_ON_GCE - if (use_mlir && has_tensor_list_arg) { + if (use_mlir) { LOG(WARNING) << "MLIR is not supported in this environment."; } return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); #else - // TODO(b/155596779): Support TensorList args. - if (!use_mlir || !has_tensor_list_arg) { + if (!use_mlir) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } + VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; std::vector control_rets; if (result_dtypes.empty()) { diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 1dfcf88e654..622a1ff3fd8 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -327,7 +327,6 @@ tf_xla_py_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -393,7 +392,6 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -416,7 +414,6 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_solve_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -639,7 +636,6 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -696,7 +692,6 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -1018,7 +1013,6 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1089,7 +1083,6 @@ tf_xla_py_test( name = "reduce_ops_test", size = "medium", srcs = ["reduce_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1391,7 +1384,6 @@ tf_xla_py_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1548,7 +1540,6 @@ tf_xla_py_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 1, # Times out in fastbuild mode. @@ -1790,7 +1781,6 @@ tf_xla_py_test( name = "fake_quant_ops_test", size = "medium", srcs = ["fake_quant_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index b8e8098bb6e..9c4da2becb6 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -935,7 +935,6 @@ cuda_py_test( distribute_py_test( name = "checkpointing_test", srcs = ["checkpointing_test.py"], - disable_mlir_bridge = False, main = "checkpointing_test.py", tags = [ "multi_and_single_gpu", @@ -1154,7 +1153,6 @@ distribute_py_test( name = "values_test", size = "medium", srcs = ["values_test.py"], - disable_mlir_bridge = False, main = "values_test.py", shard_count = 5, tags = [ @@ -1302,7 +1300,6 @@ distribute_py_test( distribute_py_test( name = "moving_averages_test", srcs = ["moving_averages_test.py"], - disable_mlir_bridge = False, main = "moving_averages_test.py", deps = [ ":combinations", diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 79c36c0f559..cc2622d4eb1 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -150,7 +150,6 @@ cuda_py_test( distribute_py_test( name = "checkpointing_test", srcs = ["checkpointing_test.py"], - disable_mlir_bridge = False, main = "checkpointing_test.py", tags = [ "multi_and_single_gpu", @@ -520,7 +519,6 @@ distribute_py_test( name = "keras_save_load_test", size = "medium", srcs = ["keras_save_load_test.py"], - disable_mlir_bridge = False, full_precision = True, main = "keras_save_load_test.py", shard_count = 7, @@ -790,7 +788,6 @@ distribute_py_test( name = "saved_model_save_load_test", size = "medium", srcs = ["saved_model_save_load_test.py"], - disable_mlir_bridge = False, full_precision = True, main = "saved_model_save_load_test.py", shard_count = 7, @@ -808,7 +805,6 @@ distribute_py_test( name = "saved_model_mixed_api_test", size = "medium", srcs = ["saved_model_mixed_api_test.py"], - disable_mlir_bridge = False, full_precision = True, main = "saved_model_mixed_api_test.py", shard_count = 7, diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD index 0ac8abcbe18..bb9290b3e6f 100644 --- a/tensorflow/python/keras/tests/BUILD +++ b/tensorflow/python/keras/tests/BUILD @@ -57,7 +57,6 @@ tpu_py_test( "automatic_outside_compilation_test.py", ], disable_experimental = True, - disable_mlir_bridge = False, python_version = "PY3", tags = ["no_oss"], deps = [