Merge pull request #41714 from jonathanchu33:add-benchmarks

PiperOrigin-RevId: 323677178
Change-Id: I8b6d8c683a5df922464b6edad13f381b47e583d0
This commit is contained in:
TensorFlower Gardener 2020-07-28 16:31:42 -07:00
commit f38ff8449b

View File

@ -476,6 +476,21 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
func = lambda: f(m, m, transpose_b=transpose_b)
self._run(func, num_iters, execution_mode=execution_mode)
def _benchmark_defun_matmul_with_signature(self,
m,
num_iters,
execution_mode=None):
def func_matmul(m):
return math_ops.matmul(m, m)
f = function.defun(
func_matmul,
input_signature=[tensor_spec.TensorSpec([2, 2], dtypes.float32)])
func = lambda: f(m)
self._run(func, num_iters, execution_mode=execution_mode)
def _benchmark_defun_args_matmul(self, m, num_iters, execution_mode=None):
@def_function.function
@ -576,6 +591,18 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
self._benchmark_defun_matmul(
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
def benchmark_defun_matmul_2_by_2_CPU_with_signature(self):
with context.device(CPU):
m = self._m_2_by_2.cpu()
self._benchmark_defun_matmul_with_signature(
m, num_iters=self._num_iters_2_by_2)
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_args_matmul_2_by_2_CPU(self):
with context.device(CPU):
m = self._m_2_by_2.cpu()
self._benchmark_defun_args_matmul(m, num_iters=self._num_iters_2_by_2)
@test_util.disable_tfrt("async not supported")
def benchmark_defun_matmul_2_by_2_CPU_async(self):
with context.device(CPU):
@ -651,6 +678,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
m, transpose_b=False, num_iters=self._num_iters_2_by_2)
@test_util.disable_tfrt("copy to GPU not supported")
def benchmark_defun_matmul_2_by_2_GPU_with_signature(self):
if not context.num_gpus():
return
with context.device(GPU):
m = self._m_2_by_2.gpu()
self._benchmark_defun_matmul_with_signature(
m, num_iters=self._num_iters_2_by_2)
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_args_matmul_2_by_2_GPU(self):
if not context.num_gpus():
return