Update training unittests to TF2
This commit is contained in:
parent
7802e2f284
commit
123aeb0a44
@ -2,6 +2,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from deepspeech_training.util.helpers import ValueRange, get_value_range, pick_value_from_range, tf_pick_value_from_range
|
||||
|
||||
|
||||
@ -59,16 +60,16 @@ class TestValueRange(unittest.TestCase):
|
||||
|
||||
class TestPickValueFromFixedRange(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestPickValueFromFixedRange, self).__init__(*args, **kwargs)
|
||||
self.session = tf.Session()
|
||||
self.clock_ph = tf.placeholder(dtype=tf.float64, name='clock')
|
||||
super().__init__(*args, **kwargs)
|
||||
self.session = tfv1.Session()
|
||||
|
||||
def _ending_tester(self, value_range, clock, expected):
|
||||
with tf.Session() as session:
|
||||
tf_pick = tf_pick_value_from_range(value_range, clock=self.clock_ph)
|
||||
with self.session as session:
|
||||
clock_ph = tfv1.placeholder(dtype=tf.float64, name='clock')
|
||||
tf_pick = tf_pick_value_from_range(value_range, clock=clock_ph)
|
||||
|
||||
def run_pick(_, c):
|
||||
return session.run(tf_pick, feed_dict={self.clock_ph: c})
|
||||
return session.run(tf_pick, feed_dict={clock_ph: c})
|
||||
|
||||
is_int = isinstance(value_range.start, int)
|
||||
for pick, int_type, float_type in [(pick_value_from_range, int, float), (run_pick, np.int32, np.float32)]:
|
||||
@ -97,16 +98,16 @@ class TestPickValueFromFixedRange(unittest.TestCase):
|
||||
|
||||
class TestPickValueFromRandomizedRange(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestPickValueFromRandomizedRange, self).__init__(*args, **kwargs)
|
||||
self.session = tf.Session()
|
||||
self.clock_ph = tf.placeholder(dtype=tf.float64, name='clock')
|
||||
super().__init__(*args, **kwargs)
|
||||
self.session = tfv1.Session()
|
||||
|
||||
def _ending_tester(self, value_range, clock_min, clock_max, expected_min, expected_max):
|
||||
with self.session as session:
|
||||
tf_pick = tf_pick_value_from_range(value_range, clock=self.clock_ph)
|
||||
clock_ph = tfv1.placeholder(dtype=tf.float64, name='clock')
|
||||
tf_pick = tf_pick_value_from_range(value_range, clock=clock_ph)
|
||||
|
||||
def run_pick(_, c):
|
||||
return session.run(tf_pick, feed_dict={self.clock_ph: c})
|
||||
return session.run(tf_pick, feed_dict={clock_ph: c})
|
||||
|
||||
is_int = isinstance(value_range.start, int)
|
||||
clock_range = np.arange(clock_min, clock_max, (clock_max - clock_min) / 100.0)
|
||||
|
Loading…
Reference in New Issue
Block a user