[XLA:Python] API cleanups.
Add or change method names on most classes to be lower_case instead of CamelCase to match the usual Python style. Renamed a few methods to read better. The old names aren't yet removed because callers (e.g., JAX) need to be updated. PiperOrigin-RevId: 308677107 Change-Id: I46993f093676402853e6032a822b308f59677e6f
This commit is contained in:
parent
71964116c5
commit
40d89f69e1
@ -98,7 +98,7 @@ class TpuBackend(xla_client.Backend):
|
|||||||
options.debug_options.xla_cpu_fast_math_honor_division = True
|
options.debug_options.xla_cpu_fast_math_honor_division = True
|
||||||
options.debug_options.xla_cpu_fast_math_honor_functions = True
|
options.debug_options.xla_cpu_fast_math_honor_functions = True
|
||||||
options.debug_options.xla_gpu_enable_fast_min_max = False
|
options.debug_options.xla_gpu_enable_fast_min_max = False
|
||||||
return _tpu_client.TpuExecutable.Compile(c_computation,
|
return _tpu_client.TpuExecutable.compile(c_computation,
|
||||||
compile_options.argument_layouts,
|
compile_options.argument_layouts,
|
||||||
options, self.client,
|
options, self.client,
|
||||||
compile_options.device_assignment,
|
compile_options.device_assignment,
|
||||||
@ -106,14 +106,8 @@ class TpuBackend(xla_client.Backend):
|
|||||||
|
|
||||||
def get_default_device_assignment(self, num_replicas, num_partitions=None):
|
def get_default_device_assignment(self, num_replicas, num_partitions=None):
|
||||||
if num_partitions is not None:
|
if num_partitions is not None:
|
||||||
return self.client.GetDefaultDeviceAssignment(num_replicas,
|
return self.client.get_default_device_assignment(num_replicas,
|
||||||
num_partitions)
|
num_partitions)
|
||||||
else:
|
else:
|
||||||
# TODO(henrytan): delete this case after all callers can handle 2D output
|
# TODO(henrytan): delete this case after all callers can handle 2D output
|
||||||
return self.client.GetDefaultDeviceAssignment(num_replicas)
|
return self.client.get_default_device_assignment(num_replicas)
|
||||||
|
|
||||||
def serialize(self, executable):
|
|
||||||
return self.client.SerializeExecutable(executable)
|
|
||||||
|
|
||||||
def deserialize(self, serialized_executable):
|
|
||||||
return self.client.DeserializeExecutable(serialized_executable, self.client)
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
.def("devices", &PyTpuClient::devices)
|
.def("devices", &PyTpuClient::devices)
|
||||||
.def("local_devices", &PyTpuClient::local_devices)
|
.def("local_devices", &PyTpuClient::local_devices)
|
||||||
.def("host_id", &PyTpuClient::host_id)
|
.def("host_id", &PyTpuClient::host_id)
|
||||||
.def("GetDefaultDeviceAssignment",
|
.def("get_default_device_assignment",
|
||||||
[](PyTpuClient* client, int num_replicas, int num_partitions)
|
[](PyTpuClient* client, int num_replicas, int num_partitions)
|
||||||
-> StatusOr<std::vector<std::vector<std::shared_ptr<Device>>>> {
|
-> StatusOr<std::vector<std::vector<std::shared_ptr<Device>>>> {
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||||
@ -57,7 +57,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
return result;
|
return result;
|
||||||
})
|
})
|
||||||
// TODO(skye): delete after all callers can handle 2D output
|
// TODO(skye): delete after all callers can handle 2D output
|
||||||
.def("GetDefaultDeviceAssignment",
|
.def("get_default_device_assignment",
|
||||||
[](PyTpuClient* client, int num_replicas)
|
[](PyTpuClient* client, int num_replicas)
|
||||||
-> StatusOr<std::vector<std::shared_ptr<Device>>> {
|
-> StatusOr<std::vector<std::shared_ptr<Device>>> {
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||||
@ -72,14 +72,14 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
})
|
})
|
||||||
.def("TransferToInfeed",
|
.def("transfer_to_infeed",
|
||||||
[](PyTpuClient* client, const LiteralSlice& literal,
|
[](PyTpuClient* client, const LiteralSlice& literal,
|
||||||
int device_ordinal) {
|
int device_ordinal) {
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
return client->TransferToInfeed(literal, device_ordinal);
|
return client->TransferToInfeed(literal, device_ordinal);
|
||||||
})
|
})
|
||||||
.def("TransferFromOutfeed",
|
.def("transfer_from_outfeed",
|
||||||
[](PyTpuClient* client, const Shape& shape,
|
[](PyTpuClient* client, const Shape& shape,
|
||||||
int device_ordinal) -> StatusOr<py::object> {
|
int device_ordinal) -> StatusOr<py::object> {
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
@ -159,9 +159,9 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
||||||
.def_static("Compile", &PyTpuExecutable::Compile,
|
.def_static("compile", &PyTpuExecutable::Compile,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def_static("Compile",
|
.def_static("compile",
|
||||||
[](const XlaComputation& computation,
|
[](const XlaComputation& computation,
|
||||||
absl::optional<std::vector<Shape>> argument_layouts,
|
absl::optional<std::vector<Shape>> argument_layouts,
|
||||||
const ExecutableBuildOptions* build_options,
|
const ExecutableBuildOptions* build_options,
|
||||||
@ -184,12 +184,17 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
.def("local_logical_device_ids",
|
.def("local_logical_device_ids",
|
||||||
&PyTpuExecutable::local_logical_device_ids)
|
&PyTpuExecutable::local_logical_device_ids)
|
||||||
.def("local_devices", &PyTpuExecutable::local_devices)
|
.def("local_devices", &PyTpuExecutable::local_devices)
|
||||||
.def("SizeOfGeneratedCodeInBytes",
|
.def("size_of_generated_code_in_bytes",
|
||||||
&PyTpuExecutable::SizeOfGeneratedCodeInBytes)
|
&PyTpuExecutable::SizeOfGeneratedCodeInBytes)
|
||||||
.def("Delete", &PyTpuExecutable::Delete)
|
.def("Delete", &PyTpuExecutable::Delete)
|
||||||
.def("Execute", &PyTpuExecutable::Execute,
|
.def("Execute", &PyTpuExecutable::Execute,
|
||||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
||||||
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
||||||
|
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
||||||
|
.def("delete", &PyTpuExecutable::Delete)
|
||||||
|
.def("execute", &PyTpuExecutable::Execute,
|
||||||
|
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
|
||||||
|
.def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices,
|
||||||
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
|
||||||
|
|
||||||
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
|
||||||
|
|||||||
@ -676,7 +676,7 @@ void BuildProfilerSubmodule(py::module* m) {
|
|||||||
traceme_class.def(py::init<py::str, py::kwargs>())
|
traceme_class.def(py::init<py::str, py::kwargs>())
|
||||||
.def("__enter__", &TraceMeContextManager::Enter)
|
.def("__enter__", &TraceMeContextManager::Enter)
|
||||||
.def("__exit__", &TraceMeContextManager::Exit)
|
.def("__exit__", &TraceMeContextManager::Exit)
|
||||||
.def_static("IsEnabled", &TraceMeContextManager::IsEnabled);
|
.def_static("is_enabled", &TraceMeContextManager::IsEnabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -880,6 +880,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def_property_readonly("platform", &Device::platform_name)
|
.def_property_readonly("platform", &Device::platform_name)
|
||||||
.def_property_readonly("device_kind", &Device::device_kind)
|
.def_property_readonly("device_kind", &Device::device_kind)
|
||||||
.def("__str__", &Device::DebugString)
|
.def("__str__", &Device::DebugString)
|
||||||
|
// TODO(phawkins): remove capitalized names after updating callers.
|
||||||
.def("TransferToInfeed",
|
.def("TransferToInfeed",
|
||||||
[](const Device& device, const LiteralSlice& literal) {
|
[](const Device& device, const LiteralSlice& literal) {
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
@ -891,6 +892,33 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"TransferFromOutfeed",
|
"TransferFromOutfeed",
|
||||||
|
[](const Device& device, const Shape& shape) -> StatusOr<py::object> {
|
||||||
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
|
std::shared_ptr<Literal> literal_shared;
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
|
device.GetLocalDeviceState());
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
Literal literal,
|
||||||
|
local_device->client()->TransferFromOutfeedLocal(
|
||||||
|
shape, local_device->device_ordinal()));
|
||||||
|
|
||||||
|
literal_shared = std::make_shared<Literal>(std::move(literal));
|
||||||
|
}
|
||||||
|
return LiteralToPython(std::move(literal_shared));
|
||||||
|
})
|
||||||
|
.def("transfer_to_infeed",
|
||||||
|
[](const Device& device, const LiteralSlice& literal) {
|
||||||
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
|
device.GetLocalDeviceState());
|
||||||
|
return local_device->client()->TransferToInfeedLocal(
|
||||||
|
literal, local_device->device_ordinal());
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"transfer_from_outfeed",
|
||||||
[](const Device& device, const Shape& shape) -> StatusOr<py::object> {
|
[](const Device& device, const Shape& shape) -> StatusOr<py::object> {
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
std::shared_ptr<Literal> literal_shared;
|
std::shared_ptr<Literal> literal_shared;
|
||||||
@ -921,7 +949,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
// Local XLA client methods.
|
// Local XLA client methods.
|
||||||
|
|
||||||
// Custom-call targets.
|
// Custom-call targets.
|
||||||
m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget);
|
m.def("register_custom_call_target", &PyRegisterCustomCallTarget);
|
||||||
|
|
||||||
py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig");
|
py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig");
|
||||||
alloc_config.def(py::init<>())
|
alloc_config.def(py::init<>())
|
||||||
@ -955,7 +983,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
return devices;
|
return devices;
|
||||||
})
|
})
|
||||||
.def("host_id", &PyLocalClient::host_id)
|
.def("host_id", &PyLocalClient::host_id)
|
||||||
.def("GetDefaultDeviceAssignment",
|
.def("get_default_device_assignment",
|
||||||
[](std::shared_ptr<PyLocalClient> client, int num_replicas,
|
[](std::shared_ptr<PyLocalClient> client, int num_replicas,
|
||||||
int num_partitions)
|
int num_partitions)
|
||||||
-> StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>> {
|
-> StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>> {
|
||||||
@ -976,7 +1004,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
return result;
|
return result;
|
||||||
})
|
})
|
||||||
// TODO(skye): delete after all callers can handle 2D output
|
// TODO(skye): delete after all callers can handle 2D output
|
||||||
.def("GetDefaultDeviceAssignment",
|
.def("get_default_device_assignment",
|
||||||
[](std::shared_ptr<PyLocalClient> client,
|
[](std::shared_ptr<PyLocalClient> client,
|
||||||
int num_replicas) -> StatusOr<std::vector<ClientAndPtr<Device>>> {
|
int num_replicas) -> StatusOr<std::vector<ClientAndPtr<Device>>> {
|
||||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||||
@ -991,15 +1019,15 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
})
|
})
|
||||||
.def("CreateChannelHandle",
|
.def("create_channel_handle",
|
||||||
[](PyLocalClient* client) {
|
[](PyLocalClient* client) {
|
||||||
return client->client()->CreateChannelHandle();
|
return client->client()->CreateChannelHandle();
|
||||||
})
|
})
|
||||||
.def("CreateDeviceToHostChannelHandle",
|
.def("create_device_to_host_channel_handle",
|
||||||
[](PyLocalClient* client) {
|
[](PyLocalClient* client) {
|
||||||
return client->client()->CreateDeviceToHostChannelHandle();
|
return client->client()->CreateDeviceToHostChannelHandle();
|
||||||
})
|
})
|
||||||
.def("CreateHostToDeviceChannelHandle", [](PyLocalClient* client) {
|
.def("create_host_to_device_channel_handle", [](PyLocalClient* client) {
|
||||||
return client->client()->CreateHostToDeviceChannelHandle();
|
return client->client()->CreateHostToDeviceChannelHandle();
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -1119,7 +1147,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
py::class_<PyLocalExecutable, ClientAndUniquePtr<PyLocalExecutable>>
|
py::class_<PyLocalExecutable, ClientAndUniquePtr<PyLocalExecutable>>
|
||||||
executable(m, "LocalExecutable");
|
executable(m, "LocalExecutable");
|
||||||
executable
|
executable
|
||||||
.def_static("Compile",
|
.def_static("compile",
|
||||||
[](const XlaComputation& computation,
|
[](const XlaComputation& computation,
|
||||||
absl::optional<std::vector<Shape>> argument_layouts,
|
absl::optional<std::vector<Shape>> argument_layouts,
|
||||||
const ExecutableBuildOptions* build_options,
|
const ExecutableBuildOptions* build_options,
|
||||||
@ -1146,7 +1174,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
return WrapWithClient(std::move(client),
|
return WrapWithClient(std::move(client),
|
||||||
std::move(executable));
|
std::move(executable));
|
||||||
})
|
})
|
||||||
.def_static("Compile",
|
.def_static("compile",
|
||||||
[](const XlaComputation& computation,
|
[](const XlaComputation& computation,
|
||||||
absl::optional<std::vector<Shape>> argument_layouts,
|
absl::optional<std::vector<Shape>> argument_layouts,
|
||||||
const ExecutableBuildOptions* build_options,
|
const ExecutableBuildOptions* build_options,
|
||||||
@ -1189,8 +1217,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
}
|
}
|
||||||
return devices;
|
return devices;
|
||||||
})
|
})
|
||||||
.def("SizeOfGeneratedCodeInBytes",
|
.def("size_of_generated_code_in_bytes",
|
||||||
&PyLocalExecutable::SizeOfGeneratedCodeInBytes)
|
&PyLocalExecutable::SizeOfGeneratedCodeInBytes)
|
||||||
|
.def("delete", &PyLocalExecutable::Delete)
|
||||||
|
// TODO(phawkins): delete capitalized methods after updating callers.
|
||||||
.def("Delete", &PyLocalExecutable::Delete)
|
.def("Delete", &PyLocalExecutable::Delete)
|
||||||
.def(
|
.def(
|
||||||
"Execute",
|
"Execute",
|
||||||
@ -1212,6 +1242,27 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
return outputs;
|
return outputs;
|
||||||
},
|
},
|
||||||
py::arg("arguments"))
|
py::arg("arguments"))
|
||||||
|
.def(
|
||||||
|
"execute",
|
||||||
|
[](const PyLocalExecutable& executable,
|
||||||
|
absl::Span<PyLocalBuffer* const> args)
|
||||||
|
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
ExecuteOptions options;
|
||||||
|
options.untuple_result = true;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
|
||||||
|
executable.Execute(args, options));
|
||||||
|
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
|
||||||
|
outputs.reserve(output_buffers.size());
|
||||||
|
for (auto& buffer : output_buffers) {
|
||||||
|
outputs.push_back(WrapWithClient(
|
||||||
|
executable.client()->shared_from_this(), std::move(buffer)));
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
},
|
||||||
|
py::arg("arguments"))
|
||||||
|
// TODO(phawkins): delete capitalized methods after updating callers.
|
||||||
.def(
|
.def(
|
||||||
"ExecuteOnLocalDevices",
|
"ExecuteOnLocalDevices",
|
||||||
[](const PyLocalExecutable& executable,
|
[](const PyLocalExecutable& executable,
|
||||||
@ -1239,7 +1290,33 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
},
|
},
|
||||||
py::arg("arguments"))
|
py::arg("arguments"))
|
||||||
.def(
|
.def(
|
||||||
"get_hlo_modules",
|
"execute_on_local_devices",
|
||||||
|
[](const PyLocalExecutable& executable,
|
||||||
|
absl::Span<const std::vector<PyLocalBuffer*>> args)
|
||||||
|
-> StatusOr<
|
||||||
|
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> {
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
ExecuteOptions options;
|
||||||
|
options.untuple_result = true;
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||||
|
output_buffers,
|
||||||
|
executable.ExecuteOnLocalDevices(args, options));
|
||||||
|
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> outputs;
|
||||||
|
outputs.resize(output_buffers.size());
|
||||||
|
for (int computation = 0; computation < output_buffers.size();
|
||||||
|
++computation) {
|
||||||
|
for (auto& buffer : output_buffers[computation]) {
|
||||||
|
outputs[computation].push_back(
|
||||||
|
WrapWithClient(executable.client()->shared_from_this(),
|
||||||
|
std::move(buffer)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
},
|
||||||
|
py::arg("arguments"))
|
||||||
|
.def(
|
||||||
|
"hlo_modules",
|
||||||
[](const PyLocalExecutable& executable)
|
[](const PyLocalExecutable& executable)
|
||||||
-> StatusOr<std::vector<std::shared_ptr<HloModule>>> {
|
-> StatusOr<std::vector<std::shared_ptr<HloModule>>> {
|
||||||
std::vector<std::shared_ptr<HloModule>> modules;
|
std::vector<std::shared_ptr<HloModule>> modules;
|
||||||
@ -1298,12 +1375,19 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
proto.ParseFromString(serialized_hlo_module_proto);
|
proto.ParseFromString(serialized_hlo_module_proto);
|
||||||
return absl::make_unique<XlaComputation>(proto);
|
return absl::make_unique<XlaComputation>(proto);
|
||||||
}))
|
}))
|
||||||
|
// TODO(phawkins): delete capitalized names after updating callers.
|
||||||
.def("GetProgramShape", &XlaComputation::GetProgramShape)
|
.def("GetProgramShape", &XlaComputation::GetProgramShape)
|
||||||
.def("GetSerializedProto", &GetComputationSerializedProto)
|
.def("GetSerializedProto", &GetComputationSerializedProto)
|
||||||
.def("GetHloText", &GetComputationHloText)
|
.def("GetHloText", &GetComputationHloText)
|
||||||
.def("GetHloDotGraph", &GetComputationHloDotGraph)
|
.def("GetHloDotGraph", &GetComputationHloDotGraph)
|
||||||
.def("Hash", &HashComputation)
|
.def("Hash", &HashComputation)
|
||||||
.def("get_hlo_module", &GetHloModule);
|
.def("get_hlo_module", &GetHloModule)
|
||||||
|
.def("program_shape", &XlaComputation::GetProgramShape)
|
||||||
|
.def("as_serialized_hlo_module_proto", &GetComputationSerializedProto)
|
||||||
|
.def("as_hlo_text", &GetComputationHloText)
|
||||||
|
.def("as_hlo_dot_graph", &GetComputationHloDotGraph)
|
||||||
|
.def("hash", &HashComputation)
|
||||||
|
.def("as_hlo_module", &GetHloModule);
|
||||||
|
|
||||||
py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions");
|
py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions");
|
||||||
hlo_print_options_class.def(py::init<>())
|
hlo_print_options_class.def(py::init<>())
|
||||||
@ -1381,6 +1465,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> {
|
.def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> {
|
||||||
return absl::make_unique<XlaBuilder>(UniquifyName(name));
|
return absl::make_unique<XlaBuilder>(UniquifyName(name));
|
||||||
}))
|
}))
|
||||||
|
// TODO(phawkins): delete capitalized names after updating callers.
|
||||||
.def(
|
.def(
|
||||||
"Build",
|
"Build",
|
||||||
[](XlaBuilder& builder, absl::optional<XlaOp> root) {
|
[](XlaBuilder& builder, absl::optional<XlaOp> root) {
|
||||||
@ -1403,6 +1488,35 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def("SetSharding", &XlaBuilder::SetSharding)
|
.def("SetSharding", &XlaBuilder::SetSharding)
|
||||||
.def("ClearSharding", &XlaBuilder::ClearSharding)
|
.def("ClearSharding", &XlaBuilder::ClearSharding)
|
||||||
.def("SetUpAlias",
|
.def("SetUpAlias",
|
||||||
|
[](XlaBuilder& builder, const std::vector<int64>& output_index,
|
||||||
|
int64 param_number, const std::vector<int64>& param_index) {
|
||||||
|
builder.SetUpAlias(
|
||||||
|
ShapeIndex(output_index.begin(), output_index.end()),
|
||||||
|
param_number,
|
||||||
|
ShapeIndex(param_index.begin(), param_index.end()));
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"build",
|
||||||
|
[](XlaBuilder& builder, absl::optional<XlaOp> root) {
|
||||||
|
return root ? builder.Build(*root) : builder.Build();
|
||||||
|
},
|
||||||
|
"Builds a computation from the contents of the builder.",
|
||||||
|
py::arg("root") = absl::nullopt)
|
||||||
|
.def("clear_op_metadata", &XlaBuilder::ClearOpMetadata)
|
||||||
|
.def("get_shape", &XlaBuilder::GetShape)
|
||||||
|
.def(
|
||||||
|
"get_program_shape",
|
||||||
|
[](const XlaBuilder& builder,
|
||||||
|
absl::optional<XlaOp> root) -> StatusOr<ProgramShape> {
|
||||||
|
return root ? builder.GetProgramShape(*root)
|
||||||
|
: builder.GetProgramShape();
|
||||||
|
},
|
||||||
|
py::arg("root") = absl::nullopt)
|
||||||
|
.def("is_constant", &XlaBuilder::IsConstant)
|
||||||
|
.def("set_op_metadata", &XlaBuilder::SetOpMetadata)
|
||||||
|
.def("set_sharding", &XlaBuilder::SetSharding)
|
||||||
|
.def("clear_sharding", &XlaBuilder::ClearSharding)
|
||||||
|
.def("setup_alias",
|
||||||
[](XlaBuilder& builder, const std::vector<int64>& output_index,
|
[](XlaBuilder& builder, const std::vector<int64>& output_index,
|
||||||
int64 param_number, const std::vector<int64>& param_index) {
|
int64 param_number, const std::vector<int64>& param_index) {
|
||||||
builder.SetUpAlias(
|
builder.SetUpAlias(
|
||||||
@ -1411,7 +1525,9 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
ShapeIndex(param_index.begin(), param_index.end()));
|
ShapeIndex(param_index.begin(), param_index.end()));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// TODO(phawkins): delete capitalized names after updating callers
|
||||||
m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor);
|
m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor);
|
||||||
|
m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor);
|
||||||
m.def("DLPackManagedTensorToBuffer",
|
m.def("DLPackManagedTensorToBuffer",
|
||||||
[](const py::capsule& tensor, std::shared_ptr<PyLocalClient> client)
|
[](const py::capsule& tensor, std::shared_ptr<PyLocalClient> client)
|
||||||
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
||||||
@ -1420,6 +1536,14 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
DLPackManagedTensorToBuffer(tensor, client.get()));
|
DLPackManagedTensorToBuffer(tensor, client.get()));
|
||||||
return WrapWithClient(std::move(client), std::move(buffer));
|
return WrapWithClient(std::move(client), std::move(buffer));
|
||||||
});
|
});
|
||||||
|
m.def("dlpack_managed_tensor_to_buffer",
|
||||||
|
[](const py::capsule& tensor, std::shared_ptr<PyLocalClient> client)
|
||||||
|
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<PyLocalBuffer> buffer,
|
||||||
|
DLPackManagedTensorToBuffer(tensor, client.get()));
|
||||||
|
return WrapWithClient(std::move(client), std::move(buffer));
|
||||||
|
});
|
||||||
|
|
||||||
py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
|
py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
|
||||||
.value("DEFAULT", PrecisionConfig::DEFAULT)
|
.value("DEFAULT", PrecisionConfig::DEFAULT)
|
||||||
|
|||||||
@ -147,7 +147,7 @@ class LocalBackend(Backend):
|
|||||||
options.debug_options.xla_cpu_fast_math_honor_division = True
|
options.debug_options.xla_cpu_fast_math_honor_division = True
|
||||||
options.debug_options.xla_cpu_fast_math_honor_functions = True
|
options.debug_options.xla_cpu_fast_math_honor_functions = True
|
||||||
options.debug_options.xla_gpu_enable_fast_min_max = False
|
options.debug_options.xla_gpu_enable_fast_min_max = False
|
||||||
return _xla.LocalExecutable.Compile(c_computation,
|
return _xla.LocalExecutable.compile(c_computation,
|
||||||
compile_options.argument_layouts,
|
compile_options.argument_layouts,
|
||||||
options, self.client,
|
options, self.client,
|
||||||
compile_options.device_assignment,
|
compile_options.device_assignment,
|
||||||
@ -155,11 +155,11 @@ class LocalBackend(Backend):
|
|||||||
|
|
||||||
def get_default_device_assignment(self, num_replicas, num_partitions=None):
|
def get_default_device_assignment(self, num_replicas, num_partitions=None):
|
||||||
if num_partitions is not None:
|
if num_partitions is not None:
|
||||||
return self.client.GetDefaultDeviceAssignment(num_replicas,
|
return self.client.get_default_device_assignment(num_replicas,
|
||||||
num_partitions)
|
num_partitions)
|
||||||
else:
|
else:
|
||||||
# TODO(skye): delete this case after all callers can handle 2D output
|
# TODO(skye): delete this case after all callers can handle 2D output
|
||||||
return self.client.GetDefaultDeviceAssignment(num_replicas)
|
return self.client.get_default_device_assignment(num_replicas)
|
||||||
|
|
||||||
|
|
||||||
xla_platform_names = {
|
xla_platform_names = {
|
||||||
@ -445,7 +445,7 @@ def transfer_to_infeed(value, device=None):
|
|||||||
# TODO(phawkins): support non-default backends.
|
# TODO(phawkins): support non-default backends.
|
||||||
backend = get_local_backend()
|
backend = get_local_backend()
|
||||||
device = device or backend.local_devices()[0]
|
device = device or backend.local_devices()[0]
|
||||||
device.TransferToInfeed(value)
|
device.transfer_to_infeed(value)
|
||||||
|
|
||||||
|
|
||||||
def transfer_from_outfeed(shape, device=None):
|
def transfer_from_outfeed(shape, device=None):
|
||||||
@ -462,7 +462,7 @@ def transfer_from_outfeed(shape, device=None):
|
|||||||
# TODO(phawkins): support non-default backends.
|
# TODO(phawkins): support non-default backends.
|
||||||
backend = get_local_backend()
|
backend = get_local_backend()
|
||||||
device = device or backend.local_devices()[0]
|
device = device or backend.local_devices()[0]
|
||||||
return device.TransferFromOutfeed(
|
return device.transfer_from_outfeed(
|
||||||
shape.with_major_to_minor_layout_if_absent())
|
shape.with_major_to_minor_layout_if_absent())
|
||||||
|
|
||||||
|
|
||||||
@ -542,8 +542,7 @@ def execute_with_python_values(executable, arguments=(), backend=None):
|
|||||||
backend = backend or get_local_backend()
|
backend = backend or get_local_backend()
|
||||||
|
|
||||||
def put(arg):
|
def put(arg):
|
||||||
return Buffer.from_pyval(
|
return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
|
||||||
arg, device=executable.local_devices()[0], backend=backend)
|
|
||||||
|
|
||||||
arguments = [put(arg) for arg in arguments]
|
arguments = [put(arg) for arg in arguments]
|
||||||
outputs = executable.Execute(arguments)
|
outputs = executable.Execute(arguments)
|
||||||
@ -629,7 +628,7 @@ def register_custom_call_target(name, fn, platform='cpu'):
|
|||||||
fn: a PyCapsule object containing the function pointer.
|
fn: a PyCapsule object containing the function pointer.
|
||||||
platform: the target platform.
|
platform: the target platform.
|
||||||
"""
|
"""
|
||||||
_xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
|
_xla.register_custom_call_target(name, fn, xla_platform_names[platform])
|
||||||
|
|
||||||
|
|
||||||
# Deprecated. Use register_custom_call_target instead.
|
# Deprecated. Use register_custom_call_target instead.
|
||||||
|
|||||||
@ -82,7 +82,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
return xla_client.XlaBuilder(name)
|
return xla_client.XlaBuilder(name)
|
||||||
|
|
||||||
def _Execute(self, c, arguments):
|
def _Execute(self, c, arguments):
|
||||||
compiled_c = self.backend.compile(c.Build())
|
compiled_c = self.backend.compile(c.build())
|
||||||
return xla_client.execute_with_python_values(
|
return xla_client.execute_with_python_values(
|
||||||
compiled_c, arguments, backend=self.backend)
|
compiled_c, arguments, backend=self.backend)
|
||||||
|
|
||||||
@ -136,34 +136,34 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
||||||
x = ops.Mul(p0, p1)
|
x = ops.Mul(p0, p1)
|
||||||
ops.Add(x, x)
|
ops.Add(x, x)
|
||||||
return builder.Build()
|
return builder.build()
|
||||||
|
|
||||||
def testComputationToHloText(self):
|
def testComputationToHloText(self):
|
||||||
computation = self.ExampleComputation()
|
computation = self.ExampleComputation()
|
||||||
hlo_text = computation.GetHloText()
|
hlo_text = computation.as_hlo_text()
|
||||||
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
||||||
|
|
||||||
def testComputationToHloGraph(self):
|
def testComputationToHloGraph(self):
|
||||||
computation = self.ExampleComputation()
|
computation = self.ExampleComputation()
|
||||||
hlo_dot_graph = computation.GetHloDotGraph()
|
hlo_dot_graph = computation.as_hlo_dot_graph()
|
||||||
self.assertTrue(hlo_dot_graph.startswith("digraph "))
|
self.assertTrue(hlo_dot_graph.startswith("digraph "))
|
||||||
|
|
||||||
def testHloModuleToHloText(self):
|
def testHloModuleToHloText(self):
|
||||||
computation = self.ExampleComputation()
|
computation = self.ExampleComputation()
|
||||||
hlo_text = computation.get_hlo_module().to_string()
|
hlo_text = computation.as_hlo_module().to_string()
|
||||||
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
||||||
|
|
||||||
def testHloModuleToHloGraph(self):
|
def testHloModuleToHloGraph(self):
|
||||||
computation = self.ExampleComputation()
|
computation = self.ExampleComputation()
|
||||||
hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph(
|
hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph(
|
||||||
computation.get_hlo_module())
|
computation.as_hlo_module())
|
||||||
self.assertTrue(hlo_dot_graph.startswith("digraph "))
|
self.assertTrue(hlo_dot_graph.startswith("digraph "))
|
||||||
|
|
||||||
@unittest.skipIf(cloud_tpu, "not implemented")
|
@unittest.skipIf(cloud_tpu, "not implemented")
|
||||||
def testCompiledHloModuleToHloText(self):
|
def testCompiledHloModuleToHloText(self):
|
||||||
computation = self.ExampleComputation()
|
computation = self.ExampleComputation()
|
||||||
executable = self.backend.compile(computation)
|
executable = self.backend.compile(computation)
|
||||||
hlo_modules = executable.get_hlo_modules()
|
hlo_modules = executable.hlo_modules()
|
||||||
self.assertLen(hlo_modules, 1)
|
self.assertLen(hlo_modules, 1)
|
||||||
hlo_text = hlo_modules[0].to_string()
|
hlo_text = hlo_modules[0].to_string()
|
||||||
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
self.assertTrue(hlo_text.startswith("HloModule acomputation"))
|
||||||
@ -180,7 +180,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
p1 = ops.Parameter(
|
p1 = ops.Parameter(
|
||||||
builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
||||||
ops.Mul(p0, p1)
|
ops.Mul(p0, p1)
|
||||||
computation0 = builder0.Build()
|
computation0 = builder0.build()
|
||||||
|
|
||||||
builder1 = xla_client.XlaBuilder("computation1")
|
builder1 = xla_client.XlaBuilder("computation1")
|
||||||
p0 = ops.Parameter(builder1, 0,
|
p0 = ops.Parameter(builder1, 0,
|
||||||
@ -188,9 +188,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
p1 = ops.Parameter(
|
p1 = ops.Parameter(
|
||||||
builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
|
||||||
ops.Mul(p0, p1)
|
ops.Mul(p0, p1)
|
||||||
computation1 = builder1.Build()
|
computation1 = builder1.build()
|
||||||
|
|
||||||
self.assertEqual(computation0.Hash(), computation1.Hash())
|
self.assertEqual(computation0.hash(), computation1.hash())
|
||||||
|
|
||||||
tests.append(ComputationHashTest)
|
tests.append(ComputationHashTest)
|
||||||
|
|
||||||
@ -396,7 +396,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
# Build the HLO proto
|
# Build the HLO proto
|
||||||
b = xla_client.XlaBuilder("computation")
|
b = xla_client.XlaBuilder("computation")
|
||||||
ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
|
ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
|
||||||
serialized_proto = b.Build().GetSerializedProto()
|
serialized_proto = b.build().as_serialized_hlo_module_proto()
|
||||||
|
|
||||||
# Load and execute the proto
|
# Load and execute the proto
|
||||||
c = xla_client.XlaComputation(serialized_proto)
|
c = xla_client.XlaComputation(serialized_proto)
|
||||||
@ -478,22 +478,22 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
|
ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
|
||||||
ops.Constant(c, np.float32(3.14)))
|
ops.Constant(c, np.float32(3.14)))
|
||||||
arg = NumpyArrayF32(1.11)
|
arg = NumpyArrayF32(1.11)
|
||||||
compiled_c = self.backend.compile(c.Build())
|
compiled_c = self.backend.compile(c.build())
|
||||||
arg_buffer = xla_client.Buffer.from_pyval(arg, backend=self.backend)
|
arg_buffer = self.backend.buffer_from_pyval(arg)
|
||||||
arg_buffer.delete()
|
arg_buffer.delete()
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
compiled_c.Execute([arg_buffer])
|
compiled_c.execute([arg_buffer])
|
||||||
|
|
||||||
def testShape(self):
|
def testShape(self):
|
||||||
pyval = np.array([[1., 2.]], np.float32)
|
pyval = np.array([[1., 2.]], np.float32)
|
||||||
local_buffer = xla_client.Buffer.from_pyval(pyval)
|
local_buffer = self.backend.buffer_from_pyval(pyval)
|
||||||
xla_shape = local_buffer.shape()
|
xla_shape = local_buffer.shape()
|
||||||
self.assertEqual(xla_shape.dimensions(), (1, 2))
|
self.assertEqual(xla_shape.dimensions(), (1, 2))
|
||||||
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
|
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
|
||||||
|
|
||||||
def testBlockHostUntilReadyWorks(self):
|
def testBlockHostUntilReadyWorks(self):
|
||||||
arg = np.array([[1., 2.]], np.float32)
|
arg = np.array([[1., 2.]], np.float32)
|
||||||
arg_buffer = xla_client.Buffer.from_pyval(arg)
|
arg_buffer = self.backend.buffer_from_pyval(arg)
|
||||||
arg_buffer.block_host_until_ready()
|
arg_buffer.block_host_until_ready()
|
||||||
# This test merely checks that nothing goes awry when we call
|
# This test merely checks that nothing goes awry when we call
|
||||||
# block_host_until_ready(); it's difficult to test anything else.
|
# block_host_until_ready(); it's difficult to test anything else.
|
||||||
@ -501,8 +501,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
def testCopyToHost(self):
|
def testCopyToHost(self):
|
||||||
arg0 = np.array([[1., 2.]], np.float32)
|
arg0 = np.array([[1., 2.]], np.float32)
|
||||||
arg1 = np.array([[3., 4.]], np.float32)
|
arg1 = np.array([[3., 4.]], np.float32)
|
||||||
arg0_buffer = xla_client.Buffer.from_pyval(arg0)
|
arg0_buffer = self.backend.buffer_from_pyval(arg0)
|
||||||
arg1_buffer = xla_client.Buffer.from_pyval(arg1)
|
arg1_buffer = self.backend.buffer_from_pyval(arg1)
|
||||||
# Prefetch two buffers using copy_to_host_async, and then retrieve their
|
# Prefetch two buffers using copy_to_host_async, and then retrieve their
|
||||||
# values using to_py.
|
# values using to_py.
|
||||||
arg0_buffer.copy_to_host_async()
|
arg0_buffer.copy_to_host_async()
|
||||||
@ -517,8 +517,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
def testDevice(self):
|
def testDevice(self):
|
||||||
x = np.arange(8, dtype=np.int32)
|
x = np.arange(8, dtype=np.int32)
|
||||||
for device in self.backend.local_devices():
|
for device in self.backend.local_devices():
|
||||||
buf = xla_client.Buffer.from_pyval(
|
buf = self.backend.buffer_from_pyval(x, device=device)
|
||||||
x, device=device, backend=self.backend)
|
|
||||||
self.assertEqual(buf.device(), device)
|
self.assertEqual(buf.device(), device)
|
||||||
np.testing.assert_equal(x, buf.to_py())
|
np.testing.assert_equal(x, buf.to_py())
|
||||||
|
|
||||||
@ -564,7 +563,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
|
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
|
||||||
|
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()), backend=self.backend)
|
self.backend.compile(c.build()), backend=self.backend)
|
||||||
self.assertLen(result, 1)
|
self.assertLen(result, 1)
|
||||||
expected = np.array(x, dtype=dst_dtype)
|
expected = np.array(x, dtype=dst_dtype)
|
||||||
|
|
||||||
@ -591,7 +590,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
|
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
|
||||||
|
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()), backend=self.backend)
|
self.backend.compile(c.build()), backend=self.backend)
|
||||||
self.assertLen(result, 1)
|
self.assertLen(result, 1)
|
||||||
expected = x.view(dst_dtype)
|
expected = x.view(dst_dtype)
|
||||||
|
|
||||||
@ -1127,7 +1126,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.Constant(c, NumpyArrayBool([True, False, False, True]))
|
ops.Constant(c, NumpyArrayBool([True, False, False, True]))
|
||||||
])
|
])
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()), backend=self.backend)
|
self.backend.compile(c.build()), backend=self.backend)
|
||||||
self.assertLen(result, 3)
|
self.assertLen(result, 3)
|
||||||
np.testing.assert_equal(result[0], 42)
|
np.testing.assert_equal(result[0], 42)
|
||||||
np.testing.assert_allclose(result[1], [1.0, 2.0])
|
np.testing.assert_allclose(result[1], [1.0, 2.0])
|
||||||
@ -1166,7 +1165,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
|
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
|
||||||
shape))
|
shape))
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()), backend=self.backend)
|
self.backend.compile(c.build()), backend=self.backend)
|
||||||
# since the result is random, we just check shape and uniqueness
|
# since the result is random, we just check shape and uniqueness
|
||||||
self.assertLen(result, 1)
|
self.assertLen(result, 1)
|
||||||
self.assertEqual(result[0].shape, shape)
|
self.assertEqual(result[0].shape, shape)
|
||||||
@ -1182,7 +1181,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
|
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
|
||||||
shape))
|
shape))
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()), backend=self.backend)
|
self.backend.compile(c.build()), backend=self.backend)
|
||||||
# since the result is random, we just check shape, uniqueness, and range
|
# since the result is random, we just check shape, uniqueness, and range
|
||||||
self.assertLen(result, 1)
|
self.assertLen(result, 1)
|
||||||
self.assertEqual(result[0].shape, shape)
|
self.assertEqual(result[0].shape, shape)
|
||||||
@ -1200,7 +1199,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
|
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
|
||||||
shape))
|
shape))
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()), backend=self.backend)
|
self.backend.compile(c.build()), backend=self.backend)
|
||||||
# since the result is random, we just check shape, integrality, and range
|
# since the result is random, we just check shape, integrality, and range
|
||||||
self.assertLen(result, 1)
|
self.assertLen(result, 1)
|
||||||
self.assertEqual(result[0].shape, shape)
|
self.assertEqual(result[0].shape, shape)
|
||||||
@ -1229,7 +1228,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
|
ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()), backend=self.backend)
|
self.backend.compile(c.build()), backend=self.backend)
|
||||||
self.assertLen(result, 2)
|
self.assertLen(result, 2)
|
||||||
np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
|
np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
|
||||||
np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
|
np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
|
||||||
@ -1241,7 +1240,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0)))
|
p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0)))
|
||||||
q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0)))
|
q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0)))
|
||||||
ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1)))
|
ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1)))
|
||||||
comparator = b.Build()
|
comparator = b.build()
|
||||||
|
|
||||||
keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32)
|
keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32)
|
||||||
values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
|
values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
|
||||||
@ -1251,7 +1250,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
dimension=1,
|
dimension=1,
|
||||||
comparator=comparator)
|
comparator=comparator)
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()))
|
self.backend.compile(c.build()))
|
||||||
self.assertLen(result, 2)
|
self.assertLen(result, 2)
|
||||||
np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
|
np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
|
||||||
np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
|
np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
|
||||||
@ -1321,8 +1320,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0)))
|
x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0)))
|
||||||
const_expr = ops.Sub(b, a)
|
const_expr = ops.Sub(b, a)
|
||||||
non_const_expr = ops.Mul(const_expr, x)
|
non_const_expr = ops.Mul(const_expr, x)
|
||||||
self.assertTrue(c.IsConstant(const_expr))
|
self.assertTrue(c.is_constant(const_expr))
|
||||||
self.assertFalse(c.IsConstant(non_const_expr))
|
self.assertFalse(c.is_constant(non_const_expr))
|
||||||
|
|
||||||
def testGather(self):
|
def testGather(self):
|
||||||
a = np.arange(9).astype(np.int32).reshape((3, 3))
|
a = np.arange(9).astype(np.int32).reshape((3, 3))
|
||||||
@ -1412,7 +1411,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.Parameter(c, 0,
|
ops.Parameter(c, 0,
|
||||||
xla_client.shape_from_pyval(np.array(0, dtype=in_dtype)))
|
xla_client.shape_from_pyval(np.array(0, dtype=in_dtype)))
|
||||||
ops.Constant(c, out_dtype(1))
|
ops.Constant(c, out_dtype(1))
|
||||||
return c.Build()
|
return c.build()
|
||||||
|
|
||||||
def _CreateMulBy2Computation(self, dtype):
|
def _CreateMulBy2Computation(self, dtype):
|
||||||
"""Computation (dtype) -> dtype that multiplies its parameter by 2."""
|
"""Computation (dtype) -> dtype that multiplies its parameter by 2."""
|
||||||
@ -1423,7 +1422,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
xla_client.shape_from_pyval(np.array(
|
xla_client.shape_from_pyval(np.array(
|
||||||
0, dtype=dtype)).with_major_to_minor_layout_if_absent()),
|
0, dtype=dtype)).with_major_to_minor_layout_if_absent()),
|
||||||
ops.Constant(c, dtype(2.0)))
|
ops.Constant(c, dtype(2.0)))
|
||||||
return c.Build()
|
return c.build()
|
||||||
|
|
||||||
def _CreateMulF32ByParamComputation(self):
|
def _CreateMulF32ByParamComputation(self):
|
||||||
"""Computation (f32) -> f32 that multiplies one parameter by the other."""
|
"""Computation (f32) -> f32 that multiplies one parameter by the other."""
|
||||||
@ -1431,7 +1430,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.Mul(
|
ops.Mul(
|
||||||
ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))),
|
ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))),
|
||||||
ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))))
|
ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))))
|
||||||
return c.Build()
|
return c.build()
|
||||||
|
|
||||||
def _CreateBinaryAddComputation(self, dtype):
|
def _CreateBinaryAddComputation(self, dtype):
|
||||||
"""Computation (dtype, dtype) -> dtype that adds its two parameters."""
|
"""Computation (dtype, dtype) -> dtype that adds its two parameters."""
|
||||||
@ -1439,7 +1438,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
|
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
|
||||||
shape = shape.with_major_to_minor_layout_if_absent()
|
shape = shape.with_major_to_minor_layout_if_absent()
|
||||||
ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
|
ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
|
||||||
return c.Build()
|
return c.build()
|
||||||
|
|
||||||
def _CreateBinaryGeComputation(self, dtype):
|
def _CreateBinaryGeComputation(self, dtype):
|
||||||
"""Computation (dtype, dtype) -> bool that tests param0 >= param1."""
|
"""Computation (dtype, dtype) -> bool that tests param0 >= param1."""
|
||||||
@ -1447,7 +1446,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
|
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
|
||||||
shape = shape.with_major_to_minor_layout_if_absent()
|
shape = shape.with_major_to_minor_layout_if_absent()
|
||||||
ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
|
ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
|
||||||
return c.Build()
|
return c.build()
|
||||||
|
|
||||||
def _MakeSample3DArray(self, dtype):
|
def _MakeSample3DArray(self, dtype):
|
||||||
return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
|
return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
|
||||||
@ -1516,7 +1515,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
c = self._NewComputation("div_param0_by_param1")
|
c = self._NewComputation("div_param0_by_param1")
|
||||||
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
|
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
|
||||||
ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
|
ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
|
||||||
return c.Build()
|
return c.build()
|
||||||
|
|
||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)),
|
ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)),
|
||||||
@ -1539,7 +1538,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
window_strides = (1, 2)
|
window_strides = (1, 2)
|
||||||
padding = xla_client.window_padding_type_to_pad_values(
|
padding = xla_client.window_padding_type_to_pad_values(
|
||||||
xla_client.PaddingType.VALID,
|
xla_client.PaddingType.VALID,
|
||||||
c.GetShape(operand).dimensions(), window_dimensions, window_strides)
|
c.get_shape(operand).dimensions(), window_dimensions, window_strides)
|
||||||
ops.SelectAndScatterWithGeneralPadding(
|
ops.SelectAndScatterWithGeneralPadding(
|
||||||
operand,
|
operand,
|
||||||
select=self._CreateBinaryGeComputation(dtype),
|
select=self._CreateBinaryGeComputation(dtype),
|
||||||
@ -1686,7 +1685,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
c = self._NewComputation("test_lt_10")
|
c = self._NewComputation("test_lt_10")
|
||||||
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
|
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
|
||||||
ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.)))
|
ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.)))
|
||||||
return c.Build()
|
return c.build()
|
||||||
|
|
||||||
cond = LessThan10Cond()
|
cond = LessThan10Cond()
|
||||||
body = self._CreateMulBy2Computation(dtype)
|
body = self._CreateMulBy2Computation(dtype)
|
||||||
@ -1728,10 +1727,10 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.CreateToken(c),
|
ops.CreateToken(c),
|
||||||
xla_client.shape_from_pyval(
|
xla_client.shape_from_pyval(
|
||||||
to_infeed[0]).with_major_to_minor_layout_if_absent()), 0)
|
to_infeed[0]).with_major_to_minor_layout_if_absent()), 0)
|
||||||
compiled_c = self.backend.compile(c.Build())
|
compiled_c = self.backend.compile(c.build())
|
||||||
device = self.backend.local_devices()[0]
|
device = self.backend.local_devices()[0]
|
||||||
for item in to_infeed:
|
for item in to_infeed:
|
||||||
xla_client.transfer_to_infeed(item, device=device)
|
device.transfer_to_infeed(item)
|
||||||
|
|
||||||
for item in to_infeed:
|
for item in to_infeed:
|
||||||
result, = xla_client.execute_with_python_values(
|
result, = xla_client.execute_with_python_values(
|
||||||
@ -1747,9 +1746,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.CreateToken(c),
|
ops.CreateToken(c),
|
||||||
xla_client.shape_from_pyval(
|
xla_client.shape_from_pyval(
|
||||||
to_infeed).with_major_to_minor_layout_if_absent()), 0)
|
to_infeed).with_major_to_minor_layout_if_absent()), 0)
|
||||||
compiled_c = self.backend.compile(c.Build())
|
compiled_c = self.backend.compile(c.build())
|
||||||
device = self.backend.local_devices()[0]
|
device = self.backend.local_devices()[0]
|
||||||
xla_client.transfer_to_infeed(to_infeed, device=device)
|
device.transfer_to_infeed(to_infeed)
|
||||||
|
|
||||||
result = xla_client.execute_with_python_values(
|
result = xla_client.execute_with_python_values(
|
||||||
compiled_c, backend=self.backend)
|
compiled_c, backend=self.backend)
|
||||||
@ -1771,14 +1770,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
to_round_trip[0]).with_major_to_minor_layout_if_absent()
|
to_round_trip[0]).with_major_to_minor_layout_if_absent()
|
||||||
ops.OutfeedWithToken(x, token, outfeed_shape)
|
ops.OutfeedWithToken(x, token, outfeed_shape)
|
||||||
|
|
||||||
compiled_c = self.backend.compile(c.Build())
|
compiled_c = self.backend.compile(c.build())
|
||||||
device = self.backend.local_devices()[0]
|
device = self.backend.local_devices()[0]
|
||||||
|
|
||||||
for want in to_round_trip:
|
for want in to_round_trip:
|
||||||
execution = threading.Thread(target=lambda: compiled_c.Execute([]))
|
execution = threading.Thread(target=lambda: compiled_c.execute([]))
|
||||||
execution.start()
|
execution.start()
|
||||||
xla_client.transfer_to_infeed(want, device=device)
|
device.transfer_to_infeed(want)
|
||||||
got = xla_client.transfer_from_outfeed(outfeed_shape, device=device)
|
got = device.transfer_from_outfeed(outfeed_shape)
|
||||||
execution.join()
|
execution.join()
|
||||||
self.assertEqual(want, got)
|
self.assertEqual(want, got)
|
||||||
|
|
||||||
@ -1811,9 +1810,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
|
|
||||||
def testCompileWithWrongElementTypeInLayout(self):
|
def testCompileWithWrongElementTypeInLayout(self):
|
||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
|
c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
|
||||||
ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
|
ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
|
||||||
c.ClearOpMetadata()
|
c.clear_op_metadata()
|
||||||
|
|
||||||
options = xla_client.CompileOptions()
|
options = xla_client.CompileOptions()
|
||||||
options.argument_layouts = [
|
options.argument_layouts = [
|
||||||
@ -1821,7 +1820,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def TestFun():
|
def TestFun():
|
||||||
return self.backend.compile(c.Build(), compile_options=options)
|
return self.backend.compile(c.build(), compile_options=options)
|
||||||
|
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
RuntimeError, r".*Invalid argument shape.*"
|
RuntimeError, r".*Invalid argument shape.*"
|
||||||
@ -1829,13 +1828,13 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
|
|
||||||
def testInvokeWithWrongElementType(self):
|
def testInvokeWithWrongElementType(self):
|
||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
|
c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
|
||||||
ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
|
ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
|
||||||
c.ClearOpMetadata()
|
c.clear_op_metadata()
|
||||||
|
|
||||||
def TestFun():
|
def TestFun():
|
||||||
return xla_client.execute_with_python_values(
|
return xla_client.execute_with_python_values(
|
||||||
self.backend.compile(c.Build()), [self.f32_scalar_2])
|
self.backend.compile(c.build()), [self.f32_scalar_2])
|
||||||
|
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
RuntimeError, r"Invalid argument: Argument does not match.*"
|
RuntimeError, r"Invalid argument: Argument does not match.*"
|
||||||
@ -1853,7 +1852,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
ops.Add(result, ops.Constant(c, np.float32(1.618)))
|
ops.Add(result, ops.Constant(c, np.float32(1.618)))
|
||||||
|
|
||||||
arg = NumpyArrayF32(1.0)
|
arg = NumpyArrayF32(1.0)
|
||||||
compiled_c = self.backend.compile(c.Build(result))
|
compiled_c = self.backend.compile(c.build(result))
|
||||||
ans, = xla_client.execute_with_python_values(
|
ans, = xla_client.execute_with_python_values(
|
||||||
compiled_c, [arg], backend=self.backend)
|
compiled_c, [arg], backend=self.backend)
|
||||||
np.testing.assert_allclose(ans, 4.14)
|
np.testing.assert_allclose(ans, 4.14)
|
||||||
@ -1869,16 +1868,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
sharding.type = sharding.type.REPLICATED
|
sharding.type = sharding.type.REPLICATED
|
||||||
sharding.tile_assignment_dimensions.extend([1])
|
sharding.tile_assignment_dimensions.extend([1])
|
||||||
sharding.tile_assignment_devices.extend([0])
|
sharding.tile_assignment_devices.extend([0])
|
||||||
# Set Sharding.
|
c.set_sharding(sharding)
|
||||||
c.SetSharding(sharding)
|
|
||||||
x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
|
x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
|
||||||
# Clear Sharding.
|
c.clear_sharding()
|
||||||
c.ClearSharding()
|
|
||||||
|
|
||||||
result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
|
result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
|
||||||
ops.Add(result, ops.Constant(c, np.float32(1.618)))
|
ops.Add(result, ops.Constant(c, np.float32(1.618)))
|
||||||
arg = NumpyArrayF32(1.0)
|
arg = NumpyArrayF32(1.0)
|
||||||
compiled_c = self.backend.compile(c.Build(result))
|
compiled_c = self.backend.compile(c.build(result))
|
||||||
ans, = xla_client.execute_with_python_values(
|
ans, = xla_client.execute_with_python_values(
|
||||||
compiled_c, [arg], backend=self.backend)
|
compiled_c, [arg], backend=self.backend)
|
||||||
np.testing.assert_allclose(ans, 4.14)
|
np.testing.assert_allclose(ans, 4.14)
|
||||||
@ -1898,8 +1895,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
xla_client.shape_from_pyval(
|
xla_client.shape_from_pyval(
|
||||||
NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent())
|
NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent())
|
||||||
out = ops.Add(p1, p2)
|
out = ops.Add(p1, p2)
|
||||||
c.SetUpAlias([], 0, [])
|
c.setup_alias([], 0, [])
|
||||||
c = c.Build(out)
|
c = c.build(out)
|
||||||
if self.backend.platform != "tpu":
|
if self.backend.platform != "tpu":
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Buffer aliasing is not supported "
|
RuntimeError, "Buffer aliasing is not supported "
|
||||||
@ -1940,21 +1937,22 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
} for dtype in dlpack_dtypes for shape in testcase_shapes)
|
} for dtype in dlpack_dtypes for shape in testcase_shapes)
|
||||||
def testRoundTrip(self, dtype, shape):
|
def testRoundTrip(self, dtype, shape):
|
||||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||||
buffer = xla_client.Buffer.from_pyval(x, backend=self.backend)
|
buffer = self.backend.buffer_from_pyval(x)
|
||||||
dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer)
|
dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer)
|
||||||
del buffer # Free "buffer" to make sure dlt retains ownership.
|
del buffer # Free "buffer" to make sure dlt retains ownership.
|
||||||
self.assertEqual(type(dlt).__name__, "PyCapsule")
|
self.assertEqual(type(dlt).__name__, "PyCapsule")
|
||||||
y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, self.backend.client)
|
y = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||||
|
dlt, self.backend.client)
|
||||||
np.testing.assert_array_equal(x, y.to_py())
|
np.testing.assert_array_equal(x, y.to_py())
|
||||||
|
|
||||||
def testTensorsCanBeConsumedOnceOnly(self):
|
def testTensorsCanBeConsumedOnceOnly(self):
|
||||||
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
|
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
|
||||||
buffer = xla_client.Buffer.from_pyval(x, backend=self.backend)
|
buffer = self.backend.buffer_from_pyval(x)
|
||||||
dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer)
|
dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer)
|
||||||
|
|
||||||
def ConsumeDLPackTensor():
|
def ConsumeDLPackTensor():
|
||||||
_ = xla_client._xla.DLPackManagedTensorToBuffer(dlt,
|
_ = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||||
self.backend.client)
|
dlt, self.backend.client)
|
||||||
|
|
||||||
ConsumeDLPackTensor()
|
ConsumeDLPackTensor()
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
@ -1981,7 +1979,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
def testRoundTrip(self, dtype, shape):
|
def testRoundTrip(self, dtype, shape):
|
||||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||||
x_ptr = x.__array_interface__["data"][0]
|
x_ptr = x.__array_interface__["data"][0]
|
||||||
buffer = xla_client.Buffer.from_pyval(x, backend=self.backend)
|
buffer = self.backend.buffer_from_pyval(x)
|
||||||
y = np.array(buffer, copy=False)
|
y = np.array(buffer, copy=False)
|
||||||
y_ptr = y.__array_interface__["data"][0]
|
y_ptr = y.__array_interface__["data"][0]
|
||||||
np.testing.assert_array_equal(x, y)
|
np.testing.assert_array_equal(x, y)
|
||||||
@ -1990,15 +1988,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
|
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
|
||||||
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
|
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
|
||||||
|
|
||||||
buffer2 = xla_client.Buffer.from_pyval(
|
buffer2 = self.backend.buffer_from_pyval(x, force_copy=True)
|
||||||
x, backend=self.backend, force_copy=True)
|
|
||||||
z = np.array(buffer2, copy=False)
|
z = np.array(buffer2, copy=False)
|
||||||
self.assertNotEqual(x.__array_interface__["data"][0],
|
self.assertNotEqual(x.__array_interface__["data"][0],
|
||||||
z.__array_interface__["data"][0])
|
z.__array_interface__["data"][0])
|
||||||
|
|
||||||
def testDeleteWithActiveView(self):
|
def testDeleteWithActiveView(self):
|
||||||
x = np.random.randn(20, 10)
|
x = np.random.randn(20, 10)
|
||||||
buffer = xla_client.Buffer.from_pyval(x, backend=self.backend)
|
buffer = self.backend.buffer_from_pyval(x)
|
||||||
buffer_ptr = buffer.unsafe_buffer_pointer()
|
buffer_ptr = buffer.unsafe_buffer_pointer()
|
||||||
y = np.array(buffer, copy=False)
|
y = np.array(buffer, copy=False)
|
||||||
buffer.delete()
|
buffer.delete()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user