[XLA:Python] API cleanups.
Add or change method names on most classes to be lower_case instead of CamelCase to match the usual Python style. Renamed a few methods to read better. The old names aren't yet removed because callers (e.g., JAX) need to be updated. PiperOrigin-RevId: 308677107 Change-Id: I46993f093676402853e6032a822b308f59677e6f
This commit is contained in:
		
							parent
							
								
									71964116c5
								
							
						
					
					
						commit
						40d89f69e1
					
				| @ -98,7 +98,7 @@ class TpuBackend(xla_client.Backend): | ||||
|     options.debug_options.xla_cpu_fast_math_honor_division = True | ||||
|     options.debug_options.xla_cpu_fast_math_honor_functions = True | ||||
|     options.debug_options.xla_gpu_enable_fast_min_max = False | ||||
|     return _tpu_client.TpuExecutable.Compile(c_computation, | ||||
|     return _tpu_client.TpuExecutable.compile(c_computation, | ||||
|                                              compile_options.argument_layouts, | ||||
|                                              options, self.client, | ||||
|                                              compile_options.device_assignment, | ||||
| @ -106,14 +106,8 @@ class TpuBackend(xla_client.Backend): | ||||
| 
 | ||||
|   def get_default_device_assignment(self, num_replicas, num_partitions=None): | ||||
|     if num_partitions is not None: | ||||
|       return self.client.GetDefaultDeviceAssignment(num_replicas, | ||||
|                                                     num_partitions) | ||||
|       return self.client.get_default_device_assignment(num_replicas, | ||||
|                                                        num_partitions) | ||||
|     else: | ||||
|       # TODO(henrytan): delete this case after all callers can handle 2D output | ||||
|       return self.client.GetDefaultDeviceAssignment(num_replicas) | ||||
| 
 | ||||
|   def serialize(self, executable): | ||||
|     return self.client.SerializeExecutable(executable) | ||||
| 
 | ||||
|   def deserialize(self, serialized_executable): | ||||
|     return self.client.DeserializeExecutable(serialized_executable, self.client) | ||||
|       return self.client.get_default_device_assignment(num_replicas) | ||||
|  | ||||
| @ -37,7 +37,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { | ||||
|       .def("devices", &PyTpuClient::devices) | ||||
|       .def("local_devices", &PyTpuClient::local_devices) | ||||
|       .def("host_id", &PyTpuClient::host_id) | ||||
|       .def("GetDefaultDeviceAssignment", | ||||
|       .def("get_default_device_assignment", | ||||
|            [](PyTpuClient* client, int num_replicas, int num_partitions) | ||||
|                -> StatusOr<std::vector<std::vector<std::shared_ptr<Device>>>> { | ||||
|              TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, | ||||
| @ -57,7 +57,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { | ||||
|              return result; | ||||
|            }) | ||||
|       // TODO(skye): delete after all callers can handle 2D output
 | ||||
|       .def("GetDefaultDeviceAssignment", | ||||
|       .def("get_default_device_assignment", | ||||
|            [](PyTpuClient* client, int num_replicas) | ||||
|                -> StatusOr<std::vector<std::shared_ptr<Device>>> { | ||||
|              TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, | ||||
| @ -72,14 +72,14 @@ PYBIND11_MODULE(tpu_client_extension, m) { | ||||
|              } | ||||
|              return result; | ||||
|            }) | ||||
|       .def("TransferToInfeed", | ||||
|       .def("transfer_to_infeed", | ||||
|            [](PyTpuClient* client, const LiteralSlice& literal, | ||||
|               int device_ordinal) { | ||||
|              GlobalPyRefManager()->CollectGarbage(); | ||||
|              py::gil_scoped_release gil_release; | ||||
|              return client->TransferToInfeed(literal, device_ordinal); | ||||
|            }) | ||||
|       .def("TransferFromOutfeed", | ||||
|       .def("transfer_from_outfeed", | ||||
|            [](PyTpuClient* client, const Shape& shape, | ||||
|               int device_ordinal) -> StatusOr<py::object> { | ||||
|              GlobalPyRefManager()->CollectGarbage(); | ||||
| @ -159,9 +159,9 @@ PYBIND11_MODULE(tpu_client_extension, m) { | ||||
|       }); | ||||
| 
 | ||||
|   py::class_<PyTpuExecutable>(m, "TpuExecutable") | ||||
|       .def_static("Compile", &PyTpuExecutable::Compile, | ||||
|       .def_static("compile", &PyTpuExecutable::Compile, | ||||
|                   py::call_guard<py::gil_scoped_release>()) | ||||
|       .def_static("Compile", | ||||
|       .def_static("compile", | ||||
|                   [](const XlaComputation& computation, | ||||
|                      absl::optional<std::vector<Shape>> argument_layouts, | ||||
|                      const ExecutableBuildOptions* build_options, | ||||
| @ -184,12 +184,17 @@ PYBIND11_MODULE(tpu_client_extension, m) { | ||||
|       .def("local_logical_device_ids", | ||||
|            &PyTpuExecutable::local_logical_device_ids) | ||||
|       .def("local_devices", &PyTpuExecutable::local_devices) | ||||
|       .def("SizeOfGeneratedCodeInBytes", | ||||
|       .def("size_of_generated_code_in_bytes", | ||||
|            &PyTpuExecutable::SizeOfGeneratedCodeInBytes) | ||||
|       .def("Delete", &PyTpuExecutable::Delete) | ||||
|       .def("Execute", &PyTpuExecutable::Execute, | ||||
|            py::call_guard<py::gil_scoped_release>(), py::arg("arguments")) | ||||
|       .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, | ||||
|            py::call_guard<py::gil_scoped_release>(), py::arg("arguments")) | ||||
|       .def("delete", &PyTpuExecutable::Delete) | ||||
|       .def("execute", &PyTpuExecutable::Execute, | ||||
|            py::call_guard<py::gil_scoped_release>(), py::arg("arguments")) | ||||
|       .def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices, | ||||
|            py::call_guard<py::gil_scoped_release>(), py::arg("arguments")); | ||||
| 
 | ||||
|   py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice") | ||||
|  | ||||
| @ -676,7 +676,7 @@ void BuildProfilerSubmodule(py::module* m) { | ||||
|   traceme_class.def(py::init<py::str, py::kwargs>()) | ||||
|       .def("__enter__", &TraceMeContextManager::Enter) | ||||
|       .def("__exit__", &TraceMeContextManager::Exit) | ||||
|       .def_static("IsEnabled", &TraceMeContextManager::IsEnabled); | ||||
|       .def_static("is_enabled", &TraceMeContextManager::IsEnabled); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| @ -880,6 +880,7 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|       .def_property_readonly("platform", &Device::platform_name) | ||||
|       .def_property_readonly("device_kind", &Device::device_kind) | ||||
|       .def("__str__", &Device::DebugString) | ||||
|       // TODO(phawkins): remove capitalized names after updating callers.
 | ||||
|       .def("TransferToInfeed", | ||||
|            [](const Device& device, const LiteralSlice& literal) { | ||||
|              GlobalPyRefManager()->CollectGarbage(); | ||||
| @ -891,6 +892,33 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|            }) | ||||
|       .def( | ||||
|           "TransferFromOutfeed", | ||||
|           [](const Device& device, const Shape& shape) -> StatusOr<py::object> { | ||||
|             GlobalPyRefManager()->CollectGarbage(); | ||||
|             std::shared_ptr<Literal> literal_shared; | ||||
|             { | ||||
|               py::gil_scoped_release gil_release; | ||||
|               TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, | ||||
|                                   device.GetLocalDeviceState()); | ||||
|               TF_ASSIGN_OR_RETURN( | ||||
|                   Literal literal, | ||||
|                   local_device->client()->TransferFromOutfeedLocal( | ||||
|                       shape, local_device->device_ordinal())); | ||||
| 
 | ||||
|               literal_shared = std::make_shared<Literal>(std::move(literal)); | ||||
|             } | ||||
|             return LiteralToPython(std::move(literal_shared)); | ||||
|           }) | ||||
|       .def("transfer_to_infeed", | ||||
|            [](const Device& device, const LiteralSlice& literal) { | ||||
|              GlobalPyRefManager()->CollectGarbage(); | ||||
|              py::gil_scoped_release gil_release; | ||||
|              TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, | ||||
|                                  device.GetLocalDeviceState()); | ||||
|              return local_device->client()->TransferToInfeedLocal( | ||||
|                  literal, local_device->device_ordinal()); | ||||
|            }) | ||||
|       .def( | ||||
|           "transfer_from_outfeed", | ||||
|           [](const Device& device, const Shape& shape) -> StatusOr<py::object> { | ||||
|             GlobalPyRefManager()->CollectGarbage(); | ||||
|             std::shared_ptr<Literal> literal_shared; | ||||
| @ -921,7 +949,7 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|   // Local XLA client methods.
 | ||||
| 
 | ||||
|   // Custom-call targets.
 | ||||
|   m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget); | ||||
|   m.def("register_custom_call_target", &PyRegisterCustomCallTarget); | ||||
| 
 | ||||
|   py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig"); | ||||
|   alloc_config.def(py::init<>()) | ||||
| @ -955,7 +983,7 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|              return devices; | ||||
|            }) | ||||
|       .def("host_id", &PyLocalClient::host_id) | ||||
|       .def("GetDefaultDeviceAssignment", | ||||
|       .def("get_default_device_assignment", | ||||
|            [](std::shared_ptr<PyLocalClient> client, int num_replicas, | ||||
|               int num_partitions) | ||||
|                -> StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>> { | ||||
| @ -976,7 +1004,7 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|              return result; | ||||
|            }) | ||||
|       // TODO(skye): delete after all callers can handle 2D output
 | ||||
|       .def("GetDefaultDeviceAssignment", | ||||
|       .def("get_default_device_assignment", | ||||
|            [](std::shared_ptr<PyLocalClient> client, | ||||
|               int num_replicas) -> StatusOr<std::vector<ClientAndPtr<Device>>> { | ||||
|              TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, | ||||
| @ -991,15 +1019,15 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|              } | ||||
|              return result; | ||||
|            }) | ||||
|       .def("CreateChannelHandle", | ||||
|       .def("create_channel_handle", | ||||
|            [](PyLocalClient* client) { | ||||
|              return client->client()->CreateChannelHandle(); | ||||
|            }) | ||||
|       .def("CreateDeviceToHostChannelHandle", | ||||
|       .def("create_device_to_host_channel_handle", | ||||
|            [](PyLocalClient* client) { | ||||
|              return client->client()->CreateDeviceToHostChannelHandle(); | ||||
|            }) | ||||
|       .def("CreateHostToDeviceChannelHandle", [](PyLocalClient* client) { | ||||
|       .def("create_host_to_device_channel_handle", [](PyLocalClient* client) { | ||||
|         return client->client()->CreateHostToDeviceChannelHandle(); | ||||
|       }); | ||||
| 
 | ||||
| @ -1119,7 +1147,7 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|   py::class_<PyLocalExecutable, ClientAndUniquePtr<PyLocalExecutable>> | ||||
|       executable(m, "LocalExecutable"); | ||||
|   executable | ||||
|       .def_static("Compile", | ||||
|       .def_static("compile", | ||||
|                   [](const XlaComputation& computation, | ||||
|                      absl::optional<std::vector<Shape>> argument_layouts, | ||||
|                      const ExecutableBuildOptions* build_options, | ||||
| @ -1146,7 +1174,7 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|                     return WrapWithClient(std::move(client), | ||||
|                                           std::move(executable)); | ||||
|                   }) | ||||
|       .def_static("Compile", | ||||
|       .def_static("compile", | ||||
|                   [](const XlaComputation& computation, | ||||
|                      absl::optional<std::vector<Shape>> argument_layouts, | ||||
|                      const ExecutableBuildOptions* build_options, | ||||
| @ -1189,8 +1217,10 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|              } | ||||
|              return devices; | ||||
|            }) | ||||
|       .def("SizeOfGeneratedCodeInBytes", | ||||
|       .def("size_of_generated_code_in_bytes", | ||||
|            &PyLocalExecutable::SizeOfGeneratedCodeInBytes) | ||||
|       .def("delete", &PyLocalExecutable::Delete) | ||||
|       // TODO(phawkins): delete capitalized methods after updating callers.
 | ||||
|       .def("Delete", &PyLocalExecutable::Delete) | ||||
|       .def( | ||||
|           "Execute", | ||||
| @ -1212,6 +1242,27 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|             return outputs; | ||||
|           }, | ||||
|           py::arg("arguments")) | ||||
|       .def( | ||||
|           "execute", | ||||
|           [](const PyLocalExecutable& executable, | ||||
|              absl::Span<PyLocalBuffer* const> args) | ||||
|               -> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> { | ||||
|             py::gil_scoped_release gil_release; | ||||
|             ExecuteOptions options; | ||||
|             options.untuple_result = true; | ||||
|             TF_ASSIGN_OR_RETURN( | ||||
|                 std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers, | ||||
|                 executable.Execute(args, options)); | ||||
|             std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs; | ||||
|             outputs.reserve(output_buffers.size()); | ||||
|             for (auto& buffer : output_buffers) { | ||||
|               outputs.push_back(WrapWithClient( | ||||
|                   executable.client()->shared_from_this(), std::move(buffer))); | ||||
|             } | ||||
|             return outputs; | ||||
|           }, | ||||
|           py::arg("arguments")) | ||||
|       // TODO(phawkins): delete capitalized methods after updating callers.
 | ||||
|       .def( | ||||
|           "ExecuteOnLocalDevices", | ||||
|           [](const PyLocalExecutable& executable, | ||||
| @ -1239,7 +1290,33 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|           }, | ||||
|           py::arg("arguments")) | ||||
|       .def( | ||||
|           "get_hlo_modules", | ||||
|           "execute_on_local_devices", | ||||
|           [](const PyLocalExecutable& executable, | ||||
|              absl::Span<const std::vector<PyLocalBuffer*>> args) | ||||
|               -> StatusOr< | ||||
|                   std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> { | ||||
|             py::gil_scoped_release gil_release; | ||||
|             ExecuteOptions options; | ||||
|             options.untuple_result = true; | ||||
|             TF_ASSIGN_OR_RETURN( | ||||
|                 std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>> | ||||
|                     output_buffers, | ||||
|                 executable.ExecuteOnLocalDevices(args, options)); | ||||
|             std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> outputs; | ||||
|             outputs.resize(output_buffers.size()); | ||||
|             for (int computation = 0; computation < output_buffers.size(); | ||||
|                  ++computation) { | ||||
|               for (auto& buffer : output_buffers[computation]) { | ||||
|                 outputs[computation].push_back( | ||||
|                     WrapWithClient(executable.client()->shared_from_this(), | ||||
|                                    std::move(buffer))); | ||||
|               } | ||||
|             } | ||||
|             return outputs; | ||||
|           }, | ||||
|           py::arg("arguments")) | ||||
|       .def( | ||||
|           "hlo_modules", | ||||
|           [](const PyLocalExecutable& executable) | ||||
|               -> StatusOr<std::vector<std::shared_ptr<HloModule>>> { | ||||
|             std::vector<std::shared_ptr<HloModule>> modules; | ||||
| @ -1298,12 +1375,19 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|         proto.ParseFromString(serialized_hlo_module_proto); | ||||
|         return absl::make_unique<XlaComputation>(proto); | ||||
|       })) | ||||
|       // TODO(phawkins): delete capitalized names after updating callers.
 | ||||
|       .def("GetProgramShape", &XlaComputation::GetProgramShape) | ||||
|       .def("GetSerializedProto", &GetComputationSerializedProto) | ||||
|       .def("GetHloText", &GetComputationHloText) | ||||
|       .def("GetHloDotGraph", &GetComputationHloDotGraph) | ||||
|       .def("Hash", &HashComputation) | ||||
|       .def("get_hlo_module", &GetHloModule); | ||||
|       .def("get_hlo_module", &GetHloModule) | ||||
|       .def("program_shape", &XlaComputation::GetProgramShape) | ||||
|       .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto) | ||||
|       .def("as_hlo_text", &GetComputationHloText) | ||||
|       .def("as_hlo_dot_graph", &GetComputationHloDotGraph) | ||||
|       .def("hash", &HashComputation) | ||||
|       .def("as_hlo_module", &GetHloModule); | ||||
| 
 | ||||
|   py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions"); | ||||
|   hlo_print_options_class.def(py::init<>()) | ||||
| @ -1381,6 +1465,7 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|       .def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> { | ||||
|         return absl::make_unique<XlaBuilder>(UniquifyName(name)); | ||||
|       })) | ||||
|       // TODO(phawkins): delete capitalized names after updating callers.
 | ||||
|       .def( | ||||
|           "Build", | ||||
|           [](XlaBuilder& builder, absl::optional<XlaOp> root) { | ||||
| @ -1403,6 +1488,35 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|       .def("SetSharding", &XlaBuilder::SetSharding) | ||||
|       .def("ClearSharding", &XlaBuilder::ClearSharding) | ||||
|       .def("SetUpAlias", | ||||
|            [](XlaBuilder& builder, const std::vector<int64>& output_index, | ||||
|               int64 param_number, const std::vector<int64>& param_index) { | ||||
|              builder.SetUpAlias( | ||||
|                  ShapeIndex(output_index.begin(), output_index.end()), | ||||
|                  param_number, | ||||
|                  ShapeIndex(param_index.begin(), param_index.end())); | ||||
|            }) | ||||
|       .def( | ||||
|           "build", | ||||
|           [](XlaBuilder& builder, absl::optional<XlaOp> root) { | ||||
|             return root ? builder.Build(*root) : builder.Build(); | ||||
|           }, | ||||
|           "Builds a computation from the contents of the builder.", | ||||
|           py::arg("root") = absl::nullopt) | ||||
|       .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) | ||||
|       .def("get_shape", &XlaBuilder::GetShape) | ||||
|       .def( | ||||
|           "get_program_shape", | ||||
|           [](const XlaBuilder& builder, | ||||
|              absl::optional<XlaOp> root) -> StatusOr<ProgramShape> { | ||||
|             return root ? builder.GetProgramShape(*root) | ||||
|                         : builder.GetProgramShape(); | ||||
|           }, | ||||
|           py::arg("root") = absl::nullopt) | ||||
|       .def("is_constant", &XlaBuilder::IsConstant) | ||||
|       .def("set_op_metadata", &XlaBuilder::SetOpMetadata) | ||||
|       .def("set_sharding", &XlaBuilder::SetSharding) | ||||
|       .def("clear_sharding", &XlaBuilder::ClearSharding) | ||||
|       .def("setup_alias", | ||||
|            [](XlaBuilder& builder, const std::vector<int64>& output_index, | ||||
|               int64 param_number, const std::vector<int64>& param_index) { | ||||
|              builder.SetUpAlias( | ||||
| @ -1411,7 +1525,9 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|                  ShapeIndex(param_index.begin(), param_index.end())); | ||||
|            }); | ||||
| 
 | ||||
|   // TODO(phawkins): delete capitalized names after updating callers
 | ||||
|   m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor); | ||||
|   m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor); | ||||
|   m.def("DLPackManagedTensorToBuffer", | ||||
|         [](const py::capsule& tensor, std::shared_ptr<PyLocalClient> client) | ||||
|             -> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> { | ||||
| @ -1420,6 +1536,14 @@ PYBIND11_MODULE(xla_extension, m) { | ||||
|               DLPackManagedTensorToBuffer(tensor, client.get())); | ||||
|           return WrapWithClient(std::move(client), std::move(buffer)); | ||||
|         }); | ||||
|   m.def("dlpack_managed_tensor_to_buffer", | ||||
|         [](const py::capsule& tensor, std::shared_ptr<PyLocalClient> client) | ||||
|             -> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> { | ||||
|           TF_ASSIGN_OR_RETURN( | ||||
|               std::unique_ptr<PyLocalBuffer> buffer, | ||||
|               DLPackManagedTensorToBuffer(tensor, client.get())); | ||||
|           return WrapWithClient(std::move(client), std::move(buffer)); | ||||
|         }); | ||||
| 
 | ||||
|   py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision") | ||||
|       .value("DEFAULT", PrecisionConfig::DEFAULT) | ||||
|  | ||||
| @ -147,7 +147,7 @@ class LocalBackend(Backend): | ||||
|     options.debug_options.xla_cpu_fast_math_honor_division = True | ||||
|     options.debug_options.xla_cpu_fast_math_honor_functions = True | ||||
|     options.debug_options.xla_gpu_enable_fast_min_max = False | ||||
|     return _xla.LocalExecutable.Compile(c_computation, | ||||
|     return _xla.LocalExecutable.compile(c_computation, | ||||
|                                         compile_options.argument_layouts, | ||||
|                                         options, self.client, | ||||
|                                         compile_options.device_assignment, | ||||
| @ -155,11 +155,11 @@ class LocalBackend(Backend): | ||||
| 
 | ||||
|   def get_default_device_assignment(self, num_replicas, num_partitions=None): | ||||
|     if num_partitions is not None: | ||||
|       return self.client.GetDefaultDeviceAssignment(num_replicas, | ||||
|                                                     num_partitions) | ||||
|       return self.client.get_default_device_assignment(num_replicas, | ||||
|                                                        num_partitions) | ||||
|     else: | ||||
|       # TODO(skye): delete this case after all callers can handle 2D output | ||||
|       return self.client.GetDefaultDeviceAssignment(num_replicas) | ||||
|       return self.client.get_default_device_assignment(num_replicas) | ||||
| 
 | ||||
| 
 | ||||
| xla_platform_names = { | ||||
| @ -445,7 +445,7 @@ def transfer_to_infeed(value, device=None): | ||||
|   # TODO(phawkins): support non-default backends. | ||||
|   backend = get_local_backend() | ||||
|   device = device or backend.local_devices()[0] | ||||
|   device.TransferToInfeed(value) | ||||
|   device.transfer_to_infeed(value) | ||||
| 
 | ||||
| 
 | ||||
| def transfer_from_outfeed(shape, device=None): | ||||
| @ -462,7 +462,7 @@ def transfer_from_outfeed(shape, device=None): | ||||
|   # TODO(phawkins): support non-default backends. | ||||
|   backend = get_local_backend() | ||||
|   device = device or backend.local_devices()[0] | ||||
|   return device.TransferFromOutfeed( | ||||
|   return device.transfer_from_outfeed( | ||||
|       shape.with_major_to_minor_layout_if_absent()) | ||||
| 
 | ||||
| 
 | ||||
| @ -542,8 +542,7 @@ def execute_with_python_values(executable, arguments=(), backend=None): | ||||
|   backend = backend or get_local_backend() | ||||
| 
 | ||||
|   def put(arg): | ||||
|     return Buffer.from_pyval( | ||||
|         arg, device=executable.local_devices()[0], backend=backend) | ||||
|     return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) | ||||
| 
 | ||||
|   arguments = [put(arg) for arg in arguments] | ||||
|   outputs = executable.Execute(arguments) | ||||
| @ -629,7 +628,7 @@ def register_custom_call_target(name, fn, platform='cpu'): | ||||
|     fn: a PyCapsule object containing the function pointer. | ||||
|     platform: the target platform. | ||||
|   """ | ||||
|   _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform]) | ||||
|   _xla.register_custom_call_target(name, fn, xla_platform_names[platform]) | ||||
| 
 | ||||
| 
 | ||||
| # Deprecated. Use register_custom_call_target instead. | ||||
|  | ||||
| @ -82,7 +82,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       return xla_client.XlaBuilder(name) | ||||
| 
 | ||||
|     def _Execute(self, c, arguments): | ||||
|       compiled_c = self.backend.compile(c.Build()) | ||||
|       compiled_c = self.backend.compile(c.build()) | ||||
|       return xla_client.execute_with_python_values( | ||||
|           compiled_c, arguments, backend=self.backend) | ||||
| 
 | ||||
| @ -136,34 +136,34 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) | ||||
|       x = ops.Mul(p0, p1) | ||||
|       ops.Add(x, x) | ||||
|       return builder.Build() | ||||
|       return builder.build() | ||||
| 
 | ||||
|     def testComputationToHloText(self): | ||||
|       computation = self.ExampleComputation() | ||||
|       hlo_text = computation.GetHloText() | ||||
|       hlo_text = computation.as_hlo_text() | ||||
|       self.assertTrue(hlo_text.startswith("HloModule acomputation")) | ||||
| 
 | ||||
|     def testComputationToHloGraph(self): | ||||
|       computation = self.ExampleComputation() | ||||
|       hlo_dot_graph = computation.GetHloDotGraph() | ||||
|       hlo_dot_graph = computation.as_hlo_dot_graph() | ||||
|       self.assertTrue(hlo_dot_graph.startswith("digraph ")) | ||||
| 
 | ||||
|     def testHloModuleToHloText(self): | ||||
|       computation = self.ExampleComputation() | ||||
|       hlo_text = computation.get_hlo_module().to_string() | ||||
|       hlo_text = computation.as_hlo_module().to_string() | ||||
|       self.assertTrue(hlo_text.startswith("HloModule acomputation")) | ||||
| 
 | ||||
|     def testHloModuleToHloGraph(self): | ||||
|       computation = self.ExampleComputation() | ||||
|       hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( | ||||
|           computation.get_hlo_module()) | ||||
|           computation.as_hlo_module()) | ||||
|       self.assertTrue(hlo_dot_graph.startswith("digraph ")) | ||||
| 
 | ||||
|     @unittest.skipIf(cloud_tpu, "not implemented") | ||||
|     def testCompiledHloModuleToHloText(self): | ||||
|       computation = self.ExampleComputation() | ||||
|       executable = self.backend.compile(computation) | ||||
|       hlo_modules = executable.get_hlo_modules() | ||||
|       hlo_modules = executable.hlo_modules() | ||||
|       self.assertLen(hlo_modules, 1) | ||||
|       hlo_text = hlo_modules[0].to_string() | ||||
|       self.assertTrue(hlo_text.startswith("HloModule acomputation")) | ||||
| @ -180,7 +180,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       p1 = ops.Parameter( | ||||
|           builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) | ||||
|       ops.Mul(p0, p1) | ||||
|       computation0 = builder0.Build() | ||||
|       computation0 = builder0.build() | ||||
| 
 | ||||
|       builder1 = xla_client.XlaBuilder("computation1") | ||||
|       p0 = ops.Parameter(builder1, 0, | ||||
| @ -188,9 +188,9 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       p1 = ops.Parameter( | ||||
|           builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) | ||||
|       ops.Mul(p0, p1) | ||||
|       computation1 = builder1.Build() | ||||
|       computation1 = builder1.build() | ||||
| 
 | ||||
|       self.assertEqual(computation0.Hash(), computation1.Hash()) | ||||
|       self.assertEqual(computation0.hash(), computation1.hash()) | ||||
| 
 | ||||
|   tests.append(ComputationHashTest) | ||||
| 
 | ||||
| @ -396,7 +396,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       # Build the HLO proto | ||||
|       b = xla_client.XlaBuilder("computation") | ||||
|       ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) | ||||
|       serialized_proto = b.Build().GetSerializedProto() | ||||
|       serialized_proto = b.build().as_serialized_hlo_module_proto() | ||||
| 
 | ||||
|       # Load and execute the proto | ||||
|       c = xla_client.XlaComputation(serialized_proto) | ||||
| @ -478,22 +478,22 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), | ||||
|           ops.Constant(c, np.float32(3.14))) | ||||
|       arg = NumpyArrayF32(1.11) | ||||
|       compiled_c = self.backend.compile(c.Build()) | ||||
|       arg_buffer = xla_client.Buffer.from_pyval(arg, backend=self.backend) | ||||
|       compiled_c = self.backend.compile(c.build()) | ||||
|       arg_buffer = self.backend.buffer_from_pyval(arg) | ||||
|       arg_buffer.delete() | ||||
|       with self.assertRaises(RuntimeError): | ||||
|         compiled_c.Execute([arg_buffer]) | ||||
|         compiled_c.execute([arg_buffer]) | ||||
| 
 | ||||
|     def testShape(self): | ||||
|       pyval = np.array([[1., 2.]], np.float32) | ||||
|       local_buffer = xla_client.Buffer.from_pyval(pyval) | ||||
|       local_buffer = self.backend.buffer_from_pyval(pyval) | ||||
|       xla_shape = local_buffer.shape() | ||||
|       self.assertEqual(xla_shape.dimensions(), (1, 2)) | ||||
|       self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) | ||||
| 
 | ||||
|     def testBlockHostUntilReadyWorks(self): | ||||
|       arg = np.array([[1., 2.]], np.float32) | ||||
|       arg_buffer = xla_client.Buffer.from_pyval(arg) | ||||
|       arg_buffer = self.backend.buffer_from_pyval(arg) | ||||
|       arg_buffer.block_host_until_ready() | ||||
|       # This test merely checks that nothing goes awry when we call | ||||
|       # block_host_until_ready(); it's difficult to test anything else. | ||||
| @ -501,8 +501,8 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|     def testCopyToHost(self): | ||||
|       arg0 = np.array([[1., 2.]], np.float32) | ||||
|       arg1 = np.array([[3., 4.]], np.float32) | ||||
|       arg0_buffer = xla_client.Buffer.from_pyval(arg0) | ||||
|       arg1_buffer = xla_client.Buffer.from_pyval(arg1) | ||||
|       arg0_buffer = self.backend.buffer_from_pyval(arg0) | ||||
|       arg1_buffer = self.backend.buffer_from_pyval(arg1) | ||||
|       # Prefetch two buffers using copy_to_host_async, and then retrieve their | ||||
|       # values using to_py. | ||||
|       arg0_buffer.copy_to_host_async() | ||||
| @ -517,8 +517,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|     def testDevice(self): | ||||
|       x = np.arange(8, dtype=np.int32) | ||||
|       for device in self.backend.local_devices(): | ||||
|         buf = xla_client.Buffer.from_pyval( | ||||
|             x, device=device, backend=self.backend) | ||||
|         buf = self.backend.buffer_from_pyval(x, device=device) | ||||
|         self.assertEqual(buf.device(), device) | ||||
|         np.testing.assert_equal(x, buf.to_py()) | ||||
| 
 | ||||
| @ -564,7 +563,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) | ||||
| 
 | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           self.backend.compile(c.Build()), backend=self.backend) | ||||
|           self.backend.compile(c.build()), backend=self.backend) | ||||
|       self.assertLen(result, 1) | ||||
|       expected = np.array(x, dtype=dst_dtype) | ||||
| 
 | ||||
| @ -591,7 +590,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) | ||||
| 
 | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           self.backend.compile(c.Build()), backend=self.backend) | ||||
|           self.backend.compile(c.build()), backend=self.backend) | ||||
|       self.assertLen(result, 1) | ||||
|       expected = x.view(dst_dtype) | ||||
| 
 | ||||
| @ -1127,7 +1126,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           ops.Constant(c, NumpyArrayBool([True, False, False, True])) | ||||
|       ]) | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           self.backend.compile(c.Build()), backend=self.backend) | ||||
|           self.backend.compile(c.build()), backend=self.backend) | ||||
|       self.assertLen(result, 3) | ||||
|       np.testing.assert_equal(result[0], 42) | ||||
|       np.testing.assert_allclose(result[1], [1.0, 2.0]) | ||||
| @ -1166,7 +1165,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, | ||||
|                                              shape)) | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           self.backend.compile(c.Build()), backend=self.backend) | ||||
|           self.backend.compile(c.build()), backend=self.backend) | ||||
|       # since the result is random, we just check shape and uniqueness | ||||
|       self.assertLen(result, 1) | ||||
|       self.assertEqual(result[0].shape, shape) | ||||
| @ -1182,7 +1181,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, | ||||
|                                              shape)) | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           self.backend.compile(c.Build()), backend=self.backend) | ||||
|           self.backend.compile(c.build()), backend=self.backend) | ||||
|       # since the result is random, we just check shape, uniqueness, and range | ||||
|       self.assertLen(result, 1) | ||||
|       self.assertEqual(result[0].shape, shape) | ||||
| @ -1200,7 +1199,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, | ||||
|                                              shape)) | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           self.backend.compile(c.Build()), backend=self.backend) | ||||
|           self.backend.compile(c.build()), backend=self.backend) | ||||
|       # since the result is random, we just check shape, integrality, and range | ||||
|       self.assertLen(result, 1) | ||||
|       self.assertEqual(result[0].shape, shape) | ||||
| @ -1229,7 +1228,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       c = self._NewComputation() | ||||
|       ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           self.backend.compile(c.Build()), backend=self.backend) | ||||
|           self.backend.compile(c.build()), backend=self.backend) | ||||
|       self.assertLen(result, 2) | ||||
|       np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) | ||||
|       np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) | ||||
| @ -1241,7 +1240,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) | ||||
|       q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) | ||||
|       ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) | ||||
|       comparator = b.Build() | ||||
|       comparator = b.build() | ||||
| 
 | ||||
|       keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) | ||||
|       values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) | ||||
| @ -1251,7 +1250,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           dimension=1, | ||||
|           comparator=comparator) | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           self.backend.compile(c.Build())) | ||||
|           self.backend.compile(c.build())) | ||||
|       self.assertLen(result, 2) | ||||
|       np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) | ||||
|       np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) | ||||
| @ -1321,8 +1320,8 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) | ||||
|       const_expr = ops.Sub(b, a) | ||||
|       non_const_expr = ops.Mul(const_expr, x) | ||||
|       self.assertTrue(c.IsConstant(const_expr)) | ||||
|       self.assertFalse(c.IsConstant(non_const_expr)) | ||||
|       self.assertTrue(c.is_constant(const_expr)) | ||||
|       self.assertFalse(c.is_constant(non_const_expr)) | ||||
| 
 | ||||
|     def testGather(self): | ||||
|       a = np.arange(9).astype(np.int32).reshape((3, 3)) | ||||
| @ -1412,7 +1411,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       ops.Parameter(c, 0, | ||||
|                     xla_client.shape_from_pyval(np.array(0, dtype=in_dtype))) | ||||
|       ops.Constant(c, out_dtype(1)) | ||||
|       return c.Build() | ||||
|       return c.build() | ||||
| 
 | ||||
|     def _CreateMulBy2Computation(self, dtype): | ||||
|       """Computation (dtype) -> dtype that multiplies its parameter by 2.""" | ||||
| @ -1423,7 +1422,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|               xla_client.shape_from_pyval(np.array( | ||||
|                   0, dtype=dtype)).with_major_to_minor_layout_if_absent()), | ||||
|           ops.Constant(c, dtype(2.0))) | ||||
|       return c.Build() | ||||
|       return c.build() | ||||
| 
 | ||||
|     def _CreateMulF32ByParamComputation(self): | ||||
|       """Computation (f32) -> f32 that multiplies one parameter by the other.""" | ||||
| @ -1431,7 +1430,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       ops.Mul( | ||||
|           ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), | ||||
|           ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) | ||||
|       return c.Build() | ||||
|       return c.build() | ||||
| 
 | ||||
|     def _CreateBinaryAddComputation(self, dtype): | ||||
|       """Computation (dtype, dtype) -> dtype that adds its two parameters.""" | ||||
| @ -1439,7 +1438,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) | ||||
|       shape = shape.with_major_to_minor_layout_if_absent() | ||||
|       ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) | ||||
|       return c.Build() | ||||
|       return c.build() | ||||
| 
 | ||||
|     def _CreateBinaryGeComputation(self, dtype): | ||||
|       """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" | ||||
| @ -1447,7 +1446,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) | ||||
|       shape = shape.with_major_to_minor_layout_if_absent() | ||||
|       ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) | ||||
|       return c.Build() | ||||
|       return c.build() | ||||
| 
 | ||||
|     def _MakeSample3DArray(self, dtype): | ||||
|       return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], | ||||
| @ -1516,7 +1515,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|         c = self._NewComputation("div_param0_by_param1") | ||||
|         shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) | ||||
|         ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) | ||||
|         return c.Build() | ||||
|         return c.build() | ||||
| 
 | ||||
|       c = self._NewComputation() | ||||
|       ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), | ||||
| @ -1539,7 +1538,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       window_strides = (1, 2) | ||||
|       padding = xla_client.window_padding_type_to_pad_values( | ||||
|           xla_client.PaddingType.VALID, | ||||
|           c.GetShape(operand).dimensions(), window_dimensions, window_strides) | ||||
|           c.get_shape(operand).dimensions(), window_dimensions, window_strides) | ||||
|       ops.SelectAndScatterWithGeneralPadding( | ||||
|           operand, | ||||
|           select=self._CreateBinaryGeComputation(dtype), | ||||
| @ -1686,7 +1685,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|         c = self._NewComputation("test_lt_10") | ||||
|         shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) | ||||
|         ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) | ||||
|         return c.Build() | ||||
|         return c.build() | ||||
| 
 | ||||
|       cond = LessThan10Cond() | ||||
|       body = self._CreateMulBy2Computation(dtype) | ||||
| @ -1728,10 +1727,10 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|               ops.CreateToken(c), | ||||
|               xla_client.shape_from_pyval( | ||||
|                   to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) | ||||
|       compiled_c = self.backend.compile(c.Build()) | ||||
|       compiled_c = self.backend.compile(c.build()) | ||||
|       device = self.backend.local_devices()[0] | ||||
|       for item in to_infeed: | ||||
|         xla_client.transfer_to_infeed(item, device=device) | ||||
|         device.transfer_to_infeed(item) | ||||
| 
 | ||||
|       for item in to_infeed: | ||||
|         result, = xla_client.execute_with_python_values( | ||||
| @ -1747,9 +1746,9 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|               ops.CreateToken(c), | ||||
|               xla_client.shape_from_pyval( | ||||
|                   to_infeed).with_major_to_minor_layout_if_absent()), 0) | ||||
|       compiled_c = self.backend.compile(c.Build()) | ||||
|       compiled_c = self.backend.compile(c.build()) | ||||
|       device = self.backend.local_devices()[0] | ||||
|       xla_client.transfer_to_infeed(to_infeed, device=device) | ||||
|       device.transfer_to_infeed(to_infeed) | ||||
| 
 | ||||
|       result = xla_client.execute_with_python_values( | ||||
|           compiled_c, backend=self.backend) | ||||
| @ -1771,14 +1770,14 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           to_round_trip[0]).with_major_to_minor_layout_if_absent() | ||||
|       ops.OutfeedWithToken(x, token, outfeed_shape) | ||||
| 
 | ||||
|       compiled_c = self.backend.compile(c.Build()) | ||||
|       compiled_c = self.backend.compile(c.build()) | ||||
|       device = self.backend.local_devices()[0] | ||||
| 
 | ||||
|       for want in to_round_trip: | ||||
|         execution = threading.Thread(target=lambda: compiled_c.Execute([])) | ||||
|         execution = threading.Thread(target=lambda: compiled_c.execute([])) | ||||
|         execution.start() | ||||
|         xla_client.transfer_to_infeed(want, device=device) | ||||
|         got = xla_client.transfer_from_outfeed(outfeed_shape, device=device) | ||||
|         device.transfer_to_infeed(want) | ||||
|         got = device.transfer_from_outfeed(outfeed_shape) | ||||
|         execution.join() | ||||
|         self.assertEqual(want, got) | ||||
| 
 | ||||
| @ -1811,9 +1810,9 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
| 
 | ||||
|     def testCompileWithWrongElementTypeInLayout(self): | ||||
|       c = self._NewComputation() | ||||
|       c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) | ||||
|       c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) | ||||
|       ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) | ||||
|       c.ClearOpMetadata() | ||||
|       c.clear_op_metadata() | ||||
| 
 | ||||
|       options = xla_client.CompileOptions() | ||||
|       options.argument_layouts = [ | ||||
| @ -1821,7 +1820,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       ] | ||||
| 
 | ||||
|       def TestFun(): | ||||
|         return self.backend.compile(c.Build(), compile_options=options) | ||||
|         return self.backend.compile(c.build(), compile_options=options) | ||||
| 
 | ||||
|       self.assertRaisesRegex( | ||||
|           RuntimeError, r".*Invalid argument shape.*" | ||||
| @ -1829,13 +1828,13 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
| 
 | ||||
|     def testInvokeWithWrongElementType(self): | ||||
|       c = self._NewComputation() | ||||
|       c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) | ||||
|       c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) | ||||
|       ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) | ||||
|       c.ClearOpMetadata() | ||||
|       c.clear_op_metadata() | ||||
| 
 | ||||
