Lookup functions in optional FunctionLibraryDefinition when checking for supported features for MLIR bridge.
FunctionOptimizationPass passes in an additional FunctionLibraryDefinition in addition to the one attached to the graph. In some cases, this FunctionLibraryDefinition contains the function definitions that the passed in graph does not contain. PiperOrigin-RevId: 361161819 Change-Id: I9719b61073f45942cbf6210fd03f13a6bd361246
This commit is contained in:
parent
017ec03f9b
commit
901112863e
@ -285,7 +285,8 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
// TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
|
||||
bool use_mlir = config &&
|
||||
GetMlirBridgeRolloutPolicy(
|
||||
*graph, *config, /*uses_uninitialized_resource_args=*/
|
||||
*graph, /*function_library=*/nullptr,
|
||||
*config, /*uses_uninitialized_resource_args=*/
|
||||
AnyUninitializedResourceArg(args)) ==
|
||||
MlirBridgeRolloutPolicy::kEnabledByUser &&
|
||||
node_def.op() != "VarIsInitializedOp";
|
||||
|
@ -51,7 +51,9 @@ static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest(
|
||||
}
|
||||
|
||||
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto,
|
||||
const tensorflow::Graph& graph,
|
||||
const FunctionLibraryDefinition* function_library,
|
||||
absl::optional<ConfigProto> config_proto,
|
||||
bool uses_uninitialized_resource_args, bool record_stats) {
|
||||
switch (GetUserRequest(config_proto)) {
|
||||
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
|
||||
|
@ -50,6 +50,7 @@ enum class MlirBridgeRolloutPolicy {
|
||||
// to decide whether to emit metrics on unsupported features of the graph.
|
||||
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph& graph,
|
||||
const FunctionLibraryDefinition* function_library,
|
||||
absl::optional<tensorflow::ConfigProto> config_proto,
|
||||
bool uses_uninitialized_resource_args, bool record_stats = false);
|
||||
|
||||
|
@ -137,7 +137,7 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
num_passes_shadow_enabled = 0, num_passes_fallback_enabled = 0;
|
||||
for (const auto& pass_registration : registry_->passes()) {
|
||||
MlirOptimizationPassState pass_state = pass_registration.pass->GetPassState(
|
||||
&device_set, config_proto, **graph);
|
||||
&device_set, config_proto, **graph, *flib_def);
|
||||
per_pass_state.push_back(pass_state);
|
||||
switch (pass_state) {
|
||||
case MlirOptimizationPassState::ShadowEnabled: {
|
||||
@ -168,7 +168,7 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
// Capture stats relevant to graph properties used in dark launch.
|
||||
// We set `uses_uninitialized_resource_args` to false here because function
|
||||
// optimization is not affected by uninitialized resource args.
|
||||
GetMlirBridgeRolloutPolicy(**graph, config_proto,
|
||||
GetMlirBridgeRolloutPolicy(**graph, flib_def, config_proto,
|
||||
/*uses_uninitialized_resource_args=*/false,
|
||||
/*record_stats=*/true);
|
||||
|
||||
@ -235,8 +235,8 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
if (pass_state == MlirOptimizationPassState::Enabled ||
|
||||
(pass_state == MlirOptimizationPassState::ShadowEnabled &&
|
||||
overall_state == MlirOptimizationPassState::ShadowEnabled)) {
|
||||
pass_status =
|
||||
pass_registration.pass->Run(config_proto, *module_ref, **graph);
|
||||
pass_status = pass_registration.pass->Run(config_proto, *module_ref,
|
||||
**graph, *flib_def);
|
||||
} else if (pass_state == MlirOptimizationPassState::ShadowEnabled ||
|
||||
pass_state == MlirOptimizationPassState::FallbackEnabled) {
|
||||
// Make sure when the pass is:
|
||||
@ -244,8 +244,8 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
// FallbackEnabled, it only modifies the MLIR module in case of
|
||||
// no failures.
|
||||
auto module_ref_clone = module_ref->clone();
|
||||
pass_status =
|
||||
pass_registration.pass->Run(config_proto, module_ref_clone, **graph);
|
||||
pass_status = pass_registration.pass->Run(config_proto, module_ref_clone,
|
||||
**graph, *flib_def);
|
||||
if (pass_state == MlirOptimizationPassState::FallbackEnabled &&
|
||||
pass_status.ok()) {
|
||||
module_ref = module_ref_clone;
|
||||
@ -329,8 +329,9 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
if (options.is_function_graph || !registry_->pass()) return Status::OK();
|
||||
|
||||
auto pass = registry_->pass();
|
||||
auto pass_state = pass->GetPassState(
|
||||
options.device_set, options.session_options->config, **options.graph);
|
||||
auto pass_state =
|
||||
pass->GetPassState(options.device_set, options.session_options->config,
|
||||
**options.graph, *options.flib_def);
|
||||
|
||||
// Do not run V1 compatibility pass in shadow mode.
|
||||
if (pass_state == MlirOptimizationPassState::Disabled ||
|
||||
|
@ -63,12 +63,16 @@ class MlirOptimizationPass {
|
||||
// module will be committed.
|
||||
// `device_set` can be nullptr if the devices information is not
|
||||
// available or no device specific filtering is required.
|
||||
// `function_library` contains function definitions for function calls in
|
||||
// `graph` not included in the `graph` FunctionLibraryDefinition.
|
||||
virtual MlirOptimizationPassState GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const = 0;
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) const = 0;
|
||||
|
||||
virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
const Graph& graph) = 0;
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) = 0;
|
||||
};
|
||||
|
||||
class MlirOptimizationPassRegistry {
|
||||
@ -159,7 +163,8 @@ class MlirV1CompatOptimizationPass {
|
||||
// on exact values.
|
||||
virtual MlirOptimizationPassState GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const = 0;
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) const = 0;
|
||||
|
||||
virtual Status Run(const GraphOptimizationPassOptions& options,
|
||||
mlir::ModuleOp module) = 0;
|
||||
|
@ -32,12 +32,14 @@ class MockMlirOptimizationPass : public MlirOptimizationPass {
|
||||
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
|
||||
// instead.
|
||||
MOCK_CONST_METHOD0(name, llvm::StringRef());
|
||||
MOCK_CONST_METHOD3(GetPassState,
|
||||
MlirOptimizationPassState(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph));
|
||||
MOCK_METHOD3(Run, Status(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph));
|
||||
MOCK_CONST_METHOD4(GetPassState,
|
||||
MlirOptimizationPassState(
|
||||
const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto, const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library));
|
||||
MOCK_METHOD4(Run, Status(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library));
|
||||
};
|
||||
|
||||
class MockMlirV1CompatOptimizationPass : public MlirV1CompatOptimizationPass {
|
||||
@ -45,10 +47,11 @@ class MockMlirV1CompatOptimizationPass : public MlirV1CompatOptimizationPass {
|
||||
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
|
||||
// instead.
|
||||
MOCK_CONST_METHOD0(name, llvm::StringRef());
|
||||
MOCK_CONST_METHOD3(GetPassState,
|
||||
MlirOptimizationPassState(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph));
|
||||
MOCK_CONST_METHOD4(GetPassState,
|
||||
MlirOptimizationPassState(
|
||||
const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto, const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library));
|
||||
MOCK_METHOD2(Run, Status(const GraphOptimizationPassOptions& options,
|
||||
mlir::ModuleOp module));
|
||||
};
|
||||
@ -59,15 +62,17 @@ class ModifyMlirModulePass : public MlirOptimizationPass {
|
||||
// MOCK_METHOD does not work on Windows build, using MOCK_CONST_METHODX
|
||||
// instead.
|
||||
MOCK_CONST_METHOD0(name, llvm::StringRef());
|
||||
MOCK_CONST_METHOD3(GetPassState,
|
||||
MlirOptimizationPassState(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph));
|
||||
MOCK_CONST_METHOD4(GetPassState,
|
||||
MlirOptimizationPassState(
|
||||
const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto, const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library));
|
||||
|
||||
// Just modify MLIR module so that we can check whether original TF graph
|
||||
// has changed or not.
|
||||
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
const Graph& graph) override {
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) override {
|
||||
mlir::Builder b(module.getContext());
|
||||
auto producer = b.getNamedAttr("producer", b.getI32IntegerAttr(0));
|
||||
auto min_consumer = b.getNamedAttr("min_consumer", b.getI32IntegerAttr(0));
|
||||
@ -95,9 +100,9 @@ class MlirGraphOptimizationPassTest : public Test {
|
||||
auto optimization_pass =
|
||||
std::make_unique<NiceMock<MockMlirOptimizationPass>>();
|
||||
|
||||
ON_CALL(*optimization_pass, GetPassState(_, _, _))
|
||||
ON_CALL(*optimization_pass, GetPassState(_, _, _, _))
|
||||
.WillByDefault(Return(pass_state));
|
||||
ON_CALL(*optimization_pass, Run(_, _, _))
|
||||
ON_CALL(*optimization_pass, Run(_, _, _, _))
|
||||
.WillByDefault(Return(pass_run_result));
|
||||
MlirOptimizationPassRegistry::Global().Add(pass_priority++,
|
||||
std::move(optimization_pass));
|
||||
@ -111,7 +116,7 @@ class MlirGraphOptimizationPassTest : public Test {
|
||||
// Add FallbackEnabled pass that modifies the graph.
|
||||
auto optimization_pass =
|
||||
std::make_unique<NiceMock<ModifyMlirModulePass>>(run_status);
|
||||
ON_CALL(*optimization_pass, GetPassState(_, _, _))
|
||||
ON_CALL(*optimization_pass, GetPassState(_, _, _, _))
|
||||
.WillByDefault(Return(pass_state));
|
||||
MlirOptimizationPassRegistry::Global().Add(10,
|
||||
std::move(optimization_pass));
|
||||
|
@ -31,9 +31,11 @@ using ConfigProto = ::tensorflow::ConfigProto;
|
||||
using Graph = ::tensorflow::Graph;
|
||||
} // namespace
|
||||
|
||||
Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto,
|
||||
ModuleOp module, const Graph& graph) {
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph) ==
|
||||
Status MlirGraphOptimizationPass::Run(
|
||||
const ConfigProto& config_proto, ModuleOp module, const Graph& graph,
|
||||
const tensorflow::FunctionLibraryDefinition& function_library) {
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph,
|
||||
function_library) ==
|
||||
::tensorflow::MlirOptimizationPassState::Disabled) {
|
||||
VLOG(1) << "Skipping MLIR Graph Optimization Pass"
|
||||
<< ", session flag not enabled";
|
||||
|
@ -30,15 +30,18 @@ class MlirGraphOptimizationPass : public ::tensorflow::MlirOptimizationPass {
|
||||
::tensorflow::MlirOptimizationPassState GetPassState(
|
||||
const ::tensorflow::DeviceSet* device_set,
|
||||
const ::tensorflow::ConfigProto& config_proto,
|
||||
const tensorflow::Graph& graph) const override {
|
||||
const tensorflow::Graph& graph,
|
||||
const tensorflow::FunctionLibraryDefinition& function_library)
|
||||
const override {
|
||||
return config_proto.experimental().enable_mlir_graph_optimization()
|
||||
? tensorflow::MlirOptimizationPassState::Enabled
|
||||
: tensorflow::MlirOptimizationPassState::Disabled;
|
||||
}
|
||||
|
||||
::tensorflow::Status Run(const ::tensorflow::ConfigProto& config_proto,
|
||||
ModuleOp module,
|
||||
const ::tensorflow::Graph& graph) override;
|
||||
::tensorflow::Status Run(
|
||||
const ::tensorflow::ConfigProto& config_proto, ModuleOp module,
|
||||
const ::tensorflow::Graph& graph,
|
||||
const tensorflow::FunctionLibraryDefinition& function_library) override;
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
|
@ -32,16 +32,18 @@ namespace tfr {
|
||||
|
||||
MlirOptimizationPassState GraphDecomposePass::GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) const {
|
||||
const char* tfr_lib_env_val = getenv(std::string(kTFRLibEnv).c_str());
|
||||
return tfr_lib_env_val != nullptr ? MlirOptimizationPassState::Enabled
|
||||
: MlirOptimizationPassState::Disabled;
|
||||
}
|
||||
|
||||
Status GraphDecomposePass::Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph) {
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph) ==
|
||||
MlirOptimizationPassState::Disabled) {
|
||||
Status GraphDecomposePass::Run(
|
||||
const ConfigProto& config_proto, mlir::ModuleOp module, const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) {
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph,
|
||||
function_library) == MlirOptimizationPassState::Disabled) {
|
||||
LOG_FIRST_N(INFO, 1) << "Skipping Graph Decomposition Pass, decomposition"
|
||||
" library was not found";
|
||||
return Status::OK();
|
||||
|
@ -35,12 +35,14 @@ class GraphDecomposePass : public MlirOptimizationPass {
|
||||
// to MLIR even no tf composition file is found.
|
||||
::tensorflow::MlirOptimizationPassState GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) const override;
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
const Graph& graph) override;
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) override;
|
||||
};
|
||||
|
||||
} // namespace tfr
|
||||
|
@ -75,7 +75,8 @@ bool HasTPUDevice(const DeviceSet& device_set) {
|
||||
} // namespace
|
||||
|
||||
// Analyzes the user requested policy as well as the contents of the graph and
|
||||
// determines whether the MLIR Bridge should be run.
|
||||
// function_library_definition to determine whether the MLIR Bridge should be
|
||||
// run.
|
||||
//
|
||||
// If the user explicitly requests the bridge be enabled or disabled, this
|
||||
// function will respect the request. If the user does not explicitly request
|
||||
@ -85,7 +86,8 @@ bool HasTPUDevice(const DeviceSet& device_set) {
|
||||
// redundant for TF2 graphs.
|
||||
MlirOptimizationPassState MlirBridgePass::GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) const {
|
||||
// Skip MLIR TPU Bridge if no TPU devices found.
|
||||
if (device_set && !HasTPUDevice(*device_set)) {
|
||||
return MlirOptimizationPassState::Disabled;
|
||||
@ -93,8 +95,9 @@ MlirOptimizationPassState MlirBridgePass::GetPassState(
|
||||
|
||||
// We set `uses_uninitialized_resource_args` to false here because the first
|
||||
// phase of the bridge is not affected by uninitialized resource args.
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
|
||||
graph, config_proto, /*uses_uninitialized_resource_args=*/false);
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(graph, &function_library, config_proto,
|
||||
/*uses_uninitialized_resource_args=*/false);
|
||||
switch (policy) {
|
||||
case MlirBridgeRolloutPolicy::kEnabledByUser:
|
||||
return MlirOptimizationPassState::Enabled;
|
||||
@ -128,11 +131,12 @@ namespace {
|
||||
// operation. The kernel for these operations is responsible to lower the
|
||||
// encapsulated graph to a particular device.
|
||||
Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module, const Graph& graph) {
|
||||
mlir::ModuleOp module, const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) {
|
||||
// Set device_set to nullptr here as the device specific checks are performed
|
||||
// based on the devices in the module.
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph) ==
|
||||
MlirOptimizationPassState::Disabled) {
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph,
|
||||
function_library) == MlirOptimizationPassState::Disabled) {
|
||||
LOG_AT_LEAST_ONCE("Skipping MLIR TPU Bridge, session flag not enabled");
|
||||
mlir_bridge_gauge_v2->GetCell()->Set(false);
|
||||
return Status::OK();
|
||||
@ -156,7 +160,8 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
||||
|
||||
MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph) const {
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) const {
|
||||
// Skip MLIR TPU Bridge if no TPU devices found.
|
||||
if (device_set && !HasTPUDevice(*device_set))
|
||||
return MlirOptimizationPassState::Disabled;
|
||||
@ -166,7 +171,8 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState(
|
||||
// We set `uses_uninitialized_resource_args` to false here because the first
|
||||
// phase of the bridge is not affected by uninitialized resource args.
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
|
||||
graph, config_proto, /*uses_uninitialized_resource_args=*/false);
|
||||
graph, /*function_library=*/&function_library, config_proto,
|
||||
/*uses_uninitialized_resource_args=*/false);
|
||||
return (policy == MlirBridgeRolloutPolicy::kEnabledByUser)
|
||||
? MlirOptimizationPassState::Enabled
|
||||
: MlirOptimizationPassState::Disabled;
|
||||
@ -180,7 +186,8 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
||||
// Set device_set to nullptr here as the device specific checks are performed
|
||||
// based on the devices in the module.
|
||||
if (GetPassState(/*device_set=*/nullptr, options.session_options->config,
|
||||
**options.graph) == MlirOptimizationPassState::Disabled) {
|
||||
**options.graph,
|
||||
*options.flib_def) == MlirOptimizationPassState::Disabled) {
|
||||
LOG_AT_LEAST_ONCE(
|
||||
"Skipping MLIR TPU Bridge V1 Compat, session flag not enabled");
|
||||
mlir_bridge_gauge_v1->GetCell()->Set(false);
|
||||
|
@ -30,14 +30,16 @@ class MlirBridgePass : public MlirOptimizationPass {
|
||||
public:
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
MlirOptimizationPassState GetPassState(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
MlirOptimizationPassState GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) const override;
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
const Graph& graph) override;
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) override;
|
||||
};
|
||||
|
||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||
@ -47,9 +49,10 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||
public:
|
||||
llvm::StringRef name() const override { return "bridge"; }
|
||||
|
||||
MlirOptimizationPassState GetPassState(const DeviceSet* device_set,
|
||||
const ConfigProto& config_proto,
|
||||
const Graph& graph) const override;
|
||||
MlirOptimizationPassState GetPassState(
|
||||
const DeviceSet* device_set, const ConfigProto& config_proto,
|
||||
const Graph& graph,
|
||||
const FunctionLibraryDefinition& function_library) const override;
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
// API integrated with the Tensorflow runtime.
|
||||
|
@ -807,7 +807,7 @@ Status XlaCompiler::CompileFunction(
|
||||
MlirBridgeRolloutPolicy policy = MlirBridgeRolloutPolicy::kDisabledByUser;
|
||||
if (options.is_entry_computation) {
|
||||
policy = GetMlirBridgeRolloutPolicy(
|
||||
*graph, config_proto,
|
||||
*graph, /*function_library=*/nullptr, config_proto,
|
||||
/*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
|
||||
}
|
||||
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
|
Loading…
Reference in New Issue
Block a user