tf numpy: Interop test for tf.Variable.

PiperOrigin-RevId: 317754511
Change-Id: I6bcabccaa626039c53269bf7e8837fef972bd537
This commit is contained in:
A. Unique TensorFlower 2020-06-22 16:05:37 -07:00 committed by TensorFlower Gardener
parent 8d99ecdbd5
commit b5adbbcf0e

View File

@ -94,25 +94,11 @@ class InteropTest(tf.test.TestCase):
dx, dy = t.gradient([xx, yy], [x, y])
# # TODO(nareshmodi): Figure out a way to rewrap ndarray as tensors.
# self.assertIsInstance(dx, np_arrays.ndarray)
# self.assertIsInstance(dy, np_arrays.ndarray)
# self.assertIsInstance(dx, np.ndarray)
# self.assertIsInstance(dy, np.ndarray)
self.assertAllClose(dx, 2.0)
self.assertAllClose(dy, 3.0)
def testFunctionInterop(self):
x = np.asarray(3.0)
y = np.asarray(2.0)
add = lambda x, y: x + y
add_fn = tf.function(add)
raw_result = add(x, y)
fn_result = add_fn(x, y)
self.assertIsInstance(raw_result, np.ndarray)
self.assertIsInstance(fn_result, np.ndarray)
self.assertAllClose(raw_result, fn_result)
def testCondInterop(self):
x = np.asarray(3.0)
@ -222,6 +208,66 @@ class InteropTest(tf.test.TestCase):
# self.assertIsInstance(reduced, np.ndarray)
self.assertAllClose(reduced, 15)
class FunctionTest(InteropTest):
def testFunctionInterop(self):
x = np.asarray(3.0)
y = np.asarray(2.0)
add = lambda x, y: x + y
add_fn = tf.function(add)
raw_result = add(x, y)
fn_result = add_fn(x, y)
self.assertIsInstance(raw_result, np.ndarray)
self.assertIsInstance(fn_result, np.ndarray)
self.assertAllClose(raw_result, fn_result)
def testLen(self):
@tf.function
def f(x):
# Note that shape of input to len is data dependent.
return len(np.where(x)[0])
t = np.asarray([True, False, True])
with self.assertRaises(TypeError):
f(t)
def testIter(self):
@tf.function
def f(x):
y, z = x
return y, z
with self.assertRaises(TypeError):
f(np.asarray([3, 4]))
def testIndex(self):
@tf.function
def f(x):
return [0, 1][x]
with self.assertRaises(TypeError):
f(np.asarray([1]))
class VariableTest(InteropTest):
def test(self):
tf_var = tf.Variable(2.0)
value = np.square(tf_var)
self.assertIsInstance(value, np.ndarray)
self.assertAllClose(4.0, value)
with tf.control_dependencies([tf_var.assign_add(value)]):
tf_var_value = tf_var.read_value()
self.assertAllClose(6.0, tf_var_value)
if __name__ == '__main__':
tf.compat.v1.enable_eager_execution()
tf.test.main()