diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index a7e2423aa76..6c93d8faaef 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -476,6 +476,21 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): func = lambda: f(m, m, transpose_b=transpose_b) self._run(func, num_iters, execution_mode=execution_mode) + def _benchmark_defun_matmul_with_signature(self, + m, + num_iters, + execution_mode=None): + + def func_matmul(m): + return math_ops.matmul(m, m) + + f = function.defun( + func_matmul, + input_signature=[tensor_spec.TensorSpec([2, 2], dtypes.float32)]) + + func = lambda: f(m) + self._run(func, num_iters, execution_mode=execution_mode) + def _benchmark_defun_args_matmul(self, m, num_iters, execution_mode=None): @def_function.function @@ -576,6 +591,18 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_defun_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) + def benchmark_defun_matmul_2_by_2_CPU_with_signature(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_defun_matmul_with_signature( + m, num_iters=self._num_iters_2_by_2) + + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") + def benchmark_defun_args_matmul_2_by_2_CPU(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_defun_args_matmul(m, num_iters=self._num_iters_2_by_2) + @test_util.disable_tfrt("async not supported") def benchmark_defun_matmul_2_by_2_CPU_async(self): with context.device(CPU): @@ -651,6 +678,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m, transpose_b=False, num_iters=self._num_iters_2_by_2) @test_util.disable_tfrt("copy to GPU not supported") + def benchmark_defun_matmul_2_by_2_GPU_with_signature(self): + if not context.num_gpus(): + return + with context.device(GPU): + m = self._m_2_by_2.gpu() + self._benchmark_defun_matmul_with_signature( + m, num_iters=self._num_iters_2_by_2) + + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_args_matmul_2_by_2_GPU(self): if not context.num_gpus(): return