[XLA:Python] Make arguments and backend options mandatory arguments to execute_with_python_values.

A couple of tests weren't testing what they were supposed to be testing because they weren't passing a backend argument. Make the argument mandatory so it cannot be omitted.

In the future we can probably remove the backend argument and instead derive it from the Executable, but more refactoring is needed first.

PiperOrigin-RevId: 308684683
Change-Id: I3fcbec0de0225bd96ab8a3373c20726fd706bf1a
This commit is contained in:
Peter Hawkins 2020-04-27 13:19:08 -07:00 committed by TensorFlower Gardener
parent 3db8df8ffa
commit 76f02c7f34
2 changed files with 14 additions and 17 deletions
tensorflow/compiler/xla/python

View File

@ -536,11 +536,9 @@ class CompileOptions(object):
# There are different implementations of Executable for different backends.
def execute_with_python_values(executable, arguments=(), backend=None):
def execute_with_python_values(executable, arguments, backend):
"""Execute on one replica with Python values as arguments and output."""
backend = backend or get_local_backend()
def put(arg):
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
@ -549,7 +547,7 @@ def execute_with_python_values(executable, arguments=(), backend=None):
return [x.to_py() for x in outputs]
def execute_with_python_values_replicated(executable, arguments, backend=None):
def execute_with_python_values_replicated(executable, arguments, backend):
"""Execute on many replicas with Python values as arguments and output.
Arguments:
@ -561,7 +559,6 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
Returns:
A list of python values, one per replica.
"""
backend = backend or get_local_backend()
devices = executable.local_devices()
# pylint: disable=g-complex-comprehension
flat_args = [(arg, devices[replica])

View File

@ -401,7 +401,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
# Load and execute the proto
c = xla_client.XlaComputation(serialized_proto)
ans, = xla_client.execute_with_python_values(
self.backend.compile(c), backend=self.backend)
self.backend.compile(c), (), backend=self.backend)
np.testing.assert_equal(ans, np.int32(3))
tests.append(ComputationFromProtoTest)
@ -563,7 +563,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
result = xla_client.execute_with_python_values(
self.backend.compile(c.build()), backend=self.backend)
self.backend.compile(c.build()), (), backend=self.backend)
self.assertLen(result, 1)
expected = np.array(x, dtype=dst_dtype)
@ -590,7 +590,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
result = xla_client.execute_with_python_values(
self.backend.compile(c.build()), backend=self.backend)
self.backend.compile(c.build()), (), backend=self.backend)
self.assertLen(result, 1)
expected = x.view(dst_dtype)
@ -1126,7 +1126,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Constant(c, NumpyArrayBool([True, False, False, True]))
])
result = xla_client.execute_with_python_values(
self.backend.compile(c.build()), backend=self.backend)
self.backend.compile(c.build()), (), backend=self.backend)
self.assertLen(result, 3)
np.testing.assert_equal(result[0], 42)
np.testing.assert_allclose(result[1], [1.0, 2.0])
@ -1165,7 +1165,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
shape))
result = xla_client.execute_with_python_values(
self.backend.compile(c.build()), backend=self.backend)
self.backend.compile(c.build()), (), backend=self.backend)
# since the result is random, we just check shape and uniqueness
self.assertLen(result, 1)
self.assertEqual(result[0].shape, shape)
@ -1181,7 +1181,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
shape))
result = xla_client.execute_with_python_values(
self.backend.compile(c.build()), backend=self.backend)
self.backend.compile(c.build()), (), backend=self.backend)
# since the result is random, we just check shape, uniqueness, and range
self.assertLen(result, 1)
self.assertEqual(result[0].shape, shape)
@ -1199,7 +1199,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
shape))
result = xla_client.execute_with_python_values(
self.backend.compile(c.build()), backend=self.backend)
self.backend.compile(c.build()), (), backend=self.backend)
# since the result is random, we just check shape, integrality, and range
self.assertLen(result, 1)
self.assertEqual(result[0].shape, shape)
@ -1228,7 +1228,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
c = self._NewComputation()
ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
result = xla_client.execute_with_python_values(
self.backend.compile(c.build()), backend=self.backend)
self.backend.compile(c.build()), (), backend=self.backend)
self.assertLen(result, 2)
np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
@ -1250,7 +1250,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
dimension=1,
comparator=comparator)
result = xla_client.execute_with_python_values(
self.backend.compile(c.build()))
self.backend.compile(c.build()), (), backend=self.backend)
self.assertLen(result, 2)
np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
@ -1734,7 +1734,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
for item in to_infeed:
result, = xla_client.execute_with_python_values(
compiled_c, backend=self.backend)
compiled_c, (), backend=self.backend)
self.assertEqual(result, item)
@unittest.skipIf(cloud_tpu, "not implemented")
@ -1751,7 +1751,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
device.transfer_to_infeed(to_infeed)
result = xla_client.execute_with_python_values(
compiled_c, backend=self.backend)
compiled_c, (), backend=self.backend)
self.assertLen(result, 2)
np.testing.assert_equal(result[0], to_infeed[0])
np.testing.assert_equal(result[1], to_infeed[1])
@ -1834,7 +1834,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
def TestFun():
return xla_client.execute_with_python_values(
self.backend.compile(c.build()), [self.f32_scalar_2])
self.backend.compile(c.build()), [self.f32_scalar_2], self.backend)
self.assertRaisesRegex(
RuntimeError, r"Invalid argument: Argument does not match.*"