Update training unittests to TF2

This commit is contained in:
Reuben Morais 2021-01-03 11:07:53 +00:00
parent 7802e2f284
commit 123aeb0a44

View File

@ -2,6 +2,7 @@ import unittest
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from deepspeech_training.util.helpers import ValueRange, get_value_range, pick_value_from_range, tf_pick_value_from_range
@ -59,16 +60,16 @@ class TestValueRange(unittest.TestCase):
class TestPickValueFromFixedRange(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestPickValueFromFixedRange, self).__init__(*args, **kwargs)
self.session = tf.Session()
self.clock_ph = tf.placeholder(dtype=tf.float64, name='clock')
super().__init__(*args, **kwargs)
self.session = tfv1.Session()
def _ending_tester(self, value_range, clock, expected):
with tf.Session() as session:
tf_pick = tf_pick_value_from_range(value_range, clock=self.clock_ph)
with self.session as session:
clock_ph = tfv1.placeholder(dtype=tf.float64, name='clock')
tf_pick = tf_pick_value_from_range(value_range, clock=clock_ph)
def run_pick(_, c):
return session.run(tf_pick, feed_dict={self.clock_ph: c})
return session.run(tf_pick, feed_dict={clock_ph: c})
is_int = isinstance(value_range.start, int)
for pick, int_type, float_type in [(pick_value_from_range, int, float), (run_pick, np.int32, np.float32)]:
@ -97,16 +98,16 @@ class TestPickValueFromFixedRange(unittest.TestCase):
class TestPickValueFromRandomizedRange(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestPickValueFromRandomizedRange, self).__init__(*args, **kwargs)
self.session = tf.Session()
self.clock_ph = tf.placeholder(dtype=tf.float64, name='clock')
super().__init__(*args, **kwargs)
self.session = tfv1.Session()
def _ending_tester(self, value_range, clock_min, clock_max, expected_min, expected_max):
with self.session as session:
tf_pick = tf_pick_value_from_range(value_range, clock=self.clock_ph)
clock_ph = tfv1.placeholder(dtype=tf.float64, name='clock')
tf_pick = tf_pick_value_from_range(value_range, clock=clock_ph)
def run_pick(_, c):
return session.run(tf_pick, feed_dict={self.clock_ph: c})
return session.run(tf_pick, feed_dict={clock_ph: c})
is_int = isinstance(value_range.start, int)
clock_range = np.arange(clock_min, clock_max, (clock_max - clock_min) / 100.0)