[TF2XLA] Allow non-tensor arguments for experimental_get_compiler_ir
PiperOrigin-RevId: 334626929 Change-Id: Iaddd686b55f490331a270258f61c2b280a3cee17
This commit is contained in:
parent
e075feb345
commit
801c87d900
@ -985,7 +985,7 @@ class Function(object):
|
|||||||
fn_name = concrete_fn.name
|
fn_name = concrete_fn.name
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
canon_args, _, _, _ = \
|
_, _, _, filtered_flat_args = \
|
||||||
concrete_fn._function_spec.canonicalize_function_inputs(
|
concrete_fn._function_spec.canonicalize_function_inputs(
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
@ -999,7 +999,7 @@ class Function(object):
|
|||||||
return context.context().get_compiler_ir(
|
return context.context().get_compiler_ir(
|
||||||
stage=stage,
|
stage=stage,
|
||||||
function_name=fn_name,
|
function_name=fn_name,
|
||||||
args=list(canon_args) + concrete_fn.captured_inputs)
|
args=list(filtered_flat_args) + concrete_fn.captured_inputs)
|
||||||
|
|
||||||
return compiler_ir_generator
|
return compiler_ir_generator
|
||||||
|
|
||||||
|
@ -616,6 +616,18 @@ class DefFunctionTest(xla_test.XLATestCase):
|
|||||||
'label',
|
'label',
|
||||||
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
|
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
|
||||||
|
|
||||||
|
def testGetCompilerIrNonTensors(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def f(l):
|
||||||
|
return l[0] + l[1]
|
||||||
|
|
||||||
|
l = [constant_op.constant(1.1), constant_op.constant(2.2)]
|
||||||
|
|
||||||
|
self.assertIn('tuple',
|
||||||
|
f.experimental_get_compiler_ir(l)())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
Reference in New Issue
Block a user