diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index e99ba05369d..8db66c24bd5 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -98,7 +98,7 @@ class TpuBackend(xla_client.Backend): options.debug_options.xla_cpu_fast_math_honor_division = True options.debug_options.xla_cpu_fast_math_honor_functions = True options.debug_options.xla_gpu_enable_fast_min_max = False - return _tpu_client.TpuExecutable.Compile(c_computation, + return _tpu_client.TpuExecutable.compile(c_computation, compile_options.argument_layouts, options, self.client, compile_options.device_assignment, @@ -106,14 +106,8 @@ class TpuBackend(xla_client.Backend): def get_default_device_assignment(self, num_replicas, num_partitions=None): if num_partitions is not None: - return self.client.GetDefaultDeviceAssignment(num_replicas, - num_partitions) + return self.client.get_default_device_assignment(num_replicas, + num_partitions) else: # TODO(henrytan): delete this case after all callers can handle 2D output - return self.client.GetDefaultDeviceAssignment(num_replicas) - - def serialize(self, executable): - return self.client.SerializeExecutable(executable) - - def deserialize(self, serialized_executable): - return self.client.DeserializeExecutable(serialized_executable, self.client) + return self.client.get_default_device_assignment(num_replicas) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 83a3e5b3db9..e9a0d2df592 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -37,7 +37,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("devices", &PyTpuClient::devices) .def("local_devices", &PyTpuClient::local_devices) .def("host_id", &PyTpuClient::host_id) - .def("GetDefaultDeviceAssignment", + .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas, int num_partitions) -> StatusOr>>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, @@ -57,7 +57,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { return result; }) // TODO(skye): delete after all callers can handle 2D output - .def("GetDefaultDeviceAssignment", + .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas) -> StatusOr>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, @@ -72,14 +72,14 @@ PYBIND11_MODULE(tpu_client_extension, m) { } return result; }) - .def("TransferToInfeed", + .def("transfer_to_infeed", [](PyTpuClient* client, const LiteralSlice& literal, int device_ordinal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; return client->TransferToInfeed(literal, device_ordinal); }) - .def("TransferFromOutfeed", + .def("transfer_from_outfeed", [](PyTpuClient* client, const Shape& shape, int device_ordinal) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); @@ -159,9 +159,9 @@ PYBIND11_MODULE(tpu_client_extension, m) { }); py::class_(m, "TpuExecutable") - .def_static("Compile", &PyTpuExecutable::Compile, + .def_static("compile", &PyTpuExecutable::Compile, py::call_guard()) - .def_static("Compile", + .def_static("compile", [](const XlaComputation& computation, absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, @@ -184,12 +184,17 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("local_logical_device_ids", &PyTpuExecutable::local_logical_device_ids) .def("local_devices", &PyTpuExecutable::local_devices) - .def("SizeOfGeneratedCodeInBytes", + .def("size_of_generated_code_in_bytes", &PyTpuExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyTpuExecutable::Delete) .def("Execute", &PyTpuExecutable::Execute, py::call_guard(), py::arg("arguments")) .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, + py::call_guard(), py::arg("arguments")) + .def("delete", &PyTpuExecutable::Delete) + .def("execute", &PyTpuExecutable::Execute, + py::call_guard(), py::arg("arguments")) + .def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices, py::call_guard(), py::arg("arguments")); py::class_>(m, "TpuDevice") diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 914c4deb6d2..34d79c30125 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -676,7 +676,7 @@ void BuildProfilerSubmodule(py::module* m) { traceme_class.def(py::init()) .def("__enter__", &TraceMeContextManager::Enter) .def("__exit__", &TraceMeContextManager::Exit) - .def_static("IsEnabled", &TraceMeContextManager::IsEnabled); + .def_static("is_enabled", &TraceMeContextManager::IsEnabled); } } // namespace @@ -880,6 +880,7 @@ PYBIND11_MODULE(xla_extension, m) { .def_property_readonly("platform", &Device::platform_name) .def_property_readonly("device_kind", &Device::device_kind) .def("__str__", &Device::DebugString) + // TODO(phawkins): remove capitalized names after updating callers. .def("TransferToInfeed", [](const Device& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); @@ -891,6 +892,33 @@ PYBIND11_MODULE(xla_extension, m) { }) .def( "TransferFromOutfeed", + [](const Device& device, const Shape& shape) -> StatusOr { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal_shared; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device.GetLocalDeviceState()); + TF_ASSIGN_OR_RETURN( + Literal literal, + local_device->client()->TransferFromOutfeedLocal( + shape, local_device->device_ordinal())); + + literal_shared = std::make_shared(std::move(literal)); + } + return LiteralToPython(std::move(literal_shared)); + }) + .def("transfer_to_infeed", + [](const Device& device, const LiteralSlice& literal) { + GlobalPyRefManager()->CollectGarbage(); + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device.GetLocalDeviceState()); + return local_device->client()->TransferToInfeedLocal( + literal, local_device->device_ordinal()); + }) + .def( + "transfer_from_outfeed", [](const Device& device, const Shape& shape) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); std::shared_ptr literal_shared; @@ -921,7 +949,7 @@ PYBIND11_MODULE(xla_extension, m) { // Local XLA client methods. // Custom-call targets. - m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget); + m.def("register_custom_call_target", &PyRegisterCustomCallTarget); py::class_ alloc_config(m, "GpuAllocatorConfig"); alloc_config.def(py::init<>()) @@ -955,7 +983,7 @@ PYBIND11_MODULE(xla_extension, m) { return devices; }) .def("host_id", &PyLocalClient::host_id) - .def("GetDefaultDeviceAssignment", + .def("get_default_device_assignment", [](std::shared_ptr client, int num_replicas, int num_partitions) -> StatusOr>>> { @@ -976,7 +1004,7 @@ PYBIND11_MODULE(xla_extension, m) { return result; }) // TODO(skye): delete after all callers can handle 2D output - .def("GetDefaultDeviceAssignment", + .def("get_default_device_assignment", [](std::shared_ptr client, int num_replicas) -> StatusOr>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, @@ -991,15 +1019,15 @@ PYBIND11_MODULE(xla_extension, m) { } return result; }) - .def("CreateChannelHandle", + .def("create_channel_handle", [](PyLocalClient* client) { return client->client()->CreateChannelHandle(); }) - .def("CreateDeviceToHostChannelHandle", + .def("create_device_to_host_channel_handle", [](PyLocalClient* client) { return client->client()->CreateDeviceToHostChannelHandle(); }) - .def("CreateHostToDeviceChannelHandle", [](PyLocalClient* client) { + .def("create_host_to_device_channel_handle", [](PyLocalClient* client) { return client->client()->CreateHostToDeviceChannelHandle(); }); @@ -1119,7 +1147,7 @@ PYBIND11_MODULE(xla_extension, m) { py::class_> executable(m, "LocalExecutable"); executable - .def_static("Compile", + .def_static("compile", [](const XlaComputation& computation, absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, @@ -1146,7 +1174,7 @@ PYBIND11_MODULE(xla_extension, m) { return WrapWithClient(std::move(client), std::move(executable)); }) - .def_static("Compile", + .def_static("compile", [](const XlaComputation& computation, absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, @@ -1189,8 +1217,10 @@ PYBIND11_MODULE(xla_extension, m) { } return devices; }) - .def("SizeOfGeneratedCodeInBytes", + .def("size_of_generated_code_in_bytes", &PyLocalExecutable::SizeOfGeneratedCodeInBytes) + .def("delete", &PyLocalExecutable::Delete) + // TODO(phawkins): delete capitalized methods after updating callers. .def("Delete", &PyLocalExecutable::Delete) .def( "Execute", @@ -1212,6 +1242,27 @@ PYBIND11_MODULE(xla_extension, m) { return outputs; }, py::arg("arguments")) + .def( + "execute", + [](const PyLocalExecutable& executable, + absl::Span args) + -> StatusOr>> { + py::gil_scoped_release gil_release; + ExecuteOptions options; + options.untuple_result = true; + TF_ASSIGN_OR_RETURN( + std::vector> output_buffers, + executable.Execute(args, options)); + std::vector> outputs; + outputs.reserve(output_buffers.size()); + for (auto& buffer : output_buffers) { + outputs.push_back(WrapWithClient( + executable.client()->shared_from_this(), std::move(buffer))); + } + return outputs; + }, + py::arg("arguments")) + // TODO(phawkins): delete capitalized methods after updating callers. .def( "ExecuteOnLocalDevices", [](const PyLocalExecutable& executable, @@ -1239,7 +1290,33 @@ PYBIND11_MODULE(xla_extension, m) { }, py::arg("arguments")) .def( - "get_hlo_modules", + "execute_on_local_devices", + [](const PyLocalExecutable& executable, + absl::Span> args) + -> StatusOr< + std::vector>>> { + py::gil_scoped_release gil_release; + ExecuteOptions options; + options.untuple_result = true; + TF_ASSIGN_OR_RETURN( + std::vector>> + output_buffers, + executable.ExecuteOnLocalDevices(args, options)); + std::vector>> outputs; + outputs.resize(output_buffers.size()); + for (int computation = 0; computation < output_buffers.size(); + ++computation) { + for (auto& buffer : output_buffers[computation]) { + outputs[computation].push_back( + WrapWithClient(executable.client()->shared_from_this(), + std::move(buffer))); + } + } + return outputs; + }, + py::arg("arguments")) + .def( + "hlo_modules", [](const PyLocalExecutable& executable) -> StatusOr>> { std::vector> modules; @@ -1298,12 +1375,19 @@ PYBIND11_MODULE(xla_extension, m) { proto.ParseFromString(serialized_hlo_module_proto); return absl::make_unique(proto); })) + // TODO(phawkins): delete capitalized names after updating callers. .def("GetProgramShape", &XlaComputation::GetProgramShape) .def("GetSerializedProto", &GetComputationSerializedProto) .def("GetHloText", &GetComputationHloText) .def("GetHloDotGraph", &GetComputationHloDotGraph) .def("Hash", &HashComputation) - .def("get_hlo_module", &GetHloModule); + .def("get_hlo_module", &GetHloModule) + .def("program_shape", &XlaComputation::GetProgramShape) + .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto) + .def("as_hlo_text", &GetComputationHloText) + .def("as_hlo_dot_graph", &GetComputationHloDotGraph) + .def("hash", &HashComputation) + .def("as_hlo_module", &GetHloModule); py::class_ hlo_print_options_class(m, "HloPrintOptions"); hlo_print_options_class.def(py::init<>()) @@ -1381,6 +1465,7 @@ PYBIND11_MODULE(xla_extension, m) { .def(py::init([](const std::string& name) -> std::unique_ptr { return absl::make_unique(UniquifyName(name)); })) + // TODO(phawkins): delete capitalized names after updating callers. .def( "Build", [](XlaBuilder& builder, absl::optional root) { @@ -1403,6 +1488,35 @@ PYBIND11_MODULE(xla_extension, m) { .def("SetSharding", &XlaBuilder::SetSharding) .def("ClearSharding", &XlaBuilder::ClearSharding) .def("SetUpAlias", + [](XlaBuilder& builder, const std::vector& output_index, + int64 param_number, const std::vector& param_index) { + builder.SetUpAlias( + ShapeIndex(output_index.begin(), output_index.end()), + param_number, + ShapeIndex(param_index.begin(), param_index.end())); + }) + .def( + "build", + [](XlaBuilder& builder, absl::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }, + "Builds a computation from the contents of the builder.", + py::arg("root") = absl::nullopt) + .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) + .def("get_shape", &XlaBuilder::GetShape) + .def( + "get_program_shape", + [](const XlaBuilder& builder, + absl::optional root) -> StatusOr { + return root ? builder.GetProgramShape(*root) + : builder.GetProgramShape(); + }, + py::arg("root") = absl::nullopt) + .def("is_constant", &XlaBuilder::IsConstant) + .def("set_op_metadata", &XlaBuilder::SetOpMetadata) + .def("set_sharding", &XlaBuilder::SetSharding) + .def("clear_sharding", &XlaBuilder::ClearSharding) + .def("setup_alias", [](XlaBuilder& builder, const std::vector& output_index, int64 param_number, const std::vector& param_index) { builder.SetUpAlias( @@ -1411,7 +1525,9 @@ PYBIND11_MODULE(xla_extension, m) { ShapeIndex(param_index.begin(), param_index.end())); }); + // TODO(phawkins): delete capitalized names after updating callers m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor); + m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor); m.def("DLPackManagedTensorToBuffer", [](const py::capsule& tensor, std::shared_ptr client) -> StatusOr> { @@ -1420,6 +1536,14 @@ PYBIND11_MODULE(xla_extension, m) { DLPackManagedTensorToBuffer(tensor, client.get())); return WrapWithClient(std::move(client), std::move(buffer)); }); + m.def("dlpack_managed_tensor_to_buffer", + [](const py::capsule& tensor, std::shared_ptr client) + -> StatusOr> { + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer, + DLPackManagedTensorToBuffer(tensor, client.get())); + return WrapWithClient(std::move(client), std::move(buffer)); + }); py::enum_(m, "PrecisionConfig_Precision") .value("DEFAULT", PrecisionConfig::DEFAULT) diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 6834dd2108d..d5b2663af53 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -147,7 +147,7 @@ class LocalBackend(Backend): options.debug_options.xla_cpu_fast_math_honor_division = True options.debug_options.xla_cpu_fast_math_honor_functions = True options.debug_options.xla_gpu_enable_fast_min_max = False - return _xla.LocalExecutable.Compile(c_computation, + return _xla.LocalExecutable.compile(c_computation, compile_options.argument_layouts, options, self.client, compile_options.device_assignment, @@ -155,11 +155,11 @@ class LocalBackend(Backend): def get_default_device_assignment(self, num_replicas, num_partitions=None): if num_partitions is not None: - return self.client.GetDefaultDeviceAssignment(num_replicas, - num_partitions) + return self.client.get_default_device_assignment(num_replicas, + num_partitions) else: # TODO(skye): delete this case after all callers can handle 2D output - return self.client.GetDefaultDeviceAssignment(num_replicas) + return self.client.get_default_device_assignment(num_replicas) xla_platform_names = { @@ -445,7 +445,7 @@ def transfer_to_infeed(value, device=None): # TODO(phawkins): support non-default backends. backend = get_local_backend() device = device or backend.local_devices()[0] - device.TransferToInfeed(value) + device.transfer_to_infeed(value) def transfer_from_outfeed(shape, device=None): @@ -462,7 +462,7 @@ def transfer_from_outfeed(shape, device=None): # TODO(phawkins): support non-default backends. backend = get_local_backend() device = device or backend.local_devices()[0] - return device.TransferFromOutfeed( + return device.transfer_from_outfeed( shape.with_major_to_minor_layout_if_absent()) @@ -542,8 +542,7 @@ def execute_with_python_values(executable, arguments=(), backend=None): backend = backend or get_local_backend() def put(arg): - return Buffer.from_pyval( - arg, device=executable.local_devices()[0], backend=backend) + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) arguments = [put(arg) for arg in arguments] outputs = executable.Execute(arguments) @@ -629,7 +628,7 @@ def register_custom_call_target(name, fn, platform='cpu'): fn: a PyCapsule object containing the function pointer. platform: the target platform. """ - _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform]) + _xla.register_custom_call_target(name, fn, xla_platform_names[platform]) # Deprecated. Use register_custom_call_target instead. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 14f87c5ebe9..a0553c6a8e9 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -82,7 +82,7 @@ def TestFactory(xla_backend, cloud_tpu=False): return xla_client.XlaBuilder(name) def _Execute(self, c, arguments): - compiled_c = self.backend.compile(c.Build()) + compiled_c = self.backend.compile(c.build()) return xla_client.execute_with_python_values( compiled_c, arguments, backend=self.backend) @@ -136,34 +136,34 @@ def TestFactory(xla_backend, cloud_tpu=False): builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) x = ops.Mul(p0, p1) ops.Add(x, x) - return builder.Build() + return builder.build() def testComputationToHloText(self): computation = self.ExampleComputation() - hlo_text = computation.GetHloText() + hlo_text = computation.as_hlo_text() self.assertTrue(hlo_text.startswith("HloModule acomputation")) def testComputationToHloGraph(self): computation = self.ExampleComputation() - hlo_dot_graph = computation.GetHloDotGraph() + hlo_dot_graph = computation.as_hlo_dot_graph() self.assertTrue(hlo_dot_graph.startswith("digraph ")) def testHloModuleToHloText(self): computation = self.ExampleComputation() - hlo_text = computation.get_hlo_module().to_string() + hlo_text = computation.as_hlo_module().to_string() self.assertTrue(hlo_text.startswith("HloModule acomputation")) def testHloModuleToHloGraph(self): computation = self.ExampleComputation() hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( - computation.get_hlo_module()) + computation.as_hlo_module()) self.assertTrue(hlo_dot_graph.startswith("digraph ")) @unittest.skipIf(cloud_tpu, "not implemented") def testCompiledHloModuleToHloText(self): computation = self.ExampleComputation() executable = self.backend.compile(computation) - hlo_modules = executable.get_hlo_modules() + hlo_modules = executable.hlo_modules() self.assertLen(hlo_modules, 1) hlo_text = hlo_modules[0].to_string() self.assertTrue(hlo_text.startswith("HloModule acomputation")) @@ -180,7 +180,7 @@ def TestFactory(xla_backend, cloud_tpu=False): p1 = ops.Parameter( builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) ops.Mul(p0, p1) - computation0 = builder0.Build() + computation0 = builder0.build() builder1 = xla_client.XlaBuilder("computation1") p0 = ops.Parameter(builder1, 0, @@ -188,9 +188,9 @@ def TestFactory(xla_backend, cloud_tpu=False): p1 = ops.Parameter( builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) ops.Mul(p0, p1) - computation1 = builder1.Build() + computation1 = builder1.build() - self.assertEqual(computation0.Hash(), computation1.Hash()) + self.assertEqual(computation0.hash(), computation1.hash()) tests.append(ComputationHashTest) @@ -396,7 +396,7 @@ def TestFactory(xla_backend, cloud_tpu=False): # Build the HLO proto b = xla_client.XlaBuilder("computation") ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) - serialized_proto = b.Build().GetSerializedProto() + serialized_proto = b.build().as_serialized_hlo_module_proto() # Load and execute the proto c = xla_client.XlaComputation(serialized_proto) @@ -478,22 +478,22 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), ops.Constant(c, np.float32(3.14))) arg = NumpyArrayF32(1.11) - compiled_c = self.backend.compile(c.Build()) - arg_buffer = xla_client.Buffer.from_pyval(arg, backend=self.backend) + compiled_c = self.backend.compile(c.build()) + arg_buffer = self.backend.buffer_from_pyval(arg) arg_buffer.delete() with self.assertRaises(RuntimeError): - compiled_c.Execute([arg_buffer]) + compiled_c.execute([arg_buffer]) def testShape(self): pyval = np.array([[1., 2.]], np.float32) - local_buffer = xla_client.Buffer.from_pyval(pyval) + local_buffer = self.backend.buffer_from_pyval(pyval) xla_shape = local_buffer.shape() self.assertEqual(xla_shape.dimensions(), (1, 2)) self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) def testBlockHostUntilReadyWorks(self): arg = np.array([[1., 2.]], np.float32) - arg_buffer = xla_client.Buffer.from_pyval(arg) + arg_buffer = self.backend.buffer_from_pyval(arg) arg_buffer.block_host_until_ready() # This test merely checks that nothing goes awry when we call # block_host_until_ready(); it's difficult to test anything else. @@ -501,8 +501,8 @@ def TestFactory(xla_backend, cloud_tpu=False): def testCopyToHost(self): arg0 = np.array([[1., 2.]], np.float32) arg1 = np.array([[3., 4.]], np.float32) - arg0_buffer = xla_client.Buffer.from_pyval(arg0) - arg1_buffer = xla_client.Buffer.from_pyval(arg1) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) # Prefetch two buffers using copy_to_host_async, and then retrieve their # values using to_py. arg0_buffer.copy_to_host_async() @@ -517,8 +517,7 @@ def TestFactory(xla_backend, cloud_tpu=False): def testDevice(self): x = np.arange(8, dtype=np.int32) for device in self.backend.local_devices(): - buf = xla_client.Buffer.from_pyval( - x, device=device, backend=self.backend) + buf = self.backend.buffer_from_pyval(x, device=device) self.assertEqual(buf.device(), device) np.testing.assert_equal(x, buf.to_py()) @@ -564,7 +563,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) result = xla_client.execute_with_python_values( - self.backend.compile(c.Build()), backend=self.backend) + self.backend.compile(c.build()), backend=self.backend) self.assertLen(result, 1) expected = np.array(x, dtype=dst_dtype) @@ -591,7 +590,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) result = xla_client.execute_with_python_values( - self.backend.compile(c.Build()), backend=self.backend) + self.backend.compile(c.build()), backend=self.backend) self.assertLen(result, 1) expected = x.view(dst_dtype) @@ -1127,7 +1126,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Constant(c, NumpyArrayBool([True, False, False, True])) ]) result = xla_client.execute_with_python_values( - self.backend.compile(c.Build()), backend=self.backend) + self.backend.compile(c.build()), backend=self.backend) self.assertLen(result, 3) np.testing.assert_equal(result[0], 42) np.testing.assert_allclose(result[1], [1.0, 2.0]) @@ -1166,7 +1165,7 @@ def TestFactory(xla_backend, cloud_tpu=False): shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape)) result = xla_client.execute_with_python_values( - self.backend.compile(c.Build()), backend=self.backend) + self.backend.compile(c.build()), backend=self.backend) # since the result is random, we just check shape and uniqueness self.assertLen(result, 1) self.assertEqual(result[0].shape, shape) @@ -1182,7 +1181,7 @@ def TestFactory(xla_backend, cloud_tpu=False): shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape)) result = xla_client.execute_with_python_values( - self.backend.compile(c.Build()), backend=self.backend) + self.backend.compile(c.build()), backend=self.backend) # since the result is random, we just check shape, uniqueness, and range self.assertLen(result, 1) self.assertEqual(result[0].shape, shape) @@ -1200,7 +1199,7 @@ def TestFactory(xla_backend, cloud_tpu=False): shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, shape)) result = xla_client.execute_with_python_values( - self.backend.compile(c.Build()), backend=self.backend) + self.backend.compile(c.build()), backend=self.backend) # since the result is random, we just check shape, integrality, and range self.assertLen(result, 1) self.assertEqual(result[0].shape, shape) @@ -1229,7 +1228,7 @@ def TestFactory(xla_backend, cloud_tpu=False): c = self._NewComputation() ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) result = xla_client.execute_with_python_values( - self.backend.compile(c.Build()), backend=self.backend) + self.backend.compile(c.build()), backend=self.backend) self.assertLen(result, 2) np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) @@ -1241,7 +1240,7 @@ def TestFactory(xla_backend, cloud_tpu=False): p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) - comparator = b.Build() + comparator = b.build() keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) @@ -1251,7 +1250,7 @@ def TestFactory(xla_backend, cloud_tpu=False): dimension=1, comparator=comparator) result = xla_client.execute_with_python_values( - self.backend.compile(c.Build())) + self.backend.compile(c.build())) self.assertLen(result, 2) np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) @@ -1321,8 +1320,8 @@ def TestFactory(xla_backend, cloud_tpu=False): x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) const_expr = ops.Sub(b, a) non_const_expr = ops.Mul(const_expr, x) - self.assertTrue(c.IsConstant(const_expr)) - self.assertFalse(c.IsConstant(non_const_expr)) + self.assertTrue(c.is_constant(const_expr)) + self.assertFalse(c.is_constant(non_const_expr)) def testGather(self): a = np.arange(9).astype(np.int32).reshape((3, 3)) @@ -1412,7 +1411,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Parameter(c, 0, xla_client.shape_from_pyval(np.array(0, dtype=in_dtype))) ops.Constant(c, out_dtype(1)) - return c.Build() + return c.build() def _CreateMulBy2Computation(self, dtype): """Computation (dtype) -> dtype that multiplies its parameter by 2.""" @@ -1423,7 +1422,7 @@ def TestFactory(xla_backend, cloud_tpu=False): xla_client.shape_from_pyval(np.array( 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), ops.Constant(c, dtype(2.0))) - return c.Build() + return c.build() def _CreateMulF32ByParamComputation(self): """Computation (f32) -> f32 that multiplies one parameter by the other.""" @@ -1431,7 +1430,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Mul( ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) - return c.Build() + return c.build() def _CreateBinaryAddComputation(self, dtype): """Computation (dtype, dtype) -> dtype that adds its two parameters.""" @@ -1439,7 +1438,7 @@ def TestFactory(xla_backend, cloud_tpu=False): shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) shape = shape.with_major_to_minor_layout_if_absent() ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) - return c.Build() + return c.build() def _CreateBinaryGeComputation(self, dtype): """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" @@ -1447,7 +1446,7 @@ def TestFactory(xla_backend, cloud_tpu=False): shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) shape = shape.with_major_to_minor_layout_if_absent() ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) - return c.Build() + return c.build() def _MakeSample3DArray(self, dtype): return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], @@ -1516,7 +1515,7 @@ def TestFactory(xla_backend, cloud_tpu=False): c = self._NewComputation("div_param0_by_param1") shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) - return c.Build() + return c.build() c = self._NewComputation() ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), @@ -1539,7 +1538,7 @@ def TestFactory(xla_backend, cloud_tpu=False): window_strides = (1, 2) padding = xla_client.window_padding_type_to_pad_values( xla_client.PaddingType.VALID, - c.GetShape(operand).dimensions(), window_dimensions, window_strides) + c.get_shape(operand).dimensions(), window_dimensions, window_strides) ops.SelectAndScatterWithGeneralPadding( operand, select=self._CreateBinaryGeComputation(dtype), @@ -1686,7 +1685,7 @@ def TestFactory(xla_backend, cloud_tpu=False): c = self._NewComputation("test_lt_10") shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) - return c.Build() + return c.build() cond = LessThan10Cond() body = self._CreateMulBy2Computation(dtype) @@ -1728,10 +1727,10 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.CreateToken(c), xla_client.shape_from_pyval( to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) - compiled_c = self.backend.compile(c.Build()) + compiled_c = self.backend.compile(c.build()) device = self.backend.local_devices()[0] for item in to_infeed: - xla_client.transfer_to_infeed(item, device=device) + device.transfer_to_infeed(item) for item in to_infeed: result, = xla_client.execute_with_python_values( @@ -1747,9 +1746,9 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.CreateToken(c), xla_client.shape_from_pyval( to_infeed).with_major_to_minor_layout_if_absent()), 0) - compiled_c = self.backend.compile(c.Build()) + compiled_c = self.backend.compile(c.build()) device = self.backend.local_devices()[0] - xla_client.transfer_to_infeed(to_infeed, device=device) + device.transfer_to_infeed(to_infeed) result = xla_client.execute_with_python_values( compiled_c, backend=self.backend) @@ -1771,14 +1770,14 @@ def TestFactory(xla_backend, cloud_tpu=False): to_round_trip[0]).with_major_to_minor_layout_if_absent() ops.OutfeedWithToken(x, token, outfeed_shape) - compiled_c = self.backend.compile(c.Build()) + compiled_c = self.backend.compile(c.build()) device = self.backend.local_devices()[0] for want in to_round_trip: - execution = threading.Thread(target=lambda: compiled_c.Execute([])) + execution = threading.Thread(target=lambda: compiled_c.execute([])) execution.start() - xla_client.transfer_to_infeed(want, device=device) - got = xla_client.transfer_from_outfeed(outfeed_shape, device=device) + device.transfer_to_infeed(want) + got = device.transfer_from_outfeed(outfeed_shape) execution.join() self.assertEqual(want, got) @@ -1811,9 +1810,9 @@ def TestFactory(xla_backend, cloud_tpu=False): def testCompileWithWrongElementTypeInLayout(self): c = self._NewComputation() - c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) - c.ClearOpMetadata() + c.clear_op_metadata() options = xla_client.CompileOptions() options.argument_layouts = [ @@ -1821,7 +1820,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ] def TestFun(): - return self.backend.compile(c.Build(), compile_options=options) + return self.backend.compile(c.build(), compile_options=options) self.assertRaisesRegex( RuntimeError, r".*Invalid argument shape.*" @@ -1829,13 +1828,13 @@ def TestFactory(xla_backend, cloud_tpu=False): def testInvokeWithWrongElementType(self): c = self._NewComputation() - c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) - c.ClearOpMetadata() + c.clear_op_metadata() def TestFun(): return xla_client.execute_with_python_values( - self.backend.compile(c.Build()), [self.f32_scalar_2]) + self.backend.compile(c.build()), [self.f32_scalar_2]) self.assertRaisesRegex( RuntimeError, r"Invalid argument: Argument does not match.*" @@ -1853,7 +1852,7 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Add(result, ops.Constant(c, np.float32(1.618))) arg = NumpyArrayF32(1.0) - compiled_c = self.backend.compile(c.Build(result)) + compiled_c = self.backend.compile(c.build(result)) ans, = xla_client.execute_with_python_values( compiled_c, [arg], backend=self.backend) np.testing.assert_allclose(ans, 4.14) @@ -1869,16 +1868,14 @@ def TestFactory(xla_backend, cloud_tpu=False): sharding.type = sharding.type.REPLICATED sharding.tile_assignment_dimensions.extend([1]) sharding.tile_assignment_devices.extend([0]) - # Set Sharding. - c.SetSharding(sharding) + c.set_sharding(sharding) x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) - # Clear Sharding. - c.ClearSharding() + c.clear_sharding() result = ops.Add(x, ops.Constant(c, np.float32(3.14))) ops.Add(result, ops.Constant(c, np.float32(1.618))) arg = NumpyArrayF32(1.0) - compiled_c = self.backend.compile(c.Build(result)) + compiled_c = self.backend.compile(c.build(result)) ans, = xla_client.execute_with_python_values( compiled_c, [arg], backend=self.backend) np.testing.assert_allclose(ans, 4.14) @@ -1898,8 +1895,8 @@ def TestFactory(xla_backend, cloud_tpu=False): xla_client.shape_from_pyval( NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) out = ops.Add(p1, p2) - c.SetUpAlias([], 0, []) - c = c.Build(out) + c.setup_alias([], 0, []) + c = c.build(out) if self.backend.platform != "tpu": with self.assertRaisesRegex( RuntimeError, "Buffer aliasing is not supported " @@ -1940,21 +1937,22 @@ def TestFactory(xla_backend, cloud_tpu=False): } for dtype in dlpack_dtypes for shape in testcase_shapes) def testRoundTrip(self, dtype, shape): x = np.array(np.random.rand(*shape) * 100, dtype=dtype) - buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) - dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) + buffer = self.backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) del buffer # Free "buffer" to make sure dlt retains ownership. self.assertEqual(type(dlt).__name__, "PyCapsule") - y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, self.backend.client) + y = xla_client._xla.dlpack_managed_tensor_to_buffer( + dlt, self.backend.client) np.testing.assert_array_equal(x, y.to_py()) def testTensorsCanBeConsumedOnceOnly(self): x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) - buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) - dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) + buffer = self.backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) def ConsumeDLPackTensor(): - _ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, - self.backend.client) + _ = xla_client._xla.dlpack_managed_tensor_to_buffer( + dlt, self.backend.client) ConsumeDLPackTensor() self.assertRaisesRegex( @@ -1981,7 +1979,7 @@ def TestFactory(xla_backend, cloud_tpu=False): def testRoundTrip(self, dtype, shape): x = np.array(np.random.rand(*shape) * 100, dtype=dtype) x_ptr = x.__array_interface__["data"][0] - buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) + buffer = self.backend.buffer_from_pyval(x) y = np.array(buffer, copy=False) y_ptr = y.__array_interface__["data"][0] np.testing.assert_array_equal(x, y) @@ -1990,15 +1988,14 @@ def TestFactory(xla_backend, cloud_tpu=False): self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr) self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) - buffer2 = xla_client.Buffer.from_pyval( - x, backend=self.backend, force_copy=True) + buffer2 = self.backend.buffer_from_pyval(x, force_copy=True) z = np.array(buffer2, copy=False) self.assertNotEqual(x.__array_interface__["data"][0], z.__array_interface__["data"][0]) def testDeleteWithActiveView(self): x = np.random.randn(20, 10) - buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) + buffer = self.backend.buffer_from_pyval(x) buffer_ptr = buffer.unsafe_buffer_pointer() y = np.array(buffer, copy=False) buffer.delete()