Reduce number of session.run calls from variable_ops_test.
PiperOrigin-RevId: 167902092
This commit is contained in:
parent
b007bfdf2c
commit
40cb77d26e
@ -171,22 +171,26 @@ class StridedSliceAssignChecker(object):
|
||||
self.dtype = dtype
|
||||
self.test = test
|
||||
self.x_np = np.array(x).astype(dtype)
|
||||
# Randomly start on mode 0 or 1.
|
||||
self.which_mode = np.random.randint(2, size=1)[0]
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
self.which_mode = 1 - self.which_mode
|
||||
value = np.array(value).astype(self.dtype)
|
||||
|
||||
with self.test.test_session() as sess, self.test.test_scope():
|
||||
x = constant_op.constant(self.x_np, dtype=self.dtype)
|
||||
var = resource_variable_ops.ResourceVariable(x)
|
||||
sess.run(variables.variables_initializer([var]))
|
||||
val = sess.run(var[index].assign(value))
|
||||
# val_copy is used to check that tf.assign works equivalently to the
|
||||
# assign method above.
|
||||
val_copy = sess.run(state_ops.assign(var[index], value))
|
||||
|
||||
if self.which_mode == 0:
|
||||
val = sess.run(var[index].assign(value))
|
||||
else:
|
||||
assert self.which_mode == 1
|
||||
val = sess.run(state_ops.assign(var[index], value))
|
||||
valnp = np.copy(self.x_np)
|
||||
valnp[index] = np.array(value)
|
||||
self.test.assertAllEqual(val, valnp)
|
||||
self.test.assertAllEqual(val_copy, valnp)
|
||||
|
||||
|
||||
class SliceAssignTest(XLATestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user