[TF2XLA] Allow non-tensor arguments for experimental_get_compiler_ir

PiperOrigin-RevId: 334626929
Change-Id: Iaddd686b55f490331a270258f61c2b280a3cee17
This commit is contained in:
George Karpenkov 2020-09-30 10:22:47 -07:00 committed by TensorFlower Gardener
parent e075feb345
commit 801c87d900
2 changed files with 14 additions and 2 deletions

View File

@ -985,7 +985,7 @@ class Function(object):
fn_name = concrete_fn.name
# pylint: disable=protected-access
canon_args, _, _, _ = \
_, _, _, filtered_flat_args = \
concrete_fn._function_spec.canonicalize_function_inputs(
*args, **kwargs)
@ -999,7 +999,7 @@ class Function(object):
return context.context().get_compiler_ir(
stage=stage,
function_name=fn_name,
args=list(canon_args) + concrete_fn.captured_inputs)
args=list(filtered_flat_args) + concrete_fn.captured_inputs)
return compiler_ir_generator

View File

@ -616,6 +616,18 @@ class DefFunctionTest(xla_test.XLATestCase):
'label',
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
def testGetCompilerIrNonTensors(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
def f(l):
return l[0] + l[1]
l = [constant_op.constant(1.1), constant_op.constant(2.2)]
self.assertIn('tuple',
f.experimental_get_compiler_ir(l)())
if __name__ == '__main__':
ops.enable_eager_execution()