Update context and tfe_wrapper to support mlir_bridge_rollout
Update eager/context.py and tfe_wrapper to support returning the real value of mlir_bridge_rollout (enabled/disabled/unspecified) instead of a bool. This gives users a clearer signal of whether or not the mlir bridge is being used. At the moment, the mlir bridge is only enabled when mlir_bridge_rollout is set to enabled but this will change in the future. PiperOrigin-RevId: 338124102 Change-Id: I5c93cbdd2815a698e6b41244db8eed716f4988e6
This commit is contained in:
parent
169cbde464
commit
1914c410b2
@ -31,6 +31,9 @@ class MlirBridgePass : public MlirOptimizationPass {
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge() ||
|
||||
config_proto.experimental().mlir_bridge_rollout() ==
|
||||
tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_ENABLED ||
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
@ -50,7 +53,10 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge() ||
|
||||
GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
config_proto.experimental().mlir_bridge_rollout() ==
|
||||
tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_ENABLED ||
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
}
|
||||
|
@ -948,7 +948,12 @@ class Context(object):
|
||||
if self._log_device_placement is not None:
|
||||
config.log_device_placement = self._log_device_placement
|
||||
|
||||
config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled()
|
||||
is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled()
|
||||
config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled
|
||||
if (is_mlir_bridge_enabled ==
|
||||
config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED):
|
||||
config.experimental.enable_mlir_bridge = True
|
||||
|
||||
if self._enable_mlir_graph_optimization is not None:
|
||||
config.experimental.enable_mlir_graph_optimization = (
|
||||
self._enable_mlir_graph_optimization)
|
||||
|
@ -214,14 +214,23 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
||||
def testEnableMlirBridge(self):
|
||||
# Default value of enable_mlir_bridge is false.
|
||||
self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
|
||||
self.assertEqual(
|
||||
context.context().config.experimental.mlir_bridge_rollout,
|
||||
config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_UNSPECIFIED)
|
||||
|
||||
# Tests enabling mlir bridge.
|
||||
config.enable_mlir_bridge()
|
||||
self.assertTrue(context.context().config.experimental.enable_mlir_bridge)
|
||||
self.assertEqual(
|
||||
context.context().config.experimental.mlir_bridge_rollout,
|
||||
config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED)
|
||||
|
||||
# Tests disabling mlir bridge.
|
||||
config.disable_mlir_bridge()
|
||||
self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
|
||||
self.assertEqual(
|
||||
context.context().config.experimental.mlir_bridge_rollout,
|
||||
config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_DISABLED)
|
||||
|
||||
@reset_eager
|
||||
def testEnableMlirGraphOptimization(self):
|
||||
|
@ -580,8 +580,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
|
||||
// MLIR Logic
|
||||
m.def("TF_IsMlirBridgeEnabled", [] {
|
||||
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
// Since python protobuf enums are integers, cast to an integer before
|
||||
// returning the enum to python.
|
||||
return static_cast<int32_t>(
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
|
||||
});
|
||||
m.def("TF_EnableMlirBridge", [](bool enabled) {
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
|
||||
|
Loading…
Reference in New Issue
Block a user