Delay transpiler initialzation to mitigate effects of circular imports.

PiperOrigin-RevId: 350757373
Change-Id: If4aa3b49264e539cfc1d828aeaf98b2bc0345a2e
This commit is contained in:
A. Unique TensorFlower 2021-01-08 07:13:10 -08:00 committed by TensorFlower Gardener
parent 59974d69d2
commit 16696c7587
2 changed files with 20 additions and 21 deletions

View File

@ -334,7 +334,7 @@ _AG_FIXED_RETURN_TYPE = {
QN = qual_names.QN
# TODO(mdan): Fix this with an importable module.
AG_MODULE = api._TRANSPILER.get_extra_locals()['ag__'] # pylint:disable=protected-access
AG_MODULE = api._TRANSPILER._extra_locals['ag__'] # pylint:disable=protected-access
class TFRTypeResolver(type_inference.Resolver):

View File

@ -209,13 +209,7 @@ class PyToTF(transpiler.PyToPy):
def __init__(self):
super(PyToTF, self).__init__()
self._extra_locals = None
def get_transformed_name(self, node):
return 'tf__' + super(PyToTF, self).get_transformed_name(node)
def get_extra_locals(self):
if self._extra_locals is None:
# TODO(mdan): Move into core or replace with an actual importable module.
# Craft a module that exposes the external API as well as certain
# internal modules.
@ -234,6 +228,11 @@ class PyToTF(transpiler.PyToPy):
ag_internal.__dict__.update(operators.__dict__)
self._extra_locals = {'ag__': ag_internal}
def get_transformed_name(self, node):
return 'tf__' + super(PyToTF, self).get_transformed_name(node)
def get_extra_locals(self):
return self._extra_locals
def get_caching_key(self, ctx):