From 40cb77d26e43ec45c07064a0642408c27cd44d82 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 7 Sep 2017 13:22:20 -0700 Subject: [PATCH] Reduce number of session.run calls from variable_ops_test. PiperOrigin-RevId: 167902092 --- tensorflow/compiler/tests/variable_ops_test.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index a6b59fc731e..fdf3f9fb6ad 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -171,22 +171,26 @@ class StridedSliceAssignChecker(object): self.dtype = dtype self.test = test self.x_np = np.array(x).astype(dtype) + # Randomly start on mode 0 or 1. + self.which_mode = np.random.randint(2, size=1)[0] def __setitem__(self, index, value): + self.which_mode = 1 - self.which_mode value = np.array(value).astype(self.dtype) with self.test.test_session() as sess, self.test.test_scope(): x = constant_op.constant(self.x_np, dtype=self.dtype) var = resource_variable_ops.ResourceVariable(x) sess.run(variables.variables_initializer([var])) - val = sess.run(var[index].assign(value)) - # val_copy is used to check that tf.assign works equivalently to the - # assign method above. - val_copy = sess.run(state_ops.assign(var[index], value)) + + if self.which_mode == 0: + val = sess.run(var[index].assign(value)) + else: + assert self.which_mode == 1 + val = sess.run(state_ops.assign(var[index], value)) valnp = np.copy(self.x_np) valnp[index] = np.array(value) self.test.assertAllEqual(val, valnp) - self.test.assertAllEqual(val_copy, valnp) class SliceAssignTest(XLATestCase):