Add uninitialized resource args to unsupported features for MLIR bridge

We currently don't support uninitialized resource variable arguments in the
MLIR bridge. With this change, we now take this into account as an unsupported
feature in the MLIR rollout policy which controls when the MLIR bridge is run.

PiperOrigin-RevId: 357854490
Change-Id: I216bd3a0a05488906798dac092ff95485242105a
This commit is contained in:
Michael Gester 2021-02-16 19:06:58 -08:00 committed by TensorFlower Gardener
parent a26dc6a474
commit 7e15dc2f89
10 changed files with 42 additions and 12 deletions

View File

@ -286,7 +286,9 @@ Status XlaCompilationCache::CompileSingleOp(
const ConfigProto* config = ctx->function_library()->config_proto();
// TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
bool use_mlir = config &&
GetMlirBridgeRolloutPolicy(*graph, *config) ==
GetMlirBridgeRolloutPolicy(
*graph, *config, /*uses_uninitialized_resource_args=*/
AnyUninitializedResourceArg(args)) ==
MlirBridgeRolloutPolicy::kEnabledByUser &&
node_def.op() != "VarIsInitializedOp";
if (!use_mlir) {

View File

@ -103,8 +103,15 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
if (flr->config_proto()) {
config_proto = *flr->config_proto();
}
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto);
// There is no easy way to check if we have uninitialized resource args here
// so we assume there are uninitialized resource args. This means that we
// might run the compilability checker in cases where we don't need to (when
// MLIR bridge is run later). Note that this is just temporary until
// b/171732021 gets fixed.
// We should also revisit if this check provides any value, otherwise we
// should remove it.
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
*fbody->graph, config_proto, /*uses_uninitialized_resource_args=*/true);
if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {

View File

@ -52,7 +52,7 @@ static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest(
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto,
bool record_stats) {
bool uses_uninitialized_resource_args, bool record_stats) {
switch (GetUserRequest(config_proto)) {
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
return MlirBridgeRolloutPolicy::kEnabledByUser;

View File

@ -51,7 +51,7 @@ enum class MlirBridgeRolloutPolicy {
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
const tensorflow::Graph& graph,
absl::optional<tensorflow::ConfigProto> config_proto,
bool record_stats = false);
bool uses_uninitialized_resource_args, bool record_stats = false);
} // namespace tensorflow

View File

@ -166,7 +166,11 @@ Status MlirFunctionOptimizationPass::Run(
// TODO(b/176852151): Remove this after dark launch completed.
// Capture stats relevant to graph properties used in dark launch.
GetMlirBridgeRolloutPolicy(**graph, config_proto, /*record_stats=*/true);
// We set `uses_uninitialized_resource_args` to false here because function
// optimization is not affected by uninitialized resource args.
GetMlirBridgeRolloutPolicy(**graph, config_proto,
/*uses_uninitialized_resource_args=*/false,
/*record_stats=*/true);
if (overall_state == MlirOptimizationPassState::Disabled) {
LOG_FIRST_N(INFO, 1) << "None of the MLIR Optimization Passes are enabled "

View File

@ -558,6 +558,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
],
alwayslink = 1,
)

View File

@ -91,8 +91,10 @@ MlirOptimizationPassState MlirBridgePass::GetPassState(
return MlirOptimizationPassState::Disabled;
}
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(graph, config_proto);
// We set `uses_uninitialized_resource_args` to false here because the first
// phase of the bridge is not affected by uninitialized resource args.
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
graph, config_proto, /*uses_uninitialized_resource_args=*/false);
switch (policy) {
case MlirBridgeRolloutPolicy::kEnabledByUser:
return MlirOptimizationPassState::Enabled;
@ -144,8 +146,10 @@ bool MlirBridgeV1CompatPass::IsEnabled(const DeviceSet* device_set,
// Do not run the bridge if it's enabled by the graph analysis,
// only run if it's enabled by the user explicitly.
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(graph, config_proto);
// We set `uses_uninitialized_resource_args` to false here because the first
// phase of the bridge is not affected by uninitialized resource args.
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
graph, config_proto, /*uses_uninitialized_resource_args=*/false);
return policy == MlirBridgeRolloutPolicy::kEnabledByUser;
}

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_argument.h"
#include "llvm/ADT/STLExtras.h"
namespace tensorflow {
bool XlaArgument::operator==(const XlaArgument& other) const {
@ -50,4 +52,10 @@ bool XlaArgument::operator==(const XlaArgument& other) const {
return constant_value.tensor_data() == other.constant_value.tensor_data();
}
bool AnyUninitializedResourceArg(absl::Span<const XlaArgument> args) {
return llvm::any_of(args, [](const XlaArgument& arg) {
return arg.kind == XlaArgument::kResource && arg.type == DT_INVALID;
});
}
} // end namespace tensorflow

View File

@ -119,6 +119,9 @@ struct XlaArgument {
string ShapeHumanString() const;
};
// Returns true if any of `args` is an uninitialized resource variable.
bool AnyUninitializedResourceArg(absl::Span<const XlaArgument> args);
} // end namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_

View File

@ -804,8 +804,9 @@ Status XlaCompiler::CompileFunction(
}
VLOG(1) << "====================================================";
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(*graph, config_proto);
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
*graph, config_proto,
/*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;