Fix enable MLIR bridge logic in xla_compiler

Update xla_compiler to pass config_proto to
GetMlirBridgeRolloutPolicy so the logic is consistent with other
calls to GetMlirBridgeRolloutPolicy.

PiperOrigin-RevId: 344898871
Change-Id: I0fd2fa933db4546d9e7b8a153da2390371d56c31
This commit is contained in:
Marissa Ikonomidis 2020-11-30 15:18:11 -08:00 committed by TensorFlower Gardener
parent a11003dcc3
commit 5994e13bbd
3 changed files with 23 additions and 7 deletions

View File

@ -350,6 +350,7 @@ cc_library(
":xla_helpers",
":xla_op_registry",
":xla_resource",
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/types/variant.h"
@ -551,7 +552,8 @@ static Status GetFunctionBody(const NameAttrList& function,
}
Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
const FunctionBody** fbody) {
const FunctionBody** fbody,
const ConfigProto** config_proto) {
// The function may be in either the local_flib_runtime_ or flib_runtime_.
// Look up the function in local first and if it is not found then look up the
// function in flib_runtime_.
@ -563,8 +565,14 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
TF_RETURN_WITH_CONTEXT_IF_ERROR(
GetFunctionBody(function, flib_runtime_, fbody),
"Local lookup failed with: ", status.error_message());
if (config_proto) {
*config_proto = flib_runtime_->config_proto();
}
VLOG(4) << "Function " << function.name() << " in flib_runtime_";
} else {
if (config_proto) {
*config_proto = local_flib_runtime_->config_proto();
}
VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
}
return Status::OK();
@ -728,7 +736,13 @@ Status XlaCompiler::CompileFunction(
}
const FunctionBody* fbody;
TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody));
const ConfigProto* config = nullptr;
TF_RETURN_IF_ERROR(FindFunctionBody(fn_name_attrs, &fbody, &config));
absl::optional<ConfigProto> config_proto;
if (config) {
config_proto = *config;
}
TF_RETURN_WITH_CONTEXT_IF_ERROR(
CheckSignature(fbody->arg_types, args),
@ -789,16 +803,16 @@ Status XlaCompiler::CompileFunction(
}
VLOG(1) << "====================================================";
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(*graph, config_proto);
#ifdef LIBTPU_ON_GCE
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
VLOG(1) << "MLIR is not supported in this environment.";
}
TF_RETURN_IF_ERROR(
CompileGraph(options, function_id, std::move(graph), args, result));
#else
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;

View File

@ -291,7 +291,8 @@ class XlaCompiler {
// Sets the function body `fbody` to the one registered as `function`.
Status FindFunctionBody(const NameAttrList& function,
const FunctionBody** fbody);
const FunctionBody** fbody,
const ConfigProto** config_proto = nullptr);
private:
// Returns the optimized graph object in this function body.