[TF2XLA] Disambiguate the call without arguments to experimental_get_compiler_ir
Under the previous version, `f.get_experimental_compiler_ir(*args, **kw)('optimized_hlo')` was ambiguous. PiperOrigin-RevId: 332527710 Change-Id: I234f32152c00b6a114bc25ebc737f65aba797f36
This commit is contained in:
parent
8ee24e7949
commit
285cb597ed
@ -979,10 +979,19 @@ class Function(object):
|
||||
concrete_fn._function_spec.canonicalize_function_inputs(
|
||||
*args, **kwargs)
|
||||
|
||||
return functools.partial(
|
||||
context.context().get_compiler_ir,
|
||||
function_name=fn_name,
|
||||
args=list(canon_args) + concrete_fn.captured_inputs)
|
||||
def compiler_ir_generator(stage='hlo'):
|
||||
"""Returns compiler IR for the given `stage`.
|
||||
|
||||
Args:
|
||||
stage: Stage at which to return the IR. Allowed values are 'hlo' and
|
||||
'optimized_hlo'.
|
||||
"""
|
||||
return context.context().get_compiler_ir(
|
||||
stage=stage,
|
||||
function_name=fn_name,
|
||||
args=list(canon_args) + concrete_fn.captured_inputs)
|
||||
|
||||
return compiler_ir_generator
|
||||
|
||||
@property
|
||||
def python_function(self):
|
||||
|
@ -557,7 +557,7 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
b = random_ops.random_normal([2])
|
||||
|
||||
self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }',
|
||||
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
|
||||
f.experimental_get_compiler_ir(a, b)('optimized_hlo'))
|
||||
|
||||
def testGetCompilerIrNotCompiled(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
Loading…
x
Reference in New Issue
Block a user