Reduce number of session.run calls from variable_ops_test.

PiperOrigin-RevId: 167902092
This commit is contained in:
A. Unique TensorFlower 2017-09-07 13:22:20 -07:00 committed by TensorFlower Gardener
parent b007bfdf2c
commit 40cb77d26e

View File

@ -171,22 +171,26 @@ class StridedSliceAssignChecker(object):
self.dtype = dtype
self.test = test
self.x_np = np.array(x).astype(dtype)
# Randomly start on mode 0 or 1.
self.which_mode = np.random.randint(2, size=1)[0]
def __setitem__(self, index, value):
self.which_mode = 1 - self.which_mode
value = np.array(value).astype(self.dtype)
with self.test.test_session() as sess, self.test.test_scope():
x = constant_op.constant(self.x_np, dtype=self.dtype)
var = resource_variable_ops.ResourceVariable(x)
sess.run(variables.variables_initializer([var]))
val = sess.run(var[index].assign(value))
# val_copy is used to check that tf.assign works equivalently to the
# assign method above.
val_copy = sess.run(state_ops.assign(var[index], value))
if self.which_mode == 0:
val = sess.run(var[index].assign(value))
else:
assert self.which_mode == 1
val = sess.run(state_ops.assign(var[index], value))
valnp = np.copy(self.x_np)
valnp[index] = np.array(value)
self.test.assertAllEqual(val, valnp)
self.test.assertAllEqual(val_copy, valnp)
class SliceAssignTest(XLATestCase):