[TF2XLA] Disambiguate the call without arguments to experimental_get_compiler_ir

Under the previous version, `f.get_experimental_compiler_ir(*args, **kw)('optimized_hlo')` was ambiguous.

PiperOrigin-RevId: 332527710
Change-Id: I234f32152c00b6a114bc25ebc737f65aba797f36
This commit is contained in:
George Karpenkov 2020-09-18 14:27:46 -07:00 committed by TensorFlower Gardener
parent 8ee24e7949
commit 285cb597ed
2 changed files with 14 additions and 5 deletions

View File

@ -979,10 +979,19 @@ class Function(object):
concrete_fn._function_spec.canonicalize_function_inputs(
*args, **kwargs)
return functools.partial(
context.context().get_compiler_ir,
function_name=fn_name,
args=list(canon_args) + concrete_fn.captured_inputs)
def compiler_ir_generator(stage='hlo'):
"""Returns compiler IR for the given `stage`.
Args:
stage: Stage at which to return the IR. Allowed values are 'hlo' and
'optimized_hlo'.
"""
return context.context().get_compiler_ir(
stage=stage,
function_name=fn_name,
args=list(canon_args) + concrete_fn.captured_inputs)
return compiler_ir_generator
@property
def python_function(self):

View File

@ -557,7 +557,7 @@ class DefFunctionTest(xla_test.XLATestCase):
b = random_ops.random_normal([2])
self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }',
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
f.experimental_get_compiler_ir(a, b)('optimized_hlo'))
def testGetCompilerIrNotCompiled(self):
with ops.device('device:{}:0'.format(self.device)):