Add uninitialized resource args to unsupported features for MLIR bridge
We currently don't support uninitialized resource variable arguments in the MLIR bridge. With this change, we now take this into account as an unsupported feature in the MLIR rollout policy which controls when the MLIR bridge is run. PiperOrigin-RevId: 357854490 Change-Id: I216bd3a0a05488906798dac092ff95485242105a
This commit is contained in:
parent
a26dc6a474
commit
7e15dc2f89
@ -286,7 +286,9 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||
// TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
|
||||
bool use_mlir = config &&
|
||||
GetMlirBridgeRolloutPolicy(*graph, *config) ==
|
||||
GetMlirBridgeRolloutPolicy(
|
||||
*graph, *config, /*uses_uninitialized_resource_args=*/
|
||||
AnyUninitializedResourceArg(args)) ==
|
||||
MlirBridgeRolloutPolicy::kEnabledByUser &&
|
||||
node_def.op() != "VarIsInitializedOp";
|
||||
if (!use_mlir) {
|
||||
|
@ -103,8 +103,15 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
if (flr->config_proto()) {
|
||||
config_proto = *flr->config_proto();
|
||||
}
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto);
|
||||
// There is no easy way to check if we have uninitialized resource args here
|
||||
// so we assume there are uninitialized resource args. This means that we
|
||||
// might run the compilability checker in cases where we don't need to (when
|
||||
// MLIR bridge is run later). Note that this is just temporary until
|
||||
// b/171732021 gets fixed.
|
||||
// We should also revisit if this check provides any value, otherwise we
|
||||
// should remove it.
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
|
||||
*fbody->graph, config_proto, /*uses_uninitialized_resource_args=*/true);
|
||||
if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
|
@ -52,7 +52,7 @@ static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest(
|
||||
|
||||
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto,
|
||||
bool record_stats) {
|
||||
bool uses_uninitialized_resource_args, bool record_stats) {
|
||||
switch (GetUserRequest(config_proto)) {
|
||||
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
|
||||
return MlirBridgeRolloutPolicy::kEnabledByUser;
|
||||
|
@ -51,7 +51,7 @@ enum class MlirBridgeRolloutPolicy {
|
||||
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph& graph,
|
||||
absl::optional<tensorflow::ConfigProto> config_proto,
|
||||
bool record_stats = false);
|
||||
bool uses_uninitialized_resource_args, bool record_stats = false);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -166,7 +166,11 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
|
||||
// TODO(b/176852151): Remove this after dark launch completed.
|
||||
// Capture stats relevant to graph properties used in dark launch.
|
||||
GetMlirBridgeRolloutPolicy(**graph, config_proto, /*record_stats=*/true);
|
||||
// We set `uses_uninitialized_resource_args` to false here because function
|
||||
// optimization is not affected by uninitialized resource args.
|
||||
GetMlirBridgeRolloutPolicy(**graph, config_proto,
|
||||
/*uses_uninitialized_resource_args=*/false,
|
||||
/*record_stats=*/true);
|
||||
|
||||
if (overall_state == MlirOptimizationPassState::Disabled) {
|
||||
LOG_FIRST_N(INFO, 1) << "None of the MLIR Optimization Passes are enabled "
|
||||
|
@ -558,6 +558,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -91,8 +91,10 @@ MlirOptimizationPassState MlirBridgePass::GetPassState(
|
||||
return MlirOptimizationPassState::Disabled;
|
||||
}
|
||||
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(graph, config_proto);
|
||||
// We set `uses_uninitialized_resource_args` to false here because the first
|
||||
// phase of the bridge is not affected by uninitialized resource args.
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
|
||||
graph, config_proto, /*uses_uninitialized_resource_args=*/false);
|
||||
switch (policy) {
|
||||
case MlirBridgeRolloutPolicy::kEnabledByUser:
|
||||
return MlirOptimizationPassState::Enabled;
|
||||
@ -144,8 +146,10 @@ bool MlirBridgeV1CompatPass::IsEnabled(const DeviceSet* device_set,
|
||||
|
||||
// Do not run the bridge if it's enabled by the graph analysis,
|
||||
// only run if it's enabled by the user explicitly.
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(graph, config_proto);
|
||||
// We set `uses_uninitialized_resource_args` to false here because the first
|
||||
// phase of the bridge is not affected by uninitialized resource args.
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
|
||||
graph, config_proto, /*uses_uninitialized_resource_args=*/false);
|
||||
return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool XlaArgument::operator==(const XlaArgument& other) const {
|
||||
@ -50,4 +52,10 @@ bool XlaArgument::operator==(const XlaArgument& other) const {
|
||||
return constant_value.tensor_data() == other.constant_value.tensor_data();
|
||||
}
|
||||
|
||||
bool AnyUninitializedResourceArg(absl::Span<const XlaArgument> args) {
|
||||
return llvm::any_of(args, [](const XlaArgument& arg) {
|
||||
return arg.kind == XlaArgument::kResource && arg.type == DT_INVALID;
|
||||
});
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -119,6 +119,9 @@ struct XlaArgument {
|
||||
string ShapeHumanString() const;
|
||||
};
|
||||
|
||||
// Returns true if any of `args` is an uninitialized resource variable.
|
||||
bool AnyUninitializedResourceArg(absl::Span<const XlaArgument> args);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
|
||||
|
@ -804,8 +804,9 @@ Status XlaCompiler::CompileFunction(
|
||||
}
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(*graph, config_proto);
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
|
||||
*graph, config_proto,
|
||||
/*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
|
||||
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
|
Loading…
x
Reference in New Issue
Block a user