Update keras benchmark test to use public TF API.
PiperOrigin-RevId: 339489416 Change-Id: If526ae1f84e3516c85c683a823ad85d7e36baeb8
This commit is contained in:
parent
69d9fb6c3c
commit
9587103251
@ -45,7 +45,7 @@ def int_gen():
|
|||||||
yield (np.random.randint(0, 5, (1,)), np.random.randint(0, 7, (1,)))
|
yield (np.random.randint(0, 5, (1,)), np.random.randint(0, 7, (1,)))
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkLayer(benchmark.Benchmark):
|
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
|
||||||
"""Benchmark the layer forward pass."""
|
"""Benchmark the layer forward pass."""
|
||||||
|
|
||||||
def run_dataset_implementation(self, batch_size):
|
def run_dataset_implementation(self, batch_size):
|
||||||
|
@ -36,7 +36,7 @@ FLAGS = flags.FLAGS
|
|||||||
v2_compat.enable_v2_behavior()
|
v2_compat.enable_v2_behavior()
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkLayer(benchmark.Benchmark):
|
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
|
||||||
"""Benchmark the layer forward pass."""
|
"""Benchmark the layer forward pass."""
|
||||||
|
|
||||||
def run_dataset_implementation(self, output_mode, batch_size, sequence_length,
|
def run_dataset_implementation(self, output_mode, batch_size, sequence_length,
|
||||||
|
@ -48,7 +48,7 @@ def reduce_fn(state, values, epsilon=EPSILON):
|
|||||||
return (discretization.merge_summaries(state_, summary, epsilon),)
|
return (discretization.merge_summaries(state_, summary, epsilon),)
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkAdapt(benchmark.Benchmark):
|
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
|
||||||
"""Benchmark adapt."""
|
"""Benchmark adapt."""
|
||||||
|
|
||||||
def run_dataset_implementation(self, num_elements, batch_size):
|
def run_dataset_implementation(self, num_elements, batch_size):
|
||||||
|
@ -47,7 +47,7 @@ def word_gen():
|
|||||||
yield "".join(random.choice(string.ascii_letters) for i in range(2))
|
yield "".join(random.choice(string.ascii_letters) for i in range(2))
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkLayer(benchmark.Benchmark):
|
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
|
||||||
"""Benchmark the layer forward pass."""
|
"""Benchmark the layer forward pass."""
|
||||||
|
|
||||||
def run_dataset_implementation(self, batch_size):
|
def run_dataset_implementation(self, batch_size):
|
||||||
|
@ -85,7 +85,7 @@ def image_augmentation(inputs, batch_size):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkLayer(benchmark.Benchmark):
|
class BenchmarkLayer(benchmark.TensorFlowBenchmark):
|
||||||
"""Benchmark the layer forward pass."""
|
"""Benchmark the layer forward pass."""
|
||||||
|
|
||||||
def run_dataset_implementation(self, batch_size):
|
def run_dataset_implementation(self, batch_size):
|
||||||
|
@ -63,7 +63,7 @@ def get_top_k(dataset, k):
|
|||||||
return sorted_vocab
|
return sorted_vocab
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkAdapt(benchmark.Benchmark):
|
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
|
||||||
"""Benchmark adapt."""
|
"""Benchmark adapt."""
|
||||||
|
|
||||||
def run_numpy_implementation(self, num_elements, batch_size, k):
|
def run_numpy_implementation(self, num_elements, batch_size, k):
|
||||||
|
@ -60,7 +60,7 @@ def reduce_fn(state, values):
|
|||||||
return (k, n + batch_size, ex, ex2)
|
return (k, n + batch_size, ex, ex2)
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkAdapt(benchmark.Benchmark):
|
class BenchmarkAdapt(benchmark.TensorFlowBenchmark):
|
||||||
"""Benchmark adapt."""
|
"""Benchmark adapt."""
|
||||||
|
|
||||||
def run_dataset_implementation(self, num_elements, batch_size):
|
def run_dataset_implementation(self, num_elements, batch_size):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user