Update context and tfe_wrapper to support mlir_bridge_rollout

Update eager/context.py and tfe_wrapper to support returning
the real value of mlir_bridge_rollout (enabled/disabled/unspecified)
instead of a bool. This gives users a clearer signal of whether
or not the mlir bridge is being used. At the moment, the mlir
bridge is only enabled when mlir_bridge_rollout is set to
enabled but this will change in the future.

PiperOrigin-RevId: 338124102
Change-Id: I5c93cbdd2815a698e6b41244db8eed716f4988e6
This commit is contained in:
Marissa Ikonomidis 2020-10-20 13:20:19 -07:00 committed by TensorFlower Gardener
parent 169cbde464
commit 1914c410b2
4 changed files with 26 additions and 4 deletions

View File

@ -31,6 +31,9 @@ class MlirBridgePass : public MlirOptimizationPass {
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge() ||
config_proto.experimental().mlir_bridge_rollout() ==
tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_ENABLED ||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_ENABLED;
@ -50,7 +53,10 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge() ||
GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
config_proto.experimental().mlir_bridge_rollout() ==
tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_ENABLED ||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_ENABLED;
}

View File

@ -948,7 +948,12 @@ class Context(object):
if self._log_device_placement is not None:
config.log_device_placement = self._log_device_placement
config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled()
is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled()
config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled
if (is_mlir_bridge_enabled ==
config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED):
config.experimental.enable_mlir_bridge = True
if self._enable_mlir_graph_optimization is not None:
config.experimental.enable_mlir_graph_optimization = (
self._enable_mlir_graph_optimization)

View File

@ -214,14 +214,23 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
def testEnableMlirBridge(self):
# Default value of enable_mlir_bridge is false.
self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
self.assertEqual(
context.context().config.experimental.mlir_bridge_rollout,
config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_UNSPECIFIED)
# Tests enabling mlir bridge.
config.enable_mlir_bridge()
self.assertTrue(context.context().config.experimental.enable_mlir_bridge)
self.assertEqual(
context.context().config.experimental.mlir_bridge_rollout,
config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED)
# Tests disabling mlir bridge.
config.disable_mlir_bridge()
self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
self.assertEqual(
context.context().config.experimental.mlir_bridge_rollout,
config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_DISABLED)
@reset_eager
def testEnableMlirGraphOptimization(self):

View File

@ -580,8 +580,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// MLIR Logic
m.def("TF_IsMlirBridgeEnabled", [] {
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
// Since python protobuf enums are integers, cast to an integer before
// returning the enum to python.
return static_cast<int32_t>(
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
});
m.def("TF_EnableMlirBridge", [](bool enabled) {
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =