Update keras benchmark test to use public TF API.
PiperOrigin-RevId: 339489416 Change-Id: If526ae1f84e3516c85c683a823ad85d7e36baeb8
This commit is contained in:
parent
69d9fb6c3c
commit
9587103251
@ -45,7 +45,7 @@ def int_gen():
|
||||
yield (np.random.randint(0, 5, (1,)), np.random.randint(0, 7, (1,)))
|
||||
|
||||
|
||||
class BenchmarkLayer(benchmark.Benchmark):
|
||||
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
|
||||
"""Benchmark the layer forward pass."""
|
||||
|
||||
def run_dataset_implementation(self, batch_size):
|
||||
|
@ -36,7 +36,7 @@ FLAGS = flags.FLAGS
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
||||
|
||||
class BenchmarkLayer(benchmark.Benchmark):
|
||||
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
|
||||
"""Benchmark the layer forward pass."""
|
||||
|
||||
def run_dataset_implementation(self, output_mode, batch_size, sequence_length,
|
||||
|
@ -48,7 +48,7 @@ def reduce_fn(state, values, epsilon=EPSILON):
|
||||
return (discretization.merge_summaries(state_, summary, epsilon),)
|
||||
|
||||
|
||||
class BenchmarkAdapt(benchmark.Benchmark):
|
||||
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
|
||||
"""Benchmark adapt."""
|
||||
|
||||
def run_dataset_implementation(self, num_elements, batch_size):
|
||||
|
@ -47,7 +47,7 @@ def word_gen():
|
||||
yield "".join(random.choice(string.ascii_letters) for i in range(2))
|
||||
|
||||
|
||||
class BenchmarkLayer(benchmark.Benchmark):
|
||||
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
|
||||
"""Benchmark the layer forward pass."""
|
||||
|
||||
def run_dataset_implementation(self, batch_size):
|
||||
|
@ -85,7 +85,7 @@ def image_augmentation(inputs, batch_size):
|
||||
return img
|
||||
|
||||
|
||||
class BenchmarkLayer(benchmark.Benchmark):
|
||||
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
|
||||
"""Benchmark the layer forward pass."""
|
||||
|
||||
def run_dataset_implementation(self, batch_size):
|
||||
|
@ -63,7 +63,7 @@ def get_top_k(dataset, k):
|
||||
return sorted_vocab
|
||||
|
||||
|
||||
class BenchmarkAdapt(benchmark.Benchmark):
|
||||
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
|
||||
"""Benchmark adapt."""
|
||||
|
||||
def run_numpy_implementation(self, num_elements, batch_size, k):
|
||||
|
@ -60,7 +60,7 @@ def reduce_fn(state, values):
|
||||
return (k, n + batch_size, ex, ex2)
|
||||
|
||||
|
||||
class BenchmarkAdapt(benchmark.Benchmark):
|
||||
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
|
||||
"""Benchmark adapt."""
|
||||
|
||||
def run_dataset_implementation(self, num_elements, batch_size):
|
||||
|
Loading…
Reference in New Issue
Block a user