tf numpy: Interop test for tf.Variable.
PiperOrigin-RevId: 317754511 Change-Id: I6bcabccaa626039c53269bf7e8837fef972bd537
This commit is contained in:
parent
8d99ecdbd5
commit
b5adbbcf0e
@ -94,25 +94,11 @@ class InteropTest(tf.test.TestCase):
|
||||
dx, dy = t.gradient([xx, yy], [x, y])
|
||||
|
||||
# # TODO(nareshmodi): Figure out a way to rewrap ndarray as tensors.
|
||||
# self.assertIsInstance(dx, np_arrays.ndarray)
|
||||
# self.assertIsInstance(dy, np_arrays.ndarray)
|
||||
# self.assertIsInstance(dx, np.ndarray)
|
||||
# self.assertIsInstance(dy, np.ndarray)
|
||||
self.assertAllClose(dx, 2.0)
|
||||
self.assertAllClose(dy, 3.0)
|
||||
|
||||
def testFunctionInterop(self):
|
||||
x = np.asarray(3.0)
|
||||
y = np.asarray(2.0)
|
||||
|
||||
add = lambda x, y: x + y
|
||||
add_fn = tf.function(add)
|
||||
|
||||
raw_result = add(x, y)
|
||||
fn_result = add_fn(x, y)
|
||||
|
||||
self.assertIsInstance(raw_result, np.ndarray)
|
||||
self.assertIsInstance(fn_result, np.ndarray)
|
||||
self.assertAllClose(raw_result, fn_result)
|
||||
|
||||
def testCondInterop(self):
|
||||
x = np.asarray(3.0)
|
||||
|
||||
@ -222,6 +208,66 @@ class InteropTest(tf.test.TestCase):
|
||||
# self.assertIsInstance(reduced, np.ndarray)
|
||||
self.assertAllClose(reduced, 15)
|
||||
|
||||
|
||||
class FunctionTest(InteropTest):
|
||||
|
||||
def testFunctionInterop(self):
|
||||
x = np.asarray(3.0)
|
||||
y = np.asarray(2.0)
|
||||
|
||||
add = lambda x, y: x + y
|
||||
add_fn = tf.function(add)
|
||||
|
||||
raw_result = add(x, y)
|
||||
fn_result = add_fn(x, y)
|
||||
|
||||
self.assertIsInstance(raw_result, np.ndarray)
|
||||
self.assertIsInstance(fn_result, np.ndarray)
|
||||
self.assertAllClose(raw_result, fn_result)
|
||||
|
||||
def testLen(self):
|
||||
|
||||
@tf.function
|
||||
def f(x):
|
||||
# Note that shape of input to len is data dependent.
|
||||
return len(np.where(x)[0])
|
||||
|
||||
t = np.asarray([True, False, True])
|
||||
with self.assertRaises(TypeError):
|
||||
f(t)
|
||||
|
||||
def testIter(self):
|
||||
|
||||
@tf.function
|
||||
def f(x):
|
||||
y, z = x
|
||||
return y, z
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
f(np.asarray([3, 4]))
|
||||
|
||||
def testIndex(self):
|
||||
|
||||
@tf.function
|
||||
def f(x):
|
||||
return [0, 1][x]
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
f(np.asarray([1]))
|
||||
|
||||
|
||||
class VariableTest(InteropTest):
|
||||
|
||||
def test(self):
|
||||
tf_var = tf.Variable(2.0)
|
||||
value = np.square(tf_var)
|
||||
self.assertIsInstance(value, np.ndarray)
|
||||
self.assertAllClose(4.0, value)
|
||||
with tf.control_dependencies([tf_var.assign_add(value)]):
|
||||
tf_var_value = tf_var.read_value()
|
||||
self.assertAllClose(6.0, tf_var_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user