diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 61ff6bc191f..a7d4e054e9f 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -286,7 +286,9 @@ Status XlaCompilationCache::CompileSingleOp( const ConfigProto* config = ctx->function_library()->config_proto(); // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR. bool use_mlir = config && - GetMlirBridgeRolloutPolicy(*graph, *config) == + GetMlirBridgeRolloutPolicy( + *graph, *config, /*uses_uninitialized_resource_args=*/ + AnyUninitializedResourceArg(args)) == MlirBridgeRolloutPolicy::kEnabledByUser && node_def.op() != "VarIsInitializedOp"; if (!use_mlir) { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 602c2d2cbfe..054ac9987c6 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -103,8 +103,15 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, if (flr->config_proto()) { config_proto = *flr->config_proto(); } - MlirBridgeRolloutPolicy policy = - GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto); + // There is no easy way to check if we have uninitialized resource args here + // so we assume there are uninitialized resource args. This means that we + // might run the compilability checker in cases where we don't need to (when + // MLIR bridge is run later). Note that this is just temporary until + // b/171732021 gets fixed. + // We should also revisit if this check provides any value, otherwise we + // should remove it. + MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( + *fbody->graph, config_proto, /*uses_uninitialized_resource_args=*/true); if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) { RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { diff --git a/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc b/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc index ac3e59eb8a8..a3239aa9421 100644 --- a/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc +++ b/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.cc @@ -52,7 +52,7 @@ static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest( MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, absl::optional config_proto, - bool record_stats) { + bool uses_uninitialized_resource_args, bool record_stats) { switch (GetUserRequest(config_proto)) { case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED: return MlirBridgeRolloutPolicy::kEnabledByUser; diff --git a/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h b/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h index f029ad8c8c2..4fff6b7081e 100644 --- a/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h +++ b/tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h @@ -51,7 +51,7 @@ enum class MlirBridgeRolloutPolicy { MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( const tensorflow::Graph& graph, absl::optional config_proto, - bool record_stats = false); + bool uses_uninitialized_resource_args, bool record_stats = false); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index e0c8a97a8a4..07234718dc5 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -166,7 +166,11 @@ Status MlirFunctionOptimizationPass::Run( // TODO(b/176852151): Remove this after dark launch completed. // Capture stats relevant to graph properties used in dark launch. - GetMlirBridgeRolloutPolicy(**graph, config_proto, /*record_stats=*/true); + // We set `uses_uninitialized_resource_args` to false here because function + // optimization is not affected by uninitialized resource args. + GetMlirBridgeRolloutPolicy(**graph, config_proto, + /*uses_uninitialized_resource_args=*/false, + /*record_stats=*/true); if (overall_state == MlirOptimizationPassState::Disabled) { LOG_FIRST_N(INFO, 1) << "None of the MLIR Optimization Passes are enabled " diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 31329590f1d..5d551753654 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -558,6 +558,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework", "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index eff56d5d528..f58001f3144 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -91,8 +91,10 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( return MlirOptimizationPassState::Disabled; } - MlirBridgeRolloutPolicy policy = - GetMlirBridgeRolloutPolicy(graph, config_proto); + // We set `uses_uninitialized_resource_args` to false here because the first + // phase of the bridge is not affected by uninitialized resource args. + MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( + graph, config_proto, /*uses_uninitialized_resource_args=*/false); switch (policy) { case MlirBridgeRolloutPolicy::kEnabledByUser: return MlirOptimizationPassState::Enabled; @@ -144,8 +146,10 @@ bool MlirBridgeV1CompatPass::IsEnabled(const DeviceSet* device_set, // Do not run the bridge if it's enabled by the graph analysis, // only run if it's enabled by the user explicitly. - MlirBridgeRolloutPolicy policy = - GetMlirBridgeRolloutPolicy(graph, config_proto); + // We set `uses_uninitialized_resource_args` to false here because the first + // phase of the bridge is not affected by uninitialized resource args. + MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( + graph, config_proto, /*uses_uninitialized_resource_args=*/false); return policy == MlirBridgeRolloutPolicy::kEnabledByUser; } diff --git a/tensorflow/compiler/tf2xla/xla_argument.cc b/tensorflow/compiler/tf2xla/xla_argument.cc index fe31025386e..8b91dd3870b 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.cc +++ b/tensorflow/compiler/tf2xla/xla_argument.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "llvm/ADT/STLExtras.h" + namespace tensorflow { bool XlaArgument::operator==(const XlaArgument& other) const { @@ -50,4 +52,10 @@ bool XlaArgument::operator==(const XlaArgument& other) const { return constant_value.tensor_data() == other.constant_value.tensor_data(); } +bool AnyUninitializedResourceArg(absl::Span args) { + return llvm::any_of(args, [](const XlaArgument& arg) { + return arg.kind == XlaArgument::kResource && arg.type == DT_INVALID; + }); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index c304c479f87..4b785bb9f29 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -119,6 +119,9 @@ struct XlaArgument { string ShapeHumanString() const; }; +// Returns true if any of `args` is an uninitialized resource variable. +bool AnyUninitializedResourceArg(absl::Span args); + } // end namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index c04d1e02658..b566cc8d6b9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -804,8 +804,9 @@ Status XlaCompiler::CompileFunction( } VLOG(1) << "===================================================="; - MlirBridgeRolloutPolicy policy = - GetMlirBridgeRolloutPolicy(*graph, config_proto); + MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( + *graph, config_proto, + /*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args)); if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) { VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info;