[TF2XLA] Disambiguate the call without arguments to experimental_get_compiler_ir
Under the previous version, `f.get_experimental_compiler_ir(*args, **kw)('optimized_hlo')` was ambiguous.
PiperOrigin-RevId: 332527710
Change-Id: I234f32152c00b6a114bc25ebc737f65aba797f36
This commit is contained in:
parent
8ee24e7949
commit
285cb597ed
@ -979,10 +979,19 @@ class Function(object):
|
|||||||
concrete_fn._function_spec.canonicalize_function_inputs(
|
concrete_fn._function_spec.canonicalize_function_inputs(
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
return functools.partial(
|
def compiler_ir_generator(stage='hlo'):
|
||||||
context.context().get_compiler_ir,
|
"""Returns compiler IR for the given `stage`.
|
||||||
function_name=fn_name,
|
|
||||||
args=list(canon_args) + concrete_fn.captured_inputs)
|
Args:
|
||||||
|
stage: Stage at which to return the IR. Allowed values are 'hlo' and
|
||||||
|
'optimized_hlo'.
|
||||||
|
"""
|
||||||
|
return context.context().get_compiler_ir(
|
||||||
|
stage=stage,
|
||||||
|
function_name=fn_name,
|
||||||
|
args=list(canon_args) + concrete_fn.captured_inputs)
|
||||||
|
|
||||||
|
return compiler_ir_generator
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def python_function(self):
|
def python_function(self):
|
||||||
|
|||||||
@ -557,7 +557,7 @@ class DefFunctionTest(xla_test.XLATestCase):
|
|||||||
b = random_ops.random_normal([2])
|
b = random_ops.random_normal([2])
|
||||||
|
|
||||||
self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }',
|
self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }',
|
||||||
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
|
f.experimental_get_compiler_ir(a, b)('optimized_hlo'))
|
||||||
|
|
||||||
def testGetCompilerIrNotCompiled(self):
|
def testGetCompilerIrNotCompiled(self):
|
||||||
with ops.device('device:{}:0'.format(self.device)):
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user