Update keras benchmark test to use public TF API.

PiperOrigin-RevId: 339489416
Change-Id: If526ae1f84e3516c85c683a823ad85d7e36baeb8
This commit is contained in:
Scott Zhu 2020-10-28 10:37:05 -07:00 committed by TensorFlower Gardener
parent 69d9fb6c3c
commit 9587103251
7 changed files with 7 additions and 7 deletions

View File

@ -45,7 +45,7 @@ def int_gen():
yield (np.random.randint(0, 5, (1,)), np.random.randint(0, 7, (1,)))
class BenchmarkLayer(benchmark.Benchmark):
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
"""Benchmark the layer forward pass."""
def run_dataset_implementation(self, batch_size):

View File

@ -36,7 +36,7 @@ FLAGS = flags.FLAGS
v2_compat.enable_v2_behavior()
class BenchmarkLayer(benchmark.Benchmark):
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
"""Benchmark the layer forward pass."""
def run_dataset_implementation(self, output_mode, batch_size, sequence_length,

View File

@ -48,7 +48,7 @@ def reduce_fn(state, values, epsilon=EPSILON):
return (discretization.merge_summaries(state_, summary, epsilon),)
class BenchmarkAdapt(benchmark.Benchmark):
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
"""Benchmark adapt."""
def run_dataset_implementation(self, num_elements, batch_size):

View File

@ -47,7 +47,7 @@ def word_gen():
yield "".join(random.choice(string.ascii_letters) for i in range(2))
class BenchmarkLayer(benchmark.Benchmark):
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
"""Benchmark the layer forward pass."""
def run_dataset_implementation(self, batch_size):

View File

@ -85,7 +85,7 @@ def image_augmentation(inputs, batch_size):
return img
class BenchmarkLayer(benchmark.Benchmark):
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
"""Benchmark the layer forward pass."""
def run_dataset_implementation(self, batch_size):

View File

@ -63,7 +63,7 @@ def get_top_k(dataset, k):
return sorted_vocab
class BenchmarkAdapt(benchmark.Benchmark):
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
"""Benchmark adapt."""
def run_numpy_implementation(self, num_elements, batch_size, k):

View File

@ -60,7 +60,7 @@ def reduce_fn(state, values):
return (k, n + batch_size, ex, ex2)
class BenchmarkAdapt(benchmark.Benchmark):
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
"""Benchmark adapt."""
def run_dataset_implementation(self, num_elements, batch_size):