diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 58211433496..d1ebd47b6c7 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -985,7 +985,7 @@ class Function(object): fn_name = concrete_fn.name # pylint: disable=protected-access - canon_args, _, _, _ = \ + _, _, _, filtered_flat_args = \ concrete_fn._function_spec.canonicalize_function_inputs( *args, **kwargs) @@ -999,7 +999,7 @@ class Function(object): return context.context().get_compiler_ir( stage=stage, function_name=fn_name, - args=list(canon_args) + concrete_fn.captured_inputs) + args=list(filtered_flat_args) + concrete_fn.captured_inputs) return compiler_ir_generator diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 80ca015a24d..b6997def0d2 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -616,6 +616,18 @@ class DefFunctionTest(xla_test.XLATestCase): 'label', f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot')) + def testGetCompilerIrNonTensors(self): + with ops.device('device:{}:0'.format(self.device)): + + @def_function.function(experimental_compile=True) + def f(l): + return l[0] + l[1] + + l = [constant_op.constant(1.1), constant_op.constant(2.2)] + + self.assertIn('tuple', + f.experimental_get_compiler_ir(l)()) + if __name__ == '__main__': ops.enable_eager_execution()