[TF2XLA] Allow non-tensor arguments for experimental_get_compiler_ir
PiperOrigin-RevId: 334626929 Change-Id: Iaddd686b55f490331a270258f61c2b280a3cee17
This commit is contained in:
parent
e075feb345
commit
801c87d900
@ -985,7 +985,7 @@ class Function(object):
|
||||
fn_name = concrete_fn.name
|
||||
|
||||
# pylint: disable=protected-access
|
||||
canon_args, _, _, _ = \
|
||||
_, _, _, filtered_flat_args = \
|
||||
concrete_fn._function_spec.canonicalize_function_inputs(
|
||||
*args, **kwargs)
|
||||
|
||||
@ -999,7 +999,7 @@ class Function(object):
|
||||
return context.context().get_compiler_ir(
|
||||
stage=stage,
|
||||
function_name=fn_name,
|
||||
args=list(canon_args) + concrete_fn.captured_inputs)
|
||||
args=list(filtered_flat_args) + concrete_fn.captured_inputs)
|
||||
|
||||
return compiler_ir_generator
|
||||
|
||||
|
@ -616,6 +616,18 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
'label',
|
||||
f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
|
||||
|
||||
def testGetCompilerIrNonTensors(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f(l):
|
||||
return l[0] + l[1]
|
||||
|
||||
l = [constant_op.constant(1.1), constant_op.constant(2.2)]
|
||||
|
||||
self.assertIn('tuple',
|
||||
f.experimental_get_compiler_ir(l)())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user