From 801c87d900291ebe9fb5984b8de1d3cd0b93d0f2 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 30 Sep 2020 10:22:47 -0700 Subject: [PATCH] [TF2XLA] Allow non-tensor arguments for experimental_get_compiler_ir PiperOrigin-RevId: 334626929 Change-Id: Iaddd686b55f490331a270258f61c2b280a3cee17 --- tensorflow/python/eager/def_function.py | 4 ++-- tensorflow/python/eager/def_function_xla_jit_test.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 58211433496..d1ebd47b6c7 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -985,7 +985,7 @@ class Function(object): fn_name = concrete_fn.name # pylint: disable=protected-access - canon_args, _, _, _ = \ + _, _, _, filtered_flat_args = \ concrete_fn._function_spec.canonicalize_function_inputs( *args, **kwargs) @@ -999,7 +999,7 @@ class Function(object): return context.context().get_compiler_ir( stage=stage, function_name=fn_name, - args=list(canon_args) + concrete_fn.captured_inputs) + args=list(filtered_flat_args) + concrete_fn.captured_inputs) return compiler_ir_generator diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 80ca015a24d..b6997def0d2 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -616,6 +616,18 @@ class DefFunctionTest(xla_test.XLATestCase): 'label', f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot')) + def testGetCompilerIrNonTensors(self): + with ops.device('device:{}:0'.format(self.device)): + + @def_function.function(experimental_compile=True) + def f(l): + return l[0] + l[1] + + l = [constant_op.constant(1.1), constant_op.constant(2.2)] + + self.assertIn('tuple', + f.experimental_get_compiler_ir(l)()) + if __name__ == '__main__': ops.enable_eager_execution()