|       def TestFun(): | ||||
|         return xla_client.execute_with_python_values( | ||||
|             self.backend.compile(c.Build()), [self.f32_scalar_2]) | ||||
|             self.backend.compile(c.build()), [self.f32_scalar_2]) | ||||
| 
 | ||||
|       self.assertRaisesRegex( | ||||
|           RuntimeError, r"Invalid argument: Argument does not match.*" | ||||
| @ -1853,7 +1852,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       ops.Add(result, ops.Constant(c, np.float32(1.618))) | ||||
| 
 | ||||
|       arg = NumpyArrayF32(1.0) | ||||
|       compiled_c = self.backend.compile(c.Build(result)) | ||||
|       compiled_c = self.backend.compile(c.build(result)) | ||||
|       ans, = xla_client.execute_with_python_values( | ||||
|           compiled_c, [arg], backend=self.backend) | ||||
|       np.testing.assert_allclose(ans, 4.14) | ||||
| @ -1869,16 +1868,14 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       sharding.type = sharding.type.REPLICATED | ||||
|       sharding.tile_assignment_dimensions.extend([1]) | ||||
|       sharding.tile_assignment_devices.extend([0]) | ||||
|       # Set Sharding. | ||||
|       c.SetSharding(sharding) | ||||
|       c.set_sharding(sharding) | ||||
|       x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) | ||||
|       # Clear Sharding. | ||||
|       c.ClearSharding() | ||||
|       c.clear_sharding() | ||||
| 
 | ||||
|       result = ops.Add(x, ops.Constant(c, np.float32(3.14))) | ||||
|       ops.Add(result, ops.Constant(c, np.float32(1.618))) | ||||
|       arg = NumpyArrayF32(1.0) | ||||
|       compiled_c = self.backend.compile(c.Build(result)) | ||||
|       compiled_c = self.backend.compile(c.build(result)) | ||||
|       ans, = xla_client.execute_with_python_values( | ||||
|           compiled_c, [arg], backend=self.backend) | ||||
|       np.testing.assert_allclose(ans, 4.14) | ||||
| @ -1898,8 +1895,8 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|           xla_client.shape_from_pyval( | ||||
|               NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) | ||||
|       out = ops.Add(p1, p2) | ||||
|       c.SetUpAlias([], 0, []) | ||||
|       c = c.Build(out) | ||||
|       c.setup_alias([], 0, []) | ||||
|       c = c.build(out) | ||||
|       if self.backend.platform != "tpu": | ||||
|         with self.assertRaisesRegex( | ||||
|             RuntimeError, "Buffer aliasing is not supported " | ||||
| @ -1940,21 +1937,22 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|     } for dtype in dlpack_dtypes for shape in testcase_shapes) | ||||
|     def testRoundTrip(self, dtype, shape): | ||||
|       x = np.array(np.random.rand(*shape) * 100, dtype=dtype) | ||||
|       buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) | ||||
|       dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) | ||||
|       buffer = self.backend.buffer_from_pyval(x) | ||||
|       dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) | ||||
|       del buffer  # Free "buffer" to make sure dlt retains ownership. | ||||
|       self.assertEqual(type(dlt).__name__, "PyCapsule") | ||||
|       y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, self.backend.client) | ||||
|       y = xla_client._xla.dlpack_managed_tensor_to_buffer( | ||||
|           dlt, self.backend.client) | ||||
|       np.testing.assert_array_equal(x, y.to_py()) | ||||
| 
 | ||||
|     def testTensorsCanBeConsumedOnceOnly(self): | ||||
|       x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) | ||||
|       buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) | ||||
|       dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) | ||||
|       buffer = self.backend.buffer_from_pyval(x) | ||||
|       dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) | ||||
| 
 | ||||
|       def ConsumeDLPackTensor(): | ||||
|         _ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, | ||||
|                                                         self.backend.client) | ||||
|         _ = xla_client._xla.dlpack_managed_tensor_to_buffer( | ||||
|             dlt, self.backend.client) | ||||
| 
 | ||||
|       ConsumeDLPackTensor() | ||||
|       self.assertRaisesRegex( | ||||
| @ -1981,7 +1979,7 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|     def testRoundTrip(self, dtype, shape): | ||||
|       x = np.array(np.random.rand(*shape) * 100, dtype=dtype) | ||||
|       x_ptr = x.__array_interface__["data"][0] | ||||
|       buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) | ||||
|       buffer = self.backend.buffer_from_pyval(x) | ||||
|       y = np.array(buffer, copy=False) | ||||
|       y_ptr = y.__array_interface__["data"][0] | ||||
|       np.testing.assert_array_equal(x, y) | ||||
| @ -1990,15 +1988,14 @@ def TestFactory(xla_backend, cloud_tpu=False): | ||||
|       self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr) | ||||
|       self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) | ||||
| 
 | ||||
|       buffer2 = xla_client.Buffer.from_pyval( | ||||
|           x, backend=self.backend, force_copy=True) | ||||
|       buffer2 = self.backend.buffer_from_pyval(x, force_copy=True) | ||||
|       z = np.array(buffer2, copy=False) | ||||
|       self.assertNotEqual(x.__array_interface__["data"][0], | ||||
|                           z.__array_interface__["data"][0]) | ||||
| 
 | ||||
|     def testDeleteWithActiveView(self): | ||||
|       x = np.random.randn(20, 10) | ||||
|       buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) | ||||
|       buffer = self.backend.buffer_from_pyval(x) | ||||
|       buffer_ptr = buffer.unsafe_buffer_pointer() | ||||
|       y = np.array(buffer, copy=False) | ||||
|       buffer.delete() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user