diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index d5b2663af53..c9633fee42f 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -536,11 +536,9 @@ class CompileOptions(object): # There are different implementations of Executable for different backends. -def execute_with_python_values(executable, arguments=(), backend=None): +def execute_with_python_values(executable, arguments, backend): """Execute on one replica with Python values as arguments and output.""" - backend = backend or get_local_backend() - def put(arg): return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) @@ -549,7 +547,7 @@ def execute_with_python_values(executable, arguments=(), backend=None): return [x.to_py() for x in outputs] -def execute_with_python_values_replicated(executable, arguments, backend=None): +def execute_with_python_values_replicated(executable, arguments, backend): """Execute on many replicas with Python values as arguments and output. Arguments: @@ -561,7 +559,6 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): Returns: A list of python values, one per replica. """ - backend = backend or get_local_backend() devices = executable.local_devices() # pylint: disable=g-complex-comprehension flat_args = [(arg, devices[replica]) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index a0553c6a8e9..9f795d11d8d 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -401,7 +401,7 @@ def TestFactory(xla_backend, cloud_tpu=False): # Load and execute the proto c = xla_client.XlaComputation(serialized_proto) ans, = xla_client.execute_with_python_values( - self.backend.compile(c), backend=self.backend) + self.backend.compile(c), (), backend=self.backend) np.testing.assert_equal(ans, np.int32(3)) tests.append(ComputationFromProtoTest) @@ -563,7 +563,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) result = xla_client.execute_with_python_values( - self.backend.compile(c.build()), backend=self.backend) + self.backend.compile(c.build()), (), backend=self.backend) self.assertLen(result, 1) expected = np.array(x, dtype=dst_dtype) @@ -590,7 +590,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) result = xla_client.execute_with_python_values( - self.backend.compile(c.build()), backend=self.backend) + self.backend.compile(c.build()), (), backend=self.backend) self.assertLen(result, 1) expected = x.view(dst_dtype) @@ -1126,7 +1126,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Constant(c, NumpyArrayBool([True, False, False, True])) ]) result = xla_client.execute_with_python_values( - self.backend.compile(c.build()), backend=self.backend) + self.backend.compile(c.build()), (), backend=self.backend) self.assertLen(result, 3) np.testing.assert_equal(result[0], 42) np.testing.assert_allclose(result[1], [1.0, 2.0]) @@ -1165,7 +1165,7 @@ def TestFactory(xla_backend, cloud_tpu=False): shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape)) result = xla_client.execute_with_python_values( - self.backend.compile(c.build()), backend=self.backend) + self.backend.compile(c.build()), (), backend=self.backend) # since the result is random, we just check shape and uniqueness self.assertLen(result, 1) self.assertEqual(result[0].shape, shape) @@ -1181,7 +1181,7 @@ def TestFactory(xla_backend, cloud_tpu=False): shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape)) result = xla_client.execute_with_python_values( - self.backend.compile(c.build()), backend=self.backend) + self.backend.compile(c.build()), (), backend=self.backend) # since the result is random, we just check shape, uniqueness, and range self.assertLen(result, 1) self.assertEqual(result[0].shape, shape) @@ -1199,7 +1199,7 @@ def TestFactory(xla_backend, cloud_tpu=False): shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, shape)) result = xla_client.execute_with_python_values( - self.backend.compile(c.build()), backend=self.backend) + self.backend.compile(c.build()), (), backend=self.backend) # since the result is random, we just check shape, integrality, and range self.assertLen(result, 1) self.assertEqual(result[0].shape, shape) @@ -1228,7 +1228,7 @@ def TestFactory(xla_backend, cloud_tpu=False): c = self._NewComputation() ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) result = xla_client.execute_with_python_values( - self.backend.compile(c.build()), backend=self.backend) + self.backend.compile(c.build()), (), backend=self.backend) self.assertLen(result, 2) np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) @@ -1250,7 +1250,7 @@ def TestFactory(xla_backend, cloud_tpu=False): dimension=1, comparator=comparator) result = xla_client.execute_with_python_values( - self.backend.compile(c.build())) + self.backend.compile(c.build()), (), backend=self.backend) self.assertLen(result, 2) np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) @@ -1734,7 +1734,7 @@ def TestFactory(xla_backend, cloud_tpu=False): for item in to_infeed: result, = xla_client.execute_with_python_values( - compiled_c, backend=self.backend) + compiled_c, (), backend=self.backend) self.assertEqual(result, item) @unittest.skipIf(cloud_tpu, "not implemented") @@ -1751,7 +1751,7 @@ def TestFactory(xla_backend, cloud_tpu=False): device.transfer_to_infeed(to_infeed) result = xla_client.execute_with_python_values( - compiled_c, backend=self.backend) + compiled_c, (), backend=self.backend) self.assertLen(result, 2) np.testing.assert_equal(result[0], to_infeed[0]) np.testing.assert_equal(result[1], to_infeed[1]) @@ -1834,7 +1834,7 @@ def TestFactory(xla_backend, cloud_tpu=False): def TestFun(): return xla_client.execute_with_python_values( - self.backend.compile(c.build()), [self.f32_scalar_2]) + self.backend.compile(c.build()), [self.f32_scalar_2], self.backend) self.assertRaisesRegex( RuntimeError, r"Invalid argument: Argument does not match.*"