[XLA:Python] Make arguments and backend options mandatory arguments to execute_with_python_values.
A couple of tests weren't testing what they were supposed to be testing because they weren't passing a backend argument. Make the argument mandatory so it cannot be omitted. In the future we can probably remove the backend argument and instead derive it from the Executable, but more refactoring is needed first. PiperOrigin-RevId: 308684683 Change-Id: I3fcbec0de0225bd96ab8a3373c20726fd706bf1a
This commit is contained in:
parent
3db8df8ffa
commit
76f02c7f34
tensorflow/compiler/xla/python
@ -536,11 +536,9 @@ class CompileOptions(object):
|
||||
# There are different implementations of Executable for different backends.
|
||||
|
||||
|
||||
def execute_with_python_values(executable, arguments=(), backend=None):
|
||||
def execute_with_python_values(executable, arguments, backend):
|
||||
"""Execute on one replica with Python values as arguments and output."""
|
||||
|
||||
backend = backend or get_local_backend()
|
||||
|
||||
def put(arg):
|
||||
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
|
||||
|
||||
@ -549,7 +547,7 @@ def execute_with_python_values(executable, arguments=(), backend=None):
|
||||
return [x.to_py() for x in outputs]
|
||||
|
||||
|
||||
def execute_with_python_values_replicated(executable, arguments, backend=None):
|
||||
def execute_with_python_values_replicated(executable, arguments, backend):
|
||||
"""Execute on many replicas with Python values as arguments and output.
|
||||
|
||||
Arguments:
|
||||
@ -561,7 +559,6 @@ def execute_with_python_values_replicated(executable, arguments, backend=None):
|
||||
Returns:
|
||||
A list of python values, one per replica.
|
||||
"""
|
||||
backend = backend or get_local_backend()
|
||||
devices = executable.local_devices()
|
||||
# pylint: disable=g-complex-comprehension
|
||||
flat_args = [(arg, devices[replica])
|
||||
|
@ -401,7 +401,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
# Load and execute the proto
|
||||
c = xla_client.XlaComputation(serialized_proto)
|
||||
ans, = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c), backend=self.backend)
|
||||
self.backend.compile(c), (), backend=self.backend)
|
||||
np.testing.assert_equal(ans, np.int32(3))
|
||||
|
||||
tests.append(ComputationFromProtoTest)
|
||||
@ -563,7 +563,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
|
||||
|
||||
result = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()), backend=self.backend)
|
||||
self.backend.compile(c.build()), (), backend=self.backend)
|
||||
self.assertLen(result, 1)
|
||||
expected = np.array(x, dtype=dst_dtype)
|
||||
|
||||
@ -590,7 +590,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
|
||||
|
||||
result = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()), backend=self.backend)
|
||||
self.backend.compile(c.build()), (), backend=self.backend)
|
||||
self.assertLen(result, 1)
|
||||
expected = x.view(dst_dtype)
|
||||
|
||||
@ -1126,7 +1126,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
ops.Constant(c, NumpyArrayBool([True, False, False, True]))
|
||||
])
|
||||
result = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()), backend=self.backend)
|
||||
self.backend.compile(c.build()), (), backend=self.backend)
|
||||
self.assertLen(result, 3)
|
||||
np.testing.assert_equal(result[0], 42)
|
||||
np.testing.assert_allclose(result[1], [1.0, 2.0])
|
||||
@ -1165,7 +1165,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
|
||||
shape))
|
||||
result = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()), backend=self.backend)
|
||||
self.backend.compile(c.build()), (), backend=self.backend)
|
||||
# since the result is random, we just check shape and uniqueness
|
||||
self.assertLen(result, 1)
|
||||
self.assertEqual(result[0].shape, shape)
|
||||
@ -1181,7 +1181,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
|
||||
shape))
|
||||
result = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()), backend=self.backend)
|
||||
self.backend.compile(c.build()), (), backend=self.backend)
|
||||
# since the result is random, we just check shape, uniqueness, and range
|
||||
self.assertLen(result, 1)
|
||||
self.assertEqual(result[0].shape, shape)
|
||||
@ -1199,7 +1199,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
|
||||
shape))
|
||||
result = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()), backend=self.backend)
|
||||
self.backend.compile(c.build()), (), backend=self.backend)
|
||||
# since the result is random, we just check shape, integrality, and range
|
||||
self.assertLen(result, 1)
|
||||
self.assertEqual(result[0].shape, shape)
|
||||
@ -1228,7 +1228,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
c = self._NewComputation()
|
||||
ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
|
||||
result = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()), backend=self.backend)
|
||||
self.backend.compile(c.build()), (), backend=self.backend)
|
||||
self.assertLen(result, 2)
|
||||
np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
|
||||
np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
|
||||
@ -1250,7 +1250,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
dimension=1,
|
||||
comparator=comparator)
|
||||
result = xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()))
|
||||
self.backend.compile(c.build()), (), backend=self.backend)
|
||||
self.assertLen(result, 2)
|
||||
np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
|
||||
np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
|
||||
@ -1734,7 +1734,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
|
||||
for item in to_infeed:
|
||||
result, = xla_client.execute_with_python_values(
|
||||
compiled_c, backend=self.backend)
|
||||
compiled_c, (), backend=self.backend)
|
||||
self.assertEqual(result, item)
|
||||
|
||||
@unittest.skipIf(cloud_tpu, "not implemented")
|
||||
@ -1751,7 +1751,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
device.transfer_to_infeed(to_infeed)
|
||||
|
||||
result = xla_client.execute_with_python_values(
|
||||
compiled_c, backend=self.backend)
|
||||
compiled_c, (), backend=self.backend)
|
||||
self.assertLen(result, 2)
|
||||
np.testing.assert_equal(result[0], to_infeed[0])
|
||||
np.testing.assert_equal(result[1], to_infeed[1])
|
||||
@ -1834,7 +1834,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
|
||||
def TestFun():
|
||||
return xla_client.execute_with_python_values(
|
||||
self.backend.compile(c.build()), [self.f32_scalar_2])
|
||||
self.backend.compile(c.build()), [self.f32_scalar_2], self.backend)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, r"Invalid argument: Argument does not match.*"
|
||||
|
Loading…
Reference in New Issue
Block a user