Delay transpiler initialzation to mitigate effects of circular imports.

PiperOrigin-RevId: 350787978
Change-Id: I134b53638c8c3c959724b713c5c517b6330b2546
This commit is contained in:
Dan Moldovan 2021-01-08 10:19:19 -08:00 committed by TensorFlower Gardener
parent c584ef0ee7
commit d2f068134d
2 changed files with 21 additions and 20 deletions

View File

@ -334,7 +334,7 @@ _AG_FIXED_RETURN_TYPE = {
QN = qual_names.QN
# TODO(mdan): Fix this with an importable module.
AG_MODULE = api._TRANSPILER._extra_locals['ag__'] # pylint:disable=protected-access
AG_MODULE = api._TRANSPILER.get_extra_locals()['ag__'] # pylint:disable=protected-access
class TFRTypeResolver(type_inference.Resolver):

View File

@ -209,30 +209,31 @@ class PyToTF(transpiler.PyToPy):
def __init__(self):
super(PyToTF, self).__init__()
# TODO(mdan): Move into core or replace with an actual importable module.
# Craft a module that exposes the external API as well as certain
# internal modules.
ag_internal = imp.new_module('autograph')
ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__)
ag_internal.ConversionOptions = converter.ConversionOptions
ag_internal.STD = converter.STANDARD_OPTIONS
ag_internal.Feature = converter.Feature
ag_internal.utils = utils
ag_internal.FunctionScope = function_wrappers.FunctionScope
ag_internal.with_function_scope = function_wrappers.with_function_scope
# TODO(mdan): Add safeguards against name clashes.
# We don't want to create a submodule because we want the operators to be
# accessible as ag__.<operator>
ag_internal.__dict__.update(special_functions.__dict__)
ag_internal.__dict__.update(operators.__dict__)
self._extra_locals = {'ag__': ag_internal}
self._extra_locals = None
def get_transformed_name(self, node):
return 'tf__' + super(PyToTF, self).get_transformed_name(node)
def get_extra_locals(self):
if self._extra_locals is None:
# TODO(mdan): Move into core or replace with an actual importable module.
# Craft a module that exposes the external API as well as certain
# internal modules.
ag_internal = imp.new_module('autograph')
ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__)
ag_internal.ConversionOptions = converter.ConversionOptions
ag_internal.STD = converter.STANDARD_OPTIONS
ag_internal.Feature = converter.Feature
ag_internal.utils = utils
ag_internal.FunctionScope = function_wrappers.FunctionScope
ag_internal.with_function_scope = function_wrappers.with_function_scope
# TODO(mdan): Add safeguards against name clashes.
# We don't want to create a submodule because we want the operators to be
# accessible as ag__.<operator>
ag_internal.__dict__.update(special_functions.__dict__)
ag_internal.__dict__.update(operators.__dict__)
self._extra_locals = {'ag__': ag_internal}
return self._extra_locals
def get_caching_key(self, ctx):