From 85f7677b4ac3ebbe444711adfd7a45f18b1b6b2b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Mar 2020 10:56:23 -0700 Subject: [PATCH] [XLA:Python] Update tpu_driver to add the same automatic tupling of arguments and untupling of results present in the local client. Update tests to use the automatic untupling support. PiperOrigin-RevId: 301623333 Change-Id: I1233e6a63eaea2bfef2ac7a85bf1b55b820361d1 --- .../python/tpu_driver/client/tpu_client.cc | 58 ++- .../xla/python/tpu_driver/client/tpu_client.h | 12 +- .../tpu_driver/client/tpu_client_extension.cc | 6 +- tensorflow/compiler/xla/python/xla.cc | 35 -- tensorflow/compiler/xla/python/xla_client.py | 28 +- .../compiler/xla/python/xla_client_test.py | 387 +++++++++--------- 6 files changed, 275 insertions(+), 251 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 706db57c4ac..56ac640cb6c 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -227,8 +227,8 @@ StatusOr> PyTpuBuffer::FromLiterals( /* static */ StatusOr> PyTpuBuffer::MakeTuple( - const std::vector buffers, - std::shared_ptr client, std::shared_ptr device) { + absl::Span buffers, std::shared_ptr client, + std::shared_ptr device) { std::vector child_shapes; std::vector> child_device_buffers; std::vector child_handle_ptrs; @@ -611,8 +611,8 @@ Status WaitForExecuteEvent(tpu_driver::Event* event) { return opt_status.value(); } -StatusOr> PyTpuExecutable::Execute( - absl::Span argument_handles) { +StatusOr>> PyTpuExecutable::Execute( + absl::Span argument_handles, bool tuple_arguments) { if (num_replicas() != 1) { return InvalidArgument( "Attempted to execute computation with %d replicas using Execute().", @@ -624,9 +624,18 @@ StatusOr> PyTpuExecutable::Execute( num_partitions()); } - std::vector all_core_arguments(argument_handles.begin(), - argument_handles.end()); + std::vector all_core_arguments; + std::unique_ptr tupled_arguments; + if (tuple_arguments) { + TF_ASSIGN_OR_RETURN(tupled_arguments, + PyTpuBuffer::MakeTuple(argument_handles, client_, + local_devices_.front())); + all_core_arguments = {tupled_arguments.get()}; + } else { + all_core_arguments = std::vector(argument_handles.begin(), + argument_handles.end()); + } ExecuteResult result = ExecuteHelper(absl::MakeSpan(&all_core_arguments, 1), argument_handles, /*replica=*/0, /*partition=*/0, RunId()); @@ -638,12 +647,19 @@ StatusOr> PyTpuExecutable::Execute( return status; } - return std::move(result.buffer); + if (result.buffer->on_host_shape().IsTuple()) { + return result.buffer->DestructureTuple(); + } else { + std::vector> outputs; + outputs.push_back(std::move(result.buffer)); + return outputs; + } } -StatusOr>> +StatusOr>>> PyTpuExecutable::ExecuteOnLocalDevices( - absl::Span> argument_handles) { + absl::Span> argument_handles, + bool tuple_arguments) { tensorflow::profiler::TraceMe traceme( "PyTpuExecutable::ExecuteOnLocalDevices"); @@ -661,6 +677,20 @@ PyTpuExecutable::ExecuteOnLocalDevices( << " num_partitions=" << num_partitions() << " num_local_devices=" << num_local_devices; + std::vector> tupled_arguments; + std::vector> tupled_argument_pointers; + if (tuple_arguments) { + tupled_arguments.resize(argument_handles.size()); + tupled_argument_pointers.resize(argument_handles.size()); + for (int i = 0; i < num_local_devices; ++i) { + TF_ASSIGN_OR_RETURN(tupled_arguments[i], + PyTpuBuffer::MakeTuple(argument_handles[i], client_, + local_devices_.at(i))); + tupled_argument_pointers[i] = {tupled_arguments[i].get()}; + } + argument_handles = tupled_argument_pointers; + } + absl::Mutex results_lock; std::vector results(num_local_devices); @@ -702,9 +732,15 @@ PyTpuExecutable::ExecuteOnLocalDevices( } VLOG(1) << "Replicated execution complete."; - std::vector> wrapped_results(num_local_devices); + std::vector>> wrapped_results( + num_local_devices); for (int i = 0; i < num_local_devices; ++i) { - wrapped_results[i] = std::move(results[i].buffer); + if (results[i].buffer->on_host_shape().IsTuple()) { + TF_ASSIGN_OR_RETURN(wrapped_results[i], + results[i].buffer->DestructureTuple()); + } else { + wrapped_results[i].push_back(std::move(results[i].buffer)); + } } return wrapped_results; } diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 4b7670707fb..2b1ac4a3044 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -166,7 +166,7 @@ class PyTpuBuffer { // Supports nested tuple creation. static StatusOr> MakeTuple( - const std::vector buffers, + absl::Span buffers, std::shared_ptr client, std::shared_ptr device); PyTpuBuffer() = delete; @@ -308,15 +308,17 @@ class PyTpuExecutable { // TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait // inside for computation to finish. Coordinate with JAX code change to see if // we can make both Execute and ExecutePerReplica non-blocking. - StatusOr> Execute( - absl::Span argument_handles); + StatusOr>> Execute( + absl::Span argument_handles, bool tuple_arguments); // Execute on local devices. Takes a sequence of argument lists (one argument // list per local device) and returns a tuple of results (one result per local // device). The number of argument lists must be equal to the local device // count. - StatusOr>> ExecuteOnLocalDevices( - absl::Span> argument_handles); + StatusOr>>> + ExecuteOnLocalDevices( + absl::Span> argument_handles, + bool tuple_arguments); void Delete() { executables_.clear(); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 0dcb9dc4c84..b4e8afb5853 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -203,9 +203,11 @@ PYBIND11_MODULE(tpu_client_extension, m) { &PyTpuExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyTpuExecutable::Delete) .def("Execute", &PyTpuExecutable::Execute, - py::call_guard(), py::arg("arguments")) + py::call_guard(), py::arg("arguments"), + py::arg("tuple_arguments")) .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, - py::call_guard(), py::arg("arguments")); + py::call_guard(), py::arg("arguments"), + py::arg("tuple_arguments")); py::class_>(m, "TpuDevice") .def_property_readonly("coords", &TpuDevice::coords) diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index ceea02f2374..d42636cde79 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -1133,20 +1133,6 @@ PYBIND11_MODULE(xla_extension, m) { .def("SizeOfGeneratedCodeInBytes", &PyLocalExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyLocalExecutable::Delete) - .def( - "Execute", - [](const PyLocalExecutable& executable, - absl::Span args) - -> StatusOr> { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN( - std::vector> output, - executable.Execute(args, ExecuteOptions())); - return WrapWithClient(executable.client()->shared_from_this(), - std::move(output.front())); - }, - py::arg("arguments")) - // TODO(phawkins): remove in favor of overload that returns a vector. .def( "Execute", [](const PyLocalExecutable& executable, @@ -1168,27 +1154,6 @@ PYBIND11_MODULE(xla_extension, m) { return outputs; }, py::arg("arguments"), py::arg("tuple_arguments")) - // TODO(phawkins): remove in favor of overload that returns a vector. - .def( - "ExecuteOnLocalDevices", - [](const PyLocalExecutable& executable, - absl::Span> args) - -> StatusOr>> { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN( - std::vector>> - output_buffers, - executable.ExecuteOnLocalDevices(args, ExecuteOptions())); - std::vector> outputs; - outputs.reserve(output_buffers.size()); - for (auto& buffers : output_buffers) { - outputs.push_back( - WrapWithClient(executable.client()->shared_from_this(), - std::move(buffers.front()))); - } - return outputs; - }, - py::arg("arguments")) .def( "ExecuteOnLocalDevices", [](const PyLocalExecutable& executable, diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index b6948b6d84d..d4df503677c 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -42,7 +42,6 @@ from tensorflow.compiler.xla.python.xla_extension import ops # consistency with XLA. # pylint: disable=invalid-name - profiler = _xla.profiler @@ -454,8 +453,8 @@ def transfer_to_infeed(value, device=None): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - device: the device to infeed the value to. Each device has a - distinct infeed queue. + device: the device to infeed the value to. Each device has a distinct infeed + queue. """ # TODO(phawkins): support non-default backends. backend = get_local_backend() @@ -501,7 +500,6 @@ def computation_count(): '''Returns the number of computations per replica.''' """ - Device = _xla.Device @@ -633,7 +631,8 @@ def execute_with_python_values(executable, arguments=(), backend=None): arg, device=executable.local_devices()[0], backend=backend) arguments = [put(arg) for arg in arguments] - return executable.Execute(arguments).to_py() + outputs = executable.Execute(arguments, tuple_arguments=False) + return [x.to_py() for x in outputs] def execute_with_python_values_replicated(executable, arguments, backend=None): @@ -641,8 +640,8 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): Arguments: executable: the program to run. - arguments: a list of lists of Python values indexed by - `[replica][arg_num]` to pass as inputs. + arguments: a list of lists of Python values indexed by `[replica][arg_num]` + to pass as inputs. backend: the backend we are targeting. Returns: @@ -661,7 +660,8 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): for replica_args in arguments: arg_buffers.append(flat_arg_buffers[:len(replica_args)]) flat_arg_buffers = flat_arg_buffers[len(replica_args):] - return [out.to_py() for out in executable.ExecuteOnLocalDevices(arg_buffers)] + return [[x.to_py() for x in xs] for xs in executable.ExecuteOnLocalDevices( + arg_buffers, tuple_arguments=False)] class PaddingType(enum.Enum): @@ -787,6 +787,7 @@ class ComputationBuilder(object): shape: a `Shape` describing the shape of the infed value. token: an optional `XlaOp` representing a token after which the infeed effect should be sequenced. + Returns: An XlaOp, representing a (value, token) pair. """ @@ -805,6 +806,7 @@ class ComputationBuilder(object): operand: an `XlaOp` representing the data to outfeed. token: an `XlaOp` representing a token after which the outfeed should be sequenced. + Returns: An `XlaOp` representing a token. """ @@ -880,7 +882,10 @@ class ComputationBuilder(object): """ return self.Constant(np.array(value, dtype=np.bool)) - def ParameterWithShape(self, shape, name=None, parameter_num=None, + def ParameterWithShape(self, + shape, + name=None, + parameter_num=None, replicated=False): """Enqueues a Parameter op onto the computation, given a shape. @@ -891,8 +896,8 @@ class ComputationBuilder(object): next linear parameter number is used. The default value capability can be used for auto-numbering. If you're using auto-numbering for some parameters, use it for *all* parameters to avoid clashes. - replicated: whether to mark the parameter's leaves as replicated. May be - a bool, in which case it applies to all leaves, or an iterable of bools. + replicated: whether to mark the parameter's leaves as replicated. May be a + bool, in which case it applies to all leaves, or an iterable of bools. Returns: An XlaOp. @@ -1791,6 +1796,7 @@ def register_custom_call_target(name, fn, platform='cpu'): """ _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform]) + # Deprecated. Use register_custom_call_target instead. register_cpu_custom_call_target = register_custom_call_target diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 72b536ade68..b28a97837fe 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -55,12 +55,14 @@ class ComputationTest(absltest.TestCase): def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): assert expected is not None - result = self._Execute(c, arguments) - # Numpy's comparison methods are a bit too lenient by treating inputs as - # "array-like", meaning that scalar 4 will be happily compared equal to - # [[4]]. We'd like to be more strict so assert shapes as well. - self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) - assert_func(result, expected) + results = self._Execute(c, arguments) + self.assertLen(results, len(expected)) + for result, e in zip(results, expected): + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) + assert_func(result, e) def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) @@ -166,32 +168,32 @@ class ComputationsWithConstantsTest(ComputationTest): def testConstantScalarSumS8(self): c = self._NewComputation() c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) - self._ExecuteAndCompareExact(c, expected=np.int8(3)) + self._ExecuteAndCompareExact(c, expected=[np.int8(3)]) def testConstantScalarSumBF16(self): c = self._NewComputation() c.Add(c.Constant(bfloat16(1.11)), c.Constant(bfloat16(3.14))) - self._ExecuteAndCompareClose(c, expected=bfloat16(4.25)) + self._ExecuteAndCompareClose(c, expected=[bfloat16(4.25)]) def testConstantScalarSumF32(self): c = self._NewComputation() c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=4.25) + self._ExecuteAndCompareClose(c, expected=[4.25]) def testConstantScalarSumF64(self): c = self._NewComputation() c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=4.25) + self._ExecuteAndCompareClose(c, expected=[4.25]) def testConstantScalarSumS32(self): c = self._NewComputation() c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) - self._ExecuteAndCompareClose(c, expected=3) + self._ExecuteAndCompareClose(c, expected=[3]) def testConstantScalarSumS64(self): c = self._NewComputation() c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) - self._ExecuteAndCompareClose(c, expected=3) + self._ExecuteAndCompareClose(c, expected=[3]) def testConstantVectorMulF16(self): c = self._NewComputation() @@ -199,108 +201,108 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(np.array([2.5, 3.3, -1.2, 0.7], np.float16)), c.Constant(np.array([-1.2, 2, -2, -3], np.float16))) self._ExecuteAndCompareClose( - c, expected=np.array([-3, 6.6, 2.4, -2.1], np.float16), rtol=2e-3) + c, expected=[np.array([-3, 6.6, 2.4, -2.1], np.float16)], rtol=2e-3) def testConstantVectorMulF32(self): c = self._NewComputation() c.Mul( c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) - self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + self._ExecuteAndCompareClose(c, expected=[[-3, 6.6, 2.4, -2.1]]) def testConstantVectorMulF64(self): c = self._NewComputation() c.Mul( c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) - self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + self._ExecuteAndCompareClose(c, expected=[[-3, 6.6, 2.4, -2.1]]) def testConstantVectorScalarDivF32(self): c = self._NewComputation() c.Div( c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), c.ConstantF32Scalar(2.0)) - self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + self._ExecuteAndCompareClose(c, expected=[[0.75, 1.25, 1.5, -5.4]]) def testConstantVectorScalarDivF64(self): c = self._NewComputation() c.Div( c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), c.ConstantF64Scalar(2.0)) - self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + self._ExecuteAndCompareClose(c, expected=[[0.75, 1.25, 1.5, -5.4]]) def testConstantVectorScalarPowF32(self): c = self._NewComputation() c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) - self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) def testConstantVectorScalarPowF64(self): c = self._NewComputation() c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) - self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) def testIota(self): c = self._NewComputation() c.Iota(np.float32, 10) - self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) + self._ExecuteAndCompareExact(c, expected=[np.arange(10, dtype=np.float32)]) def testBroadcastedIota(self): c = self._NewComputation() c.BroadcastedIota(np.int64, (2, 3), 1) expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) - self._ExecuteAndCompareExact(c, expected=expected) + self._ExecuteAndCompareExact(c, expected=[expected]) def testBooleanAnd(self): c = self._NewComputation() c.And( c.Constant(NumpyArrayBool([True, False, True, False])), c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) def testBooleanOr(self): c = self._NewComputation() c.Or( c.Constant(NumpyArrayBool([True, False, True, False])), c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) def testBooleanXor(self): c = self._NewComputation() c.Xor( c.Constant(NumpyArrayBool([True, False, True, False])), c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) def testSum2DF32(self): c = self._NewComputation() c.Add( c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) - self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) def testShiftLeft(self): c = self._NewComputation() c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2]))) - self._ExecuteAndCompareClose(c, expected=[12]) + self._ExecuteAndCompareClose(c, expected=[[12]]) def testShiftRightArithmetic(self): c = self._NewComputation() c.ShiftRightArithmetic( c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1]))) - self._ExecuteAndCompareClose(c, expected=[-1]) + self._ExecuteAndCompareClose(c, expected=[[-1]]) def testShiftRightLogical(self): c = self._NewComputation() c.ShiftRightLogical( c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1]))) - self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) + self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) def testSum2DF64(self): c = self._NewComputation() c.Add( c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) - self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) def testSum2DWith1DBroadcastDim0F32(self): # sum of a 2D array with a 1D array where the latter is replicated across @@ -311,7 +313,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(NumpyArrayF32([10, 20, 30])), broadcast_dimensions=(0,)) self._ExecuteAndCompareClose( - c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) def testSum2DWith1DBroadcastDim0F64(self): # sum of a 2D array with a 1D array where the latter is replicated across @@ -322,7 +324,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(NumpyArrayF64([10, 20, 30])), broadcast_dimensions=(0,)) self._ExecuteAndCompareClose( - c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) def testSum2DWith1DBroadcastDim1F32(self): # sum of a 2D array with a 1D array where the latter is replicated across @@ -333,7 +335,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(NumpyArrayF32([10, 20, 30])), broadcast_dimensions=(1,)) self._ExecuteAndCompareClose( - c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) def testSum2DWith1DBroadcastDim1F64(self): # sum of a 2D array with a 1D array where the latter is replicated across @@ -344,7 +346,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.Constant(NumpyArrayF64([10, 20, 30])), broadcast_dimensions=(1,)) self._ExecuteAndCompareClose( - c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) def testConstantAxpyF32(self): c = self._NewComputation() @@ -353,7 +355,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.ConstantF32Scalar(2), c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), c.Constant(NumpyArrayF32([100, -100, 200, -200]))) - self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + self._ExecuteAndCompareClose(c, expected=[[104.4, -93.4, 208.8, -189]]) def testConstantAxpyF64(self): c = self._NewComputation() @@ -362,7 +364,7 @@ class ComputationsWithConstantsTest(ComputationTest): c.ConstantF64Scalar(2), c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), c.Constant(NumpyArrayF64([100, -100, 200, -200]))) - self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + self._ExecuteAndCompareClose(c, expected=[[104.4, -93.4, 208.8, -189]]) def testCustomCall(self): c = self._NewComputation() @@ -376,7 +378,7 @@ class ComputationsWithConstantsTest(ComputationTest): xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), )) - self._ExecuteAndCompareClose(c, expected=0.75) + self._ExecuteAndCompareClose(c, expected=[0.75]) class ParametersTest(ComputationTest): @@ -400,7 +402,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareClose( c, arguments=[self.f32_scalar_2, self.f32_4vector], - expected=[-4.6, 6.6, -8.6, 10.6]) + expected=[[-4.6, 6.6, -8.6, 10.6]]) def testScalarTimesVectorAutonumberF64(self): c = self._NewComputation() @@ -410,7 +412,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareClose( c, arguments=[self.f64_scalar_2, self.f64_4vector], - expected=[-4.6, 6.6, -8.6, 10.6]) + expected=[[-4.6, 6.6, -8.6, 10.6]]) def testScalarTimesVectorS32(self): c = self._NewComputation() @@ -420,7 +422,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareExact( c, arguments=[self.s32_scalar_3, self.s32_4vector], - expected=[30, 45, -6, 21]) + expected=[[30, 45, -6, 21]]) def testScalarTimesVectorS64(self): c = self._NewComputation() @@ -430,7 +432,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareExact( c, arguments=[self.s64_scalar_3, self.s64_4vector], - expected=[30, 45, -6, 21]) + expected=[[30, 45, -6, 21]]) def testScalarMinusVectorExplicitNumberingF32(self): # Use explicit numbering and pass parameter_num first. Sub is used since @@ -443,7 +445,7 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareClose( c, arguments=[self.f32_scalar_2, self.f32_4vector], - expected=[-4.3, 1.3, -6.3, 3.3]) + expected=[[-4.3, 1.3, -6.3, 3.3]]) def testScalarMinusVectorExplicitNumberingF64(self): # Use explicit numbering and pass parameter_num first. Sub is used since @@ -456,28 +458,22 @@ class ParametersTest(ComputationTest): self._ExecuteAndCompareClose( c, arguments=[self.f64_scalar_2, self.f64_4vector], - expected=[-4.3, 1.3, -6.3, 3.3]) + expected=[[-4.3, 1.3, -6.3, 3.3]]) class BufferTest(ComputationTest): """Tests focusing on execution with Buffers.""" - def _Execute(self, c, arguments): - compiled_c = c.Build().Compile() - arg_buffers = [xla_client.Buffer.from_pyval(arg) for arg in arguments] - result_buffer = compiled_c.Execute(arg_buffers) - return result_buffer.to_py() - def testConstantSum(self): c = self._NewComputation() c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=4.25) + self._ExecuteAndCompareClose(c, expected=[4.25]) def testOneParameterSum(self): c = self._NewComputation() c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) self._ExecuteAndCompareClose( - c, arguments=[NumpyArrayF32(1.11)], expected=4.25) + c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) def testTwoParameterSum(self): c = self._NewComputation() @@ -485,8 +481,10 @@ class BufferTest(ComputationTest): c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ParameterFromNumpy(NumpyArrayF32(0.))) self._ExecuteAndCompareClose( - c, arguments=[NumpyArrayF32(1.11), - NumpyArrayF32(3.14)], expected=4.25) + c, + arguments=[NumpyArrayF32(1.11), + NumpyArrayF32(3.14)], + expected=[4.25]) def testCannotCallWithDeletedBuffers(self): c = self._NewComputation() @@ -496,7 +494,7 @@ class BufferTest(ComputationTest): arg_buffer = xla_client.Buffer.from_pyval(arg) arg_buffer.delete() with self.assertRaises(RuntimeError): - compiled_c.Execute([arg_buffer]) + compiled_c.Execute([arg_buffer], tuple_arguments=False) def testDestructureTupleEmpty(self): device = xla_client.get_local_backend().devices()[0] @@ -646,7 +644,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])), ) c.Concatenate(args, dimension=0) - self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]) def testConcatenateF64(self): c = self._NewComputation() @@ -655,7 +653,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])), ) c.Concatenate(args, dimension=0) - self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]) def testConvertElementType(self): xla_types = { @@ -672,11 +670,12 @@ class SingleOpTest(ComputationTest): c.ConvertElementType(x, xla_types[dst_dtype]) result = xla_client.execute_with_python_values(c.Build().Compile()) + self.assertLen(result, 1) expected = np.array(template, dtype=dst_dtype) - self.assertEqual(result.shape, expected.shape) - self.assertEqual(result.dtype, expected.dtype) - np.testing.assert_equal(result, expected) + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) x = [0, 1, 0, 0, 1] for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): @@ -699,11 +698,12 @@ class SingleOpTest(ComputationTest): c.BitcastConvertType(x, dst_etype) result = xla_client.execute_with_python_values(c.Build().Compile()) + self.assertLen(result, 1) expected = np.array(template, src_dtype).view(dst_dtype) - self.assertEqual(result.shape, expected.shape) - self.assertEqual(result.dtype, expected.dtype) - np.testing.assert_equal(result, expected) + self.assertEqual(result[0].shape, expected.shape) + self.assertEqual(result[0].dtype, expected.dtype) + np.testing.assert_equal(result[0], expected) x = [0, 1, 0, 0, 1] for xla_types in [xla_x32_types, xla_x64_types]: @@ -720,7 +720,7 @@ class SingleOpTest(ComputationTest): for lhs in samples[:1]: c = self._NewComputation() c.AllToAll(c.Constant(lhs), 0, 0) - self._ExecuteAndCompareExact(c, expected=lhs) + self._ExecuteAndCompareExact(c, expected=[lhs]) def testCrossReplicaSumOneReplica(self): samples = [ @@ -732,12 +732,12 @@ class SingleOpTest(ComputationTest): for lhs in samples: c = self._NewComputation() c.CrossReplicaSum(c.Constant(lhs)) - self._ExecuteAndCompareExact(c, expected=lhs) + self._ExecuteAndCompareExact(c, expected=[lhs]) def testReplicaId(self): c = self._NewComputation() _ = c.ReplicaId() - self._ExecuteAndCompareExact(c, expected=0) + self._ExecuteAndCompareExact(c, expected=[0]) def testCrossReplicaSumOneReplicaWithSingletonGroup(self): samples = [ @@ -749,35 +749,35 @@ class SingleOpTest(ComputationTest): for lhs in samples: c = self._NewComputation() c.CrossReplicaSum(c.Constant(lhs), [[0]]) - self._ExecuteAndCompareExact(c, expected=lhs) + self._ExecuteAndCompareExact(c, expected=[lhs]) def testDotMatrixVectorF32(self): c = self._NewComputation() lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) rhs = NumpyArrayF32([[10.0], [20.0]]) c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) def testDotMatrixVectorF64(self): c = self._NewComputation() lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) rhs = NumpyArrayF64([[10.0], [20.0]]) c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) def testDotMatrixMatrixF32(self): c = self._NewComputation() lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) def testDotMatrixMatrixF64(self): c = self._NewComputation() lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) def testDotGeneral(self): c = self._NewComputation() @@ -786,7 +786,7 @@ class SingleOpTest(ComputationTest): rhs = NumpyArrayF32(rng.randn(10, 4, 5)) dimension_numbers = (([2], [1]), ([0], [0])) c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) def testDotGeneralWithDotDimensionNumbersProto(self): c = self._NewComputation() @@ -801,7 +801,7 @@ class SingleOpTest(ComputationTest): dimension_numbers.rhs_batch_dimensions.append(0) c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) def testDotGeneralWithPrecisionConfig(self): c = self._NewComputation() @@ -817,7 +817,7 @@ class SingleOpTest(ComputationTest): c.Constant(rhs), dimension_numbers, precision_config=config) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) def testConvF32Same(self): c = self._NewComputation() @@ -831,7 +831,7 @@ class SingleOpTest(ComputationTest): [880., 940., 1000., 380.], [1120., 1180., 1240., 460.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvF32Valid(self): c = self._NewComputation() @@ -844,7 +844,7 @@ class SingleOpTest(ComputationTest): [640., 700., 760.], [1120., 1180., 1240.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvWithGeneralPaddingF32(self): c = self._NewComputation() @@ -864,7 +864,7 @@ class SingleOpTest(ComputationTest): [0., 0., 0.], [40., 50., 0.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvGeneralDilatedF32(self): c = self._NewComputation() @@ -885,7 +885,7 @@ class SingleOpTest(ComputationTest): [0., 0., 0.], [40., 50., 0.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvGeneralDilatedF32WithPrecisionConfig(self): c = self._NewComputation() @@ -915,7 +915,7 @@ class SingleOpTest(ComputationTest): [0., 0., 0.], [40., 50., 0.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testConvGeneralDilatedPermutedF32(self): c = self._NewComputation() @@ -933,7 +933,8 @@ class SingleOpTest(ComputationTest): pads, lhs_dilation, rhs_dilation, dimension_numbers) result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], [40., 50., 0.]]]]) - self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + self._ExecuteAndCompareClose( + c, expected=[np.transpose(result, (1, 3, 0, 2))]) def testConvGeneralDilatedGroupedConvolutionF32(self): c = self._NewComputation() @@ -960,92 +961,92 @@ class SingleOpTest(ComputationTest): [0., 0., 0.], [480., 530., 220.], ]]]) - self._ExecuteAndCompareClose(c, expected=result) + self._ExecuteAndCompareClose(c, expected=[result]) def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) c.Not(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=~arr) + self._ExecuteAndCompareClose(c, expected=[~arr]) def testPopulationCount(self): c = self._NewComputation() arr = NumpyArrayS32([3, 0, 1]) c.PopulationCount(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.array([2, 0, 1])) + self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) def testCountLeadingZeros(self): c = self._NewComputation() arr = NumpyArrayS32([0x7FFF, 0x12345678]) c.Clz(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[17, 3]) + self._ExecuteAndCompareClose(c, expected=[[17, 3]]) def testExp(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Exp(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) def testExpm1(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Expm1(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) def testRound(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Round(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.round(arr)) + self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) def testLog(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Log(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.log(arr)) + self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) def testLog1p(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Log1p(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) + self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) def testNeg(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Neg(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=-arr) + self._ExecuteAndCompareClose(c, expected=[-arr]) def testFloor(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Floor(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.floor(arr)) + self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) def testCeil(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Ceil(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) + self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) def testAbs(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) c.Abs(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.abs(arr)) + self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) def testTanh(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) c.Tanh(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) def testTrans(self): def _TransposeAndTest(array): c = self._NewComputation() c.Trans(c.Constant(array)) - self._ExecuteAndCompareClose(c, expected=array.T) + self._ExecuteAndCompareClose(c, expected=[array.T]) # Test square and non-square matrices in both default (C) and F orders. for array_fun in [NumpyArrayF32, NumpyArrayF64]: @@ -1060,7 +1061,7 @@ class SingleOpTest(ComputationTest): c = self._NewComputation() c.Transpose(c.Constant(array), permutation) expected = np.transpose(array, permutation) - self._ExecuteAndCompareClose(c, expected=expected) + self._ExecuteAndCompareClose(c, expected=[expected]) _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) @@ -1077,14 +1078,14 @@ class SingleOpTest(ComputationTest): c.Eq( c.Constant(NumpyArrayS32([1, 2, 3, 4])), c.Constant(NumpyArrayS32([4, 2, 3, 1]))) - self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) def testNe(self): c = self._NewComputation() c.Ne( c.Constant(NumpyArrayS32([1, 2, 3, 4])), c.Constant(NumpyArrayS32([4, 2, 3, 1]))) - self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) c.Ne( c.Constant(NumpyArrayF32([-2.0, 0.0, @@ -1092,42 +1093,44 @@ class SingleOpTest(ComputationTest): float("nan")])), c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) self._ExecuteAndAssertWith( - np.testing.assert_allclose, c, (), expected=[True, False, True, True]) + np.testing.assert_allclose, c, (), expected=[[True, False, True, True]]) def testGt(self): c = self._NewComputation() c.Gt( c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) + self._ExecuteAndCompareExact( + c, expected=[[False, True, True, False, False]]) def testGe(self): c = self._NewComputation() c.Ge( c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) + self._ExecuteAndCompareExact(c, expected=[[True, True, True, False, False]]) def testLt(self): c = self._NewComputation() c.Lt( c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) + self._ExecuteAndCompareExact( + c, expected=[[False, False, False, True, True]]) def testLe(self): c = self._NewComputation() c.Le( c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, True, True]]) def testMax(self): c = self._NewComputation() c.Max( c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) - self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) + self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) def testMaxExplicitBroadcastDim0(self): c = self._NewComputation() @@ -1135,7 +1138,8 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), c.Constant(NumpyArrayF32([3, 4, 5])), broadcast_dimensions=(0,)) - self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) + self._ExecuteAndCompareExact( + c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) def testMaxExplicitBroadcastDim1(self): c = self._NewComputation() @@ -1143,14 +1147,15 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), c.Constant(NumpyArrayF32([3, 4, 5])), broadcast_dimensions=(1,)) - self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) + self._ExecuteAndCompareExact( + c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) def testMin(self): c = self._NewComputation() c.Min( c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) - self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) + self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) def testPad(self): c = self._NewComputation() @@ -1159,8 +1164,8 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)]) self._ExecuteAndCompareClose( c, - expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) def testPadWithPaddingConfig(self): c = self._NewComputation() @@ -1176,8 +1181,8 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32(0.0)), padding_config) self._ExecuteAndCompareClose( c, - expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) def testReshape(self): c = self._NewComputation() @@ -1185,14 +1190,14 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), dimensions=[0, 1], new_sizes=[2, 3]) - self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) def testCollapse(self): c = self._NewComputation() c.Collapse( c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), dimensions=[1, 2]) - self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) def testRev(self): c = self._NewComputation() @@ -1200,7 +1205,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), dimensions=[0, 2]) self._ExecuteAndCompareExact( - c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) + c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) def testReducePrecision(self): c = self._NewComputation() @@ -1208,7 +1213,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), exponent_bits=8, mantissa_bits=7) - self._ExecuteAndCompareClose(c, expected=[float.fromhex("0x1.32p-3")]) + self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) def testClampF32(self): c = self._NewComputation() @@ -1216,7 +1221,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayF32(-1)), c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), c.Constant(NumpyArrayF32(2))) - self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) def testClampS32(self): c = self._NewComputation() @@ -1224,7 +1229,7 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayS32(-1)), c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), c.Constant(NumpyArrayS32(2))) - self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) def testSelect(self): c = self._NewComputation() @@ -1232,14 +1237,14 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayBool([True, False, False, True, False])), c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) - self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) + self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) def testSlice(self): c = self._NewComputation() c.Slice( c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], [3, 2]) - self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) def testSliceInDim(self): c = self._NewComputation() @@ -1249,21 +1254,21 @@ class SingleOpTest(ComputationTest): limit_index=2, stride=1, dimno=1) - self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]]) + self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) c.SliceInDim( c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), start_index=0, limit_index=3, stride=2, dimno=0) - self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) def testDynamicSlice(self): c = self._NewComputation() c.DynamicSlice( c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), c.Constant(NumpyArrayS32([1, 0])), [2, 2]) - self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) def testDynamicUpdateSlice(self): c = self._NewComputation() @@ -1271,7 +1276,8 @@ class SingleOpTest(ComputationTest): c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), c.Constant(NumpyArrayS32([1, 1]))) - self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) + self._ExecuteAndCompareExact( + c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) def testTuple(self): c = self._NewComputation() @@ -1279,7 +1285,7 @@ class SingleOpTest(ComputationTest): c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), c.Constant(NumpyArrayBool([True, False, False, True]))) result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertIsInstance(result, tuple) + self.assertLen(result, 3) np.testing.assert_equal(result[0], 42) np.testing.assert_allclose(result[1], [1.0, 2.0]) np.testing.assert_equal(result[2], [True, False, False, True]) @@ -1290,20 +1296,20 @@ class SingleOpTest(ComputationTest): c.Tuple( c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), c.Constant(NumpyArrayBool([True, False, False, True]))), 1) - self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) def testBroadcast(self): c = self._NewComputation() c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) self._ExecuteAndCompareExact( - c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) + c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) def testBroadcastInDim(self): c = self._NewComputation() c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0]) - self._ExecuteAndCompareExact(c, expected=[[1, 1], [2, 2]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1]) - self._ExecuteAndCompareExact(c, expected=[[1, 2], [1, 2]]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) def testRngNormal(self): shape = (2, 3) @@ -1314,8 +1320,9 @@ class SingleOpTest(ComputationTest): dims=shape) result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape and uniqueness - self.assertEqual(result.shape, shape) - self.assertLen(np.unique(result), np.prod(shape)) + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) def testRngUniformF32(self): lo, hi = 2., 4. @@ -1327,10 +1334,11 @@ class SingleOpTest(ComputationTest): dims=shape) result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape, uniqueness, and range - self.assertEqual(result.shape, shape) - self.assertLen(np.unique(result), np.prod(shape)) - self.assertTrue(np.all(lo <= result)) - self.assertTrue(np.all(result < hi)) + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) def testRngUniformS32(self): lo, hi = 2, 4 @@ -1342,24 +1350,25 @@ class SingleOpTest(ComputationTest): dims=shape) result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape, integrality, and range - self.assertEqual(result.shape, shape) - self.assertEqual(result.dtype, np.int32) - self.assertTrue(np.all(lo <= result)) - self.assertTrue(np.all(result < hi)) + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertEqual(result[0].dtype, np.int32) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) def testCholesky(self): l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], dtype=np.float32) c = self._NewComputation() c.Cholesky(c.Constant(np.dot(l, l.T))) - self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) + self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) def testSort(self): keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) c = self._NewComputation() c.Sort(c.Constant(keys)) self._ExecuteAndCompareClose( - c, expected=np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)) + c, expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) def testSortKeyVal(self): keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) @@ -1367,7 +1376,7 @@ class SingleOpTest(ComputationTest): c = self._NewComputation() c.Sort((c.Constant(keys), c.Constant(values)), dimension=0) result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertIsInstance(result, tuple) + self.assertLen(result, 2) np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) @@ -1387,7 +1396,7 @@ class SingleOpTest(ComputationTest): dimension=1, comparator=comparator) result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertIsInstance(result, tuple) + self.assertLen(result, 2) np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) @@ -1437,12 +1446,14 @@ class SingleOpTest(ComputationTest): transpose_a=True) self._ExecuteAndCompareClose( c, - expected=np.array([ - [0.5, 0.08333334, 0.04629629, 0.03367003], - [2.5, -0.25, -0.1388889, -0.1010101], - [4.5, -0.58333331, -0.32407406, -0.23569024], + expected=[ + np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], + dtype=np.float32) ], - dtype=np.float32), rtol=1e-4) def testIsConstant(self): @@ -1467,7 +1478,7 @@ class SingleOpTest(ComputationTest): dnums.index_vector_dim = 2 c = self._NewComputation() c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) - g = self._Execute(c, ()) + g, = self._Execute(c, ()) expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) np.testing.assert_allclose(g, expected, rtol=1e-4) @@ -1480,30 +1491,30 @@ class SingleOpTest(ComputationTest): c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:]) self._ExecuteAndCompareClose( - c, expected=np.fft.fftn(a, axes=(1, 2, 3)), rtol=1e-4) + c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) # IFFT c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:]) self._ExecuteAndCompareClose( - c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), rtol=1e-4) + c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) # RFFT b = rng.randn(*shape).astype(np.float32) c = self._NewComputation() c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:]) self._ExecuteAndCompareClose( - c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), rtol=1e-4) + c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) # IRFFT c = self._NewComputation() c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8]) self._ExecuteAndCompareClose( - c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), rtol=1e-4) + c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4) def testNextAfter(self): c = self._NewComputation() c.NextAfter( c.Constant(np.array([1, 2], dtype=np.float32)), c.Constant(np.array([2, 1], dtype=np.float32))) - out = self._Execute(c, ()) + out, = self._Execute(c, ()) eps = np.finfo(np.float32).eps np.testing.assert_equal(np.array([eps + 1, 2 - eps], dtype=np.float32), out) @@ -1515,7 +1526,7 @@ class SingleOpTest(ComputationTest): c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x)) expected = np.array( [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) - self._ExecuteAndCompareClose(c, expected=expected, rtol=1e-4) + self._ExecuteAndCompareClose(c, expected=[expected], rtol=1e-4) class EmbeddedComputationsTest(ComputationTest): @@ -1656,38 +1667,38 @@ class EmbeddedComputationsTest(ComputationTest): c.Call( self._CreateMulF32By2Computation(), operands=(c.ConstantF32Scalar(5.0),)) - self._ExecuteAndCompareClose(c, expected=10.0) + self._ExecuteAndCompareClose(c, expected=[10.0]) def testCallF64(self): c = self._NewComputation() c.Call( self._CreateMulF64By2Computation(), operands=(c.ConstantF64Scalar(5.0),)) - self._ExecuteAndCompareClose(c, expected=10.0) + self._ExecuteAndCompareClose(c, expected=[10.0]) def testMapEachElementToS32Constant(self): c = self._NewComputation() c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], self._CreateConstantS32Computation(), [0]) - self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) def testMapEachElementToS64Constant(self): c = self._NewComputation() c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], self._CreateConstantS64Computation(), [0]) - self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) def testMapMulBy2F32(self): c = self._NewComputation() c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], self._CreateMulF32By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) def testMapMulBy2F64(self): c = self._NewComputation() c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], self._CreateMulF64By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) def testSimpleMapChainF32(self): # Chains a map of constant-f32 with a map of mul-by-2 @@ -1695,7 +1706,7 @@ class EmbeddedComputationsTest(ComputationTest): const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], self._CreateConstantF32Computation(), [0]) c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) def testSimpleMapChainF64(self): # Chains a map of constant-f64 with a map of mul-by-2 @@ -1703,21 +1714,21 @@ class EmbeddedComputationsTest(ComputationTest): const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], self._CreateConstantF64Computation(), [0]) c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) def testDivVectorsWithMapF32(self): c = self._NewComputation() c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), self._CreateBinaryDivF32Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + self._ExecuteAndCompareClose(c, expected=[[0.2, 0.4, 0.75, 1.0]]) def testDivVectorsWithMapF64(self): c = self._NewComputation() c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), self._CreateBinaryDivF64Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + self._ExecuteAndCompareClose(c, expected=[[0.2, 0.4, 0.75, 1.0]]) def testSelectAndScatterF32(self): c = self._NewComputation() @@ -1730,7 +1741,7 @@ class EmbeddedComputationsTest(ComputationTest): source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), init_value=c.Constant(NumpyArrayF32(1)), scatter=self._CreateBinaryAddF32Computation()) - self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) + self._ExecuteAndCompareClose(c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]]) def testSelectAndScatterF64(self): c = self._NewComputation() @@ -1743,7 +1754,7 @@ class EmbeddedComputationsTest(ComputationTest): source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), init_value=c.Constant(NumpyArrayF64(1)), scatter=self._CreateBinaryAddF64Computation()) - self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) + self._ExecuteAndCompareClose(c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]]) def testReduce1DtoScalarF32(self): c = self._NewComputation() @@ -1752,7 +1763,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF32Scalar(0), computation_to_apply=self._CreateBinaryAddF32Computation(), dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=10) + self._ExecuteAndCompareClose(c, expected=[10]) def testReduce1DtoScalarF64(self): c = self._NewComputation() @@ -1761,7 +1772,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF64Scalar(0), computation_to_apply=self._CreateBinaryAddF64Computation(), dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=10) + self._ExecuteAndCompareClose(c, expected=[10]) def testReduce2DTo1DDim0F32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1771,7 +1782,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF32Scalar(0), computation_to_apply=self._CreateBinaryAddF32Computation(), dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + self._ExecuteAndCompareClose(c, expected=[[5, 7, 9]]) def testReduce2DTo1DDim0F64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1781,7 +1792,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF64Scalar(0), computation_to_apply=self._CreateBinaryAddF64Computation(), dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + self._ExecuteAndCompareClose(c, expected=[[5, 7, 9]]) def testReduce2DTo1DDim1F32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1791,7 +1802,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF32Scalar(0), computation_to_apply=self._CreateBinaryAddF32Computation(), dimensions=[1]) - self._ExecuteAndCompareClose(c, expected=[6, 15]) + self._ExecuteAndCompareClose(c, expected=[[6, 15]]) def testReduce2DTo1DDim1F64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1801,7 +1812,7 @@ class EmbeddedComputationsTest(ComputationTest): init_value=c.ConstantF64Scalar(0), computation_to_apply=self._CreateBinaryAddF64Computation(), dimensions=[1]) - self._ExecuteAndCompareClose(c, expected=[6, 15]) + self._ExecuteAndCompareClose(c, expected=[[6, 15]]) def testReduce3DAllPossibleWaysF32(self): input_array = self._MakeSample3DArrayF32() @@ -1814,7 +1825,7 @@ class EmbeddedComputationsTest(ComputationTest): computation_to_apply=self._CreateBinaryAddF32Computation(), dimensions=dims) self._ExecuteAndCompareClose( - c, expected=np.sum(input_array, axis=tuple(dims))) + c, expected=[np.sum(input_array, axis=tuple(dims))]) _ReduceAndTest(0) _ReduceAndTest(0, 1) @@ -1833,7 +1844,7 @@ class EmbeddedComputationsTest(ComputationTest): computation_to_apply=self._CreateBinaryAddF64Computation(), dimensions=dims) self._ExecuteAndCompareClose( - c, expected=np.sum(input_array, axis=tuple(dims))) + c, expected=[np.sum(input_array, axis=tuple(dims))]) _ReduceAndTest(0) _ReduceAndTest(0) @@ -1852,7 +1863,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 1), padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) def testReduceWindowSameUnitStridesF32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1864,7 +1875,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 1), padding=xla_client.PaddingType.SAME) - self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) def testReduceWindowValidGeneralStridesF32(self): input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1876,7 +1887,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 2), padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) def testReduceWindowValidUnitStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1888,7 +1899,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 1), padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) def testReduceWindowSameUnitStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1900,7 +1911,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 1), padding=xla_client.PaddingType.SAME) - self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) def testReduceWindowValidGeneralStridesF64(self): input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -1912,7 +1923,7 @@ class EmbeddedComputationsTest(ComputationTest): window_dimensions=(2, 1), window_strides=(1, 2), padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) + self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) def testWhileF32(self): cond = self._CreateTestF32Lt10Computation() @@ -1920,7 +1931,7 @@ class EmbeddedComputationsTest(ComputationTest): c = self._NewComputation() init = c.ConstantF32Scalar(1.) c.While(cond, body, init) - self._ExecuteAndCompareClose(c, expected=16.) + self._ExecuteAndCompareClose(c, expected=[16.]) def testWhileF64(self): cond = self._CreateTestF64Lt10Computation() @@ -1928,7 +1939,7 @@ class EmbeddedComputationsTest(ComputationTest): c = self._NewComputation() init = c.ConstantF64Scalar(1.) c.While(cond, body, init) - self._ExecuteAndCompareClose(c, expected=16.) + self._ExecuteAndCompareClose(c, expected=[16.]) def testConditionalTrue(self): c = self._NewComputation() @@ -1939,7 +1950,7 @@ class EmbeddedComputationsTest(ComputationTest): false_computation = self._CreateConstantF32Computation() c.Conditional(pred, true_operand, true_computation, false_operand, false_computation) - self._ExecuteAndCompareClose(c, expected=6.) + self._ExecuteAndCompareClose(c, expected=[6.]) def testConditionalFalse(self): c = self._NewComputation() @@ -1950,7 +1961,7 @@ class EmbeddedComputationsTest(ComputationTest): false_computation = self._CreateConstantF32Computation() c.Conditional(pred, true_operand, true_computation, false_operand, false_computation) - self._ExecuteAndCompareClose(c, expected=1.) + self._ExecuteAndCompareClose(c, expected=[1.]) def testInfeedS32Values(self): to_infeed = NumpyArrayS32([1, 2, 3, 4]) @@ -1961,7 +1972,7 @@ class EmbeddedComputationsTest(ComputationTest): xla_client.transfer_to_infeed(item) for item in to_infeed: - result = xla_client.execute_with_python_values(compiled_c) + result, = xla_client.execute_with_python_values(compiled_c) self.assertEqual(result, item) def testInfeedTuple(self): @@ -1972,6 +1983,7 @@ class EmbeddedComputationsTest(ComputationTest): xla_client.transfer_to_infeed(to_infeed) result = xla_client.execute_with_python_values(compiled_c) + self.assertLen(result, 2) np.testing.assert_equal(result[0], to_infeed[0]) np.testing.assert_equal(result[1], to_infeed[1]) @@ -1986,7 +1998,8 @@ class EmbeddedComputationsTest(ComputationTest): compiled_c = c.Build().Compile() for want in to_round_trip: - execution = threading.Thread(target=lambda: compiled_c.Execute([])) + execution = threading.Thread( + target=lambda: compiled_c.Execute([], tuple_arguments=False)) execution.start() xla_client.transfer_to_infeed(want) got = xla_client.transfer_from_outfeed( @@ -2010,7 +2023,7 @@ class EmbeddedComputationsTest(ComputationTest): c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), self._CreateBinaryAddS32Computation(), dnums) expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) - self._ExecuteAndCompareClose(c, expected=expected) + self._ExecuteAndCompareClose(c, expected=[expected]) class ErrorTest(ComputationTest): @@ -2063,7 +2076,7 @@ class ComputationRootTest(ComputationTest): arg = NumpyArrayF32(1.0) compiled_c = c.Build(result).Compile() - ans = xla_client.execute_with_python_values(compiled_c, [arg]) + ans, = xla_client.execute_with_python_values(compiled_c, [arg]) np.testing.assert_allclose(ans, 4.14) @@ -2086,7 +2099,7 @@ class SetShardingTest(ComputationTest): extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable arg = NumpyArrayF32(1.0) compiled_c = c.Build(result).Compile() - ans = xla_client.execute_with_python_values(compiled_c, [arg]) + ans, = xla_client.execute_with_python_values(compiled_c, [arg]) np.testing.assert_allclose(ans, 4.14)