Fix enable MLIR bridge logic in xla_compiler
Update xla_compiler to pass config_proto to GetMlirBridgeRolloutPolicy so the logic is consistent with other calls to GetMlirBridgeRolloutPolicy. PiperOrigin-RevId: 344898871 Change-Id: I0fd2fa933db4546d9e7b8a153da2390371d56c31
This commit is contained in:
parent
a11003dcc3
commit
5994e13bbd
@ -350,6 +350,7 @@ cc_library(
|
||||
":xla_helpers",
|
||||
":xla_op_registry",
|
||||
":xla_resource",
|
||||
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
||||
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/variant.h"
|
||||
@ -551,7 +552,8 @@ static Status GetFunctionBody(const NameAttrList& function,
|
||||
}
|
||||
|
||||
Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
|
||||
const FunctionBody** fbody) {
|
||||
const FunctionBody** fbody,
|
||||
const ConfigProto** config_proto) {
|
||||
// The function may be in either the local_flib_runtime_ or flib_runtime_.
|
||||
// Look up the function in local first and if it is not found then look up the
|
||||
// function in flib_runtime_.
|
||||
@ -563,8 +565,14 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
GetFunctionBody(function, flib_runtime_, fbody),
|
||||
"Local lookup failed with: ", status.error_message());
|
||||
if (config_proto) {
|
||||
*config_proto = flib_runtime_->config_proto();
|
||||
}
|
||||
VLOG(4) << "Function " << function.name() << " in flib_runtime_";
|
||||
} else {
|
||||
if (config_proto) {
|
||||
*config_proto = local_flib_runtime_->config_proto();
|
||||
}
|
||||
VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
|
||||
}
|
||||
return Status::OK();
|
||||
@ -728,7 +736,13 @@ Status XlaCompiler::CompileFunction(
|
||||
}
|
||||
|
||||
const FunctionBody* fbody;
|
||||
TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody));
|
||||
const ConfigProto* config = nullptr;
|
||||
TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody, &config));
|
||||
|
||||
absl::optional<ConfigProto> config_proto;
|
||||
if (config) {
|
||||
config_proto = *config;
|
||||
}
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
CheckSignature(fbody->arg_types, args),
|
||||
@ -789,16 +803,16 @@ Status XlaCompiler::CompileFunction(
|
||||
}
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(*graph, config_proto);
|
||||
#ifdef LIBTPU_ON_GCE
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
VLOG(1) << "MLIR is not supported in this environment.";
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(options, function_id, std::move(graph), args, result));
|
||||
#else
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
|
||||
|
||||
@ -291,7 +291,8 @@ class XlaCompiler {
|
||||
|
||||
// Sets the function body `fbody` to the one registered as `function`.
|
||||
Status FindFunctionBody(const NameAttrList& function,
|
||||
const FunctionBody** fbody);
|
||||
const FunctionBody** fbody,
|
||||
const ConfigProto** config_proto = nullptr);
|
||||
|
||||
private:
|
||||
// Returns the optimized graph object in this function body.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user