diff --git a/tensorflow/virtual_root_template_v1.__init__.py b/tensorflow/virtual_root_template_v1.__init__.py index 9603ddca5c0..d341de2721a 100644 --- a/tensorflow/virtual_root_template_v1.__init__.py +++ b/tensorflow/virtual_root_template_v1.__init__.py @@ -129,4 +129,8 @@ try: del examples except NameError: pass + +# Manually patch keras and estimator so tf.keras and tf.estimator work +keras = _sys.modules["tensorflow.keras"] +if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] # LINT.ThenChange(//tensorflow/virtual_root_template_v2.__init__.py.oss) diff --git a/tensorflow/virtual_root_template_v2.__init__.py b/tensorflow/virtual_root_template_v2.__init__.py index dc3011c96ee..b3e101902be 100644 --- a/tensorflow/virtual_root_template_v2.__init__.py +++ b/tensorflow/virtual_root_template_v2.__init__.py @@ -122,4 +122,8 @@ try: del examples except NameError: pass + +# Manually patch keras and estimator so tf.keras and tf.estimator work +keras = _sys.modules["tensorflow.keras"] +if not _root_estimator: estimator = _sys.modules["tensorflow.estimator"] # LINT.ThenChange(//tensorflow/virtual_root_template_v1.__init__.py.oss)