[XLA:Python] API cleanups.

Add or change method names on most classes to be lower_case instead of CamelCase to match the usual Python style. Renamed a few methods to read better.

The old names aren't yet removed because callers (e.g., JAX) need to be updated.

PiperOrigin-RevId: 308677107
Change-Id: I46993f093676402853e6032a822b308f59677e6f
This commit is contained in:
Peter Hawkins 2020-04-27 12:39:21 -07:00 committed by TensorFlower Gardener
parent 71964116c5
commit 40d89f69e1
5 changed files with 229 additions and 110 deletions

View File

@ -98,7 +98,7 @@ class TpuBackend(xla_client.Backend):
options.debug_options.xla_cpu_fast_math_honor_division = True options.debug_options.xla_cpu_fast_math_honor_division = True
options.debug_options.xla_cpu_fast_math_honor_functions = True options.debug_options.xla_cpu_fast_math_honor_functions = True
options.debug_options.xla_gpu_enable_fast_min_max = False options.debug_options.xla_gpu_enable_fast_min_max = False
return _tpu_client.TpuExecutable.Compile(c_computation, return _tpu_client.TpuExecutable.compile(c_computation,
compile_options.argument_layouts, compile_options.argument_layouts,
options, self.client, options, self.client,
compile_options.device_assignment, compile_options.device_assignment,
@ -106,14 +106,8 @@ class TpuBackend(xla_client.Backend):
def get_default_device_assignment(self, num_replicas, num_partitions=None): def get_default_device_assignment(self, num_replicas, num_partitions=None):
if num_partitions is not None: if num_partitions is not None:
return self.client.GetDefaultDeviceAssignment(num_replicas, return self.client.get_default_device_assignment(num_replicas,
num_partitions) num_partitions)
else: else:
# TODO(henrytan): delete this case after all callers can handle 2D output # TODO(henrytan): delete this case after all callers can handle 2D output
return self.client.GetDefaultDeviceAssignment(num_replicas) return self.client.get_default_device_assignment(num_replicas)
def serialize(self, executable):
return self.client.SerializeExecutable(executable)
def deserialize(self, serialized_executable):
return self.client.DeserializeExecutable(serialized_executable, self.client)

View File

@ -37,7 +37,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("devices", &PyTpuClient::devices) .def("devices", &PyTpuClient::devices)
.def("local_devices", &PyTpuClient::local_devices) .def("local_devices", &PyTpuClient::local_devices)
.def("host_id", &PyTpuClient::host_id) .def("host_id", &PyTpuClient::host_id)
.def("GetDefaultDeviceAssignment", .def("get_default_device_assignment",
[](PyTpuClient* client, int num_replicas, int num_partitions) [](PyTpuClient* client, int num_replicas, int num_partitions)
-> StatusOr<std::vector<std::vector<std::shared_ptr<Device>>>> { -> StatusOr<std::vector<std::vector<std::shared_ptr<Device>>>> {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
@ -57,7 +57,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
return result; return result;
}) })
// TODO(skye): delete after all callers can handle 2D output // TODO(skye): delete after all callers can handle 2D output
.def("GetDefaultDeviceAssignment", .def("get_default_device_assignment",
[](PyTpuClient* client, int num_replicas) [](PyTpuClient* client, int num_replicas)
-> StatusOr<std::vector<std::shared_ptr<Device>>> { -> StatusOr<std::vector<std::shared_ptr<Device>>> {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
@ -72,14 +72,14 @@ PYBIND11_MODULE(tpu_client_extension, m) {
} }
return result; return result;
}) })
.def("TransferToInfeed", .def("transfer_to_infeed",
[](PyTpuClient* client, const LiteralSlice& literal, [](PyTpuClient* client, const LiteralSlice& literal,
int device_ordinal) { int device_ordinal) {
GlobalPyRefManager()->CollectGarbage(); GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
return client->TransferToInfeed(literal, device_ordinal); return client->TransferToInfeed(literal, device_ordinal);
}) })
.def("TransferFromOutfeed", .def("transfer_from_outfeed",
[](PyTpuClient* client, const Shape& shape, [](PyTpuClient* client, const Shape& shape,
int device_ordinal) -> StatusOr<py::object> { int device_ordinal) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage(); GlobalPyRefManager()->CollectGarbage();
@ -159,9 +159,9 @@ PYBIND11_MODULE(tpu_client_extension, m) {
}); });
py::class_<PyTpuExecutable>(m, "TpuExecutable") py::class_<PyTpuExecutable>(m, "TpuExecutable")
.def_static("Compile", &PyTpuExecutable::Compile, .def_static("compile", &PyTpuExecutable::Compile,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def_static("Compile", .def_static("compile",
[](const XlaComputation& computation, [](const XlaComputation& computation,
absl::optional<std::vector<Shape>> argument_layouts, absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options, const ExecutableBuildOptions* build_options,
@ -184,12 +184,17 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("local_logical_device_ids", .def("local_logical_device_ids",
&PyTpuExecutable::local_logical_device_ids) &PyTpuExecutable::local_logical_device_ids)
.def("local_devices", &PyTpuExecutable::local_devices) .def("local_devices", &PyTpuExecutable::local_devices)
.def("SizeOfGeneratedCodeInBytes", .def("size_of_generated_code_in_bytes",
&PyTpuExecutable::SizeOfGeneratedCodeInBytes) &PyTpuExecutable::SizeOfGeneratedCodeInBytes)
.def("Delete", &PyTpuExecutable::Delete) .def("Delete", &PyTpuExecutable::Delete)
.def("Execute", &PyTpuExecutable::Execute, .def("Execute", &PyTpuExecutable::Execute,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments")) py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
.def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
.def("delete", &PyTpuExecutable::Delete)
.def("execute", &PyTpuExecutable::Execute,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"))
.def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices,
py::call_guard<py::gil_scoped_release>(), py::arg("arguments")); py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice") py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")

View File

@ -676,7 +676,7 @@ void BuildProfilerSubmodule(py::module* m) {
traceme_class.def(py::init<py::str, py::kwargs>()) traceme_class.def(py::init<py::str, py::kwargs>())
.def("__enter__", &TraceMeContextManager::Enter) .def("__enter__", &TraceMeContextManager::Enter)
.def("__exit__", &TraceMeContextManager::Exit) .def("__exit__", &TraceMeContextManager::Exit)
.def_static("IsEnabled", &TraceMeContextManager::IsEnabled); .def_static("is_enabled", &TraceMeContextManager::IsEnabled);
} }
} // namespace } // namespace
@ -880,6 +880,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def_property_readonly("platform", &Device::platform_name) .def_property_readonly("platform", &Device::platform_name)
.def_property_readonly("device_kind", &Device::device_kind) .def_property_readonly("device_kind", &Device::device_kind)
.def("__str__", &Device::DebugString) .def("__str__", &Device::DebugString)
// TODO(phawkins): remove capitalized names after updating callers.
.def("TransferToInfeed", .def("TransferToInfeed",
[](const Device& device, const LiteralSlice& literal) { [](const Device& device, const LiteralSlice& literal) {
GlobalPyRefManager()->CollectGarbage(); GlobalPyRefManager()->CollectGarbage();
@ -891,6 +892,33 @@ PYBIND11_MODULE(xla_extension, m) {
}) })
.def( .def(
"TransferFromOutfeed", "TransferFromOutfeed",
[](const Device& device, const Shape& shape) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage();
std::shared_ptr<Literal> literal_shared;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device.GetLocalDeviceState());
TF_ASSIGN_OR_RETURN(
Literal literal,
local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal()));
literal_shared = std::make_shared<Literal>(std::move(literal));
}
return LiteralToPython(std::move(literal_shared));
})
.def("transfer_to_infeed",
[](const Device& device, const LiteralSlice& literal) {
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device.GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal());
})
.def(
"transfer_from_outfeed",
[](const Device& device, const Shape& shape) -> StatusOr<py::object> { [](const Device& device, const Shape& shape) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage(); GlobalPyRefManager()->CollectGarbage();
std::shared_ptr<Literal> literal_shared; std::shared_ptr<Literal> literal_shared;
@ -921,7 +949,7 @@ PYBIND11_MODULE(xla_extension, m) {
// Local XLA client methods. // Local XLA client methods.
// Custom-call targets. // Custom-call targets.
m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget); m.def("register_custom_call_target", &PyRegisterCustomCallTarget);
py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig"); py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig");
alloc_config.def(py::init<>()) alloc_config.def(py::init<>())
@ -955,7 +983,7 @@ PYBIND11_MODULE(xla_extension, m) {
return devices; return devices;
}) })
.def("host_id", &PyLocalClient::host_id) .def("host_id", &PyLocalClient::host_id)
.def("GetDefaultDeviceAssignment", .def("get_default_device_assignment",
[](std::shared_ptr<PyLocalClient> client, int num_replicas, [](std::shared_ptr<PyLocalClient> client, int num_replicas,
int num_partitions) int num_partitions)
-> StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>> { -> StatusOr<std::vector<std::vector<ClientAndPtr<Device>>>> {
@ -976,7 +1004,7 @@ PYBIND11_MODULE(xla_extension, m) {
return result; return result;
}) })
// TODO(skye): delete after all callers can handle 2D output // TODO(skye): delete after all callers can handle 2D output
.def("GetDefaultDeviceAssignment", .def("get_default_device_assignment",
[](std::shared_ptr<PyLocalClient> client, [](std::shared_ptr<PyLocalClient> client,
int num_replicas) -> StatusOr<std::vector<ClientAndPtr<Device>>> { int num_replicas) -> StatusOr<std::vector<ClientAndPtr<Device>>> {
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
@ -991,15 +1019,15 @@ PYBIND11_MODULE(xla_extension, m) {
} }
return result; return result;
}) })
.def("CreateChannelHandle", .def("create_channel_handle",
[](PyLocalClient* client) { [](PyLocalClient* client) {
return client->client()->CreateChannelHandle(); return client->client()->CreateChannelHandle();
}) })
.def("CreateDeviceToHostChannelHandle", .def("create_device_to_host_channel_handle",
[](PyLocalClient* client) { [](PyLocalClient* client) {
return client->client()->CreateDeviceToHostChannelHandle(); return client->client()->CreateDeviceToHostChannelHandle();
}) })
.def("CreateHostToDeviceChannelHandle", [](PyLocalClient* client) { .def("create_host_to_device_channel_handle", [](PyLocalClient* client) {
return client->client()->CreateHostToDeviceChannelHandle(); return client->client()->CreateHostToDeviceChannelHandle();
}); });
@ -1119,7 +1147,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PyLocalExecutable, ClientAndUniquePtr<PyLocalExecutable>> py::class_<PyLocalExecutable, ClientAndUniquePtr<PyLocalExecutable>>
executable(m, "LocalExecutable"); executable(m, "LocalExecutable");
executable executable
.def_static("Compile", .def_static("compile",
[](const XlaComputation& computation, [](const XlaComputation& computation,
absl::optional<std::vector<Shape>> argument_layouts, absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options, const ExecutableBuildOptions* build_options,
@ -1146,7 +1174,7 @@ PYBIND11_MODULE(xla_extension, m) {
return WrapWithClient(std::move(client), return WrapWithClient(std::move(client),
std::move(executable)); std::move(executable));
}) })
.def_static("Compile", .def_static("compile",
[](const XlaComputation& computation, [](const XlaComputation& computation,
absl::optional<std::vector<Shape>> argument_layouts, absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options, const ExecutableBuildOptions* build_options,
@ -1189,8 +1217,10 @@ PYBIND11_MODULE(xla_extension, m) {
} }
return devices; return devices;
}) })
.def("SizeOfGeneratedCodeInBytes", .def("size_of_generated_code_in_bytes",
&PyLocalExecutable::SizeOfGeneratedCodeInBytes) &PyLocalExecutable::SizeOfGeneratedCodeInBytes)
.def("delete", &PyLocalExecutable::Delete)
// TODO(phawkins): delete capitalized methods after updating callers.
.def("Delete", &PyLocalExecutable::Delete) .def("Delete", &PyLocalExecutable::Delete)
.def( .def(
"Execute", "Execute",
@ -1212,6 +1242,27 @@ PYBIND11_MODULE(xla_extension, m) {
return outputs; return outputs;
}, },
py::arg("arguments")) py::arg("arguments"))
.def(
"execute",
[](const PyLocalExecutable& executable,
absl::Span<PyLocalBuffer* const> args)
-> StatusOr<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> {
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.untuple_result = true;
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<PyLocalBuffer>> output_buffers,
executable.Execute(args, options));
std::vector<ClientAndUniquePtr<PyLocalBuffer>> outputs;
outputs.reserve(output_buffers.size());
for (auto& buffer : output_buffers) {
outputs.push_back(WrapWithClient(
executable.client()->shared_from_this(), std::move(buffer)));
}
return outputs;
},
py::arg("arguments"))
// TODO(phawkins): delete capitalized methods after updating callers.
.def( .def(
"ExecuteOnLocalDevices", "ExecuteOnLocalDevices",
[](const PyLocalExecutable& executable, [](const PyLocalExecutable& executable,
@ -1239,7 +1290,33 @@ PYBIND11_MODULE(xla_extension, m) {
}, },
py::arg("arguments")) py::arg("arguments"))
.def( .def(
"get_hlo_modules", "execute_on_local_devices",
[](const PyLocalExecutable& executable,
absl::Span<const std::vector<PyLocalBuffer*>> args)
-> StatusOr<
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>>> {
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.untuple_result = true;
TF_ASSIGN_OR_RETURN(
std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>
output_buffers,
executable.ExecuteOnLocalDevices(args, options));
std::vector<std::vector<ClientAndUniquePtr<PyLocalBuffer>>> outputs;
outputs.resize(output_buffers.size());
for (int computation = 0; computation < output_buffers.size();
++computation) {
for (auto& buffer : output_buffers[computation]) {
outputs[computation].push_back(
WrapWithClient(executable.client()->shared_from_this(),
std::move(buffer)));
}
}
return outputs;
},
py::arg("arguments"))
.def(
"hlo_modules",
[](const PyLocalExecutable& executable) [](const PyLocalExecutable& executable)
-> StatusOr<std::vector<std::shared_ptr<HloModule>>> { -> StatusOr<std::vector<std::shared_ptr<HloModule>>> {
std::vector<std::shared_ptr<HloModule>> modules; std::vector<std::shared_ptr<HloModule>> modules;
@ -1298,12 +1375,19 @@ PYBIND11_MODULE(xla_extension, m) {
proto.ParseFromString(serialized_hlo_module_proto); proto.ParseFromString(serialized_hlo_module_proto);
return absl::make_unique<XlaComputation>(proto); return absl::make_unique<XlaComputation>(proto);
})) }))
// TODO(phawkins): delete capitalized names after updating callers.
.def("GetProgramShape", &XlaComputation::GetProgramShape) .def("GetProgramShape", &XlaComputation::GetProgramShape)
.def("GetSerializedProto", &GetComputationSerializedProto) .def("GetSerializedProto", &GetComputationSerializedProto)
.def("GetHloText", &GetComputationHloText) .def("GetHloText", &GetComputationHloText)
.def("GetHloDotGraph", &GetComputationHloDotGraph) .def("GetHloDotGraph", &GetComputationHloDotGraph)
.def("Hash", &HashComputation) .def("Hash", &HashComputation)
.def("get_hlo_module", &GetHloModule); .def("get_hlo_module", &GetHloModule)
.def("program_shape", &XlaComputation::GetProgramShape)
.def("as_serialized_hlo_module_proto", &GetComputationSerializedProto)
.def("as_hlo_text", &GetComputationHloText)
.def("as_hlo_dot_graph", &GetComputationHloDotGraph)
.def("hash", &HashComputation)
.def("as_hlo_module", &GetHloModule);
py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions"); py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions");
hlo_print_options_class.def(py::init<>()) hlo_print_options_class.def(py::init<>())
@ -1381,6 +1465,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> { .def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> {
return absl::make_unique<XlaBuilder>(UniquifyName(name)); return absl::make_unique<XlaBuilder>(UniquifyName(name));
})) }))
// TODO(phawkins): delete capitalized names after updating callers.
.def( .def(
"Build", "Build",
[](XlaBuilder& builder, absl::optional<XlaOp> root) { [](XlaBuilder& builder, absl::optional<XlaOp> root) {
@ -1403,6 +1488,35 @@ PYBIND11_MODULE(xla_extension, m) {
.def("SetSharding", &XlaBuilder::SetSharding) .def("SetSharding", &XlaBuilder::SetSharding)
.def("ClearSharding", &XlaBuilder::ClearSharding) .def("ClearSharding", &XlaBuilder::ClearSharding)
.def("SetUpAlias", .def("SetUpAlias",
[](XlaBuilder& builder, const std::vector<int64>& output_index,
int64 param_number, const std::vector<int64>& param_index) {
builder.SetUpAlias(
ShapeIndex(output_index.begin(), output_index.end()),
param_number,
ShapeIndex(param_index.begin(), param_index.end()));
})
.def(
"build",
[](XlaBuilder& builder, absl::optional<XlaOp> root) {
return root ? builder.Build(*root) : builder.Build();
},
"Builds a computation from the contents of the builder.",
py::arg("root") = absl::nullopt)
.def("clear_op_metadata", &XlaBuilder::ClearOpMetadata)
.def("get_shape", &XlaBuilder::GetShape)
.def(
"get_program_shape",
[](const XlaBuilder& builder,
absl::optional<XlaOp> root) -> StatusOr<ProgramShape> {
return root ? builder.GetProgramShape(*root)
: builder.GetProgramShape();
},
py::arg("root") = absl::nullopt)
.def("is_constant", &XlaBuilder::IsConstant)
.def("set_op_metadata", &XlaBuilder::SetOpMetadata)
.def("set_sharding", &XlaBuilder::SetSharding)
.def("clear_sharding", &XlaBuilder::ClearSharding)
.def("setup_alias",
[](XlaBuilder& builder, const std::vector<int64>& output_index, [](XlaBuilder& builder, const std::vector<int64>& output_index,
int64 param_number, const std::vector<int64>& param_index) { int64 param_number, const std::vector<int64>& param_index) {
builder.SetUpAlias( builder.SetUpAlias(
@ -1411,7 +1525,9 @@ PYBIND11_MODULE(xla_extension, m) {
ShapeIndex(param_index.begin(), param_index.end())); ShapeIndex(param_index.begin(), param_index.end()));
}); });
// TODO(phawkins): delete capitalized names after updating callers
m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor); m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor);
m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor);
m.def("DLPackManagedTensorToBuffer", m.def("DLPackManagedTensorToBuffer",
[](const py::capsule& tensor, std::shared_ptr<PyLocalClient> client) [](const py::capsule& tensor, std::shared_ptr<PyLocalClient> client)
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> { -> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
@ -1420,6 +1536,14 @@ PYBIND11_MODULE(xla_extension, m) {
DLPackManagedTensorToBuffer(tensor, client.get())); DLPackManagedTensorToBuffer(tensor, client.get()));
return WrapWithClient(std::move(client), std::move(buffer)); return WrapWithClient(std::move(client), std::move(buffer));
}); });
m.def("dlpack_managed_tensor_to_buffer",
[](const py::capsule& tensor, std::shared_ptr<PyLocalClient> client)
-> StatusOr<ClientAndUniquePtr<PyLocalBuffer>> {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PyLocalBuffer> buffer,
DLPackManagedTensorToBuffer(tensor, client.get()));
return WrapWithClient(std::move(client), std::move(buffer));
});
py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision") py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
.value("DEFAULT", PrecisionConfig::DEFAULT) .value("DEFAULT", PrecisionConfig::DEFAULT)

View File

@ -147,7 +147,7 @@ class LocalBackend(Backend):
options.debug_options.xla_cpu_fast_math_honor_division = True options.debug_options.xla_cpu_fast_math_honor_division = True
options.debug_options.xla_cpu_fast_math_honor_functions = True options.debug_options.xla_cpu_fast_math_honor_functions = True
options.debug_options.xla_gpu_enable_fast_min_max = False options.debug_options.xla_gpu_enable_fast_min_max = False
return _xla.LocalExecutable.Compile(c_computation, return _xla.LocalExecutable.compile(c_computation,
compile_options.argument_layouts, compile_options.argument_layouts,
options, self.client, options, self.client,
compile_options.device_assignment, compile_options.device_assignment,
@ -155,11 +155,11 @@ class LocalBackend(Backend):
def get_default_device_assignment(self, num_replicas, num_partitions=None): def get_default_device_assignment(self, num_replicas, num_partitions=None):
if num_partitions is not None: if num_partitions is not None:
return self.client.GetDefaultDeviceAssignment(num_replicas, return self.client.get_default_device_assignment(num_replicas,
num_partitions) num_partitions)
else: else:
# TODO(skye): delete this case after all callers can handle 2D output # TODO(skye): delete this case after all callers can handle 2D output
return self.client.GetDefaultDeviceAssignment(num_replicas) return self.client.get_default_device_assignment(num_replicas)
xla_platform_names = { xla_platform_names = {
@ -445,7 +445,7 @@ def transfer_to_infeed(value, device=None):
# TODO(phawkins): support non-default backends. # TODO(phawkins): support non-default backends.
backend = get_local_backend() backend = get_local_backend()
device = device or backend.local_devices()[0] device = device or backend.local_devices()[0]
device.TransferToInfeed(value) device.transfer_to_infeed(value)
def transfer_from_outfeed(shape, device=None): def transfer_from_outfeed(shape, device=None):
@ -462,7 +462,7 @@ def transfer_from_outfeed(shape, device=None):
# TODO(phawkins): support non-default backends. # TODO(phawkins): support non-default backends.
backend = get_local_backend() backend = get_local_backend()
device = device or backend.local_devices()[0] device = device or backend.local_devices()[0]
return device.TransferFromOutfeed( return device.transfer_from_outfeed(
shape.with_major_to_minor_layout_if_absent()) shape.with_major_to_minor_layout_if_absent())
@ -542,8 +542,7 @@ def execute_with_python_values(executable, arguments=(), backend=None):
backend = backend or get_local_backend() backend = backend or get_local_backend()
def put(arg): def put(arg):
return Buffer.from_pyval( return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
arg, device=executable.local_devices()[0], backend=backend)
arguments = [put(arg) for arg in arguments] arguments = [put(arg) for arg in arguments]
outputs = executable.Execute(arguments) outputs = executable.Execute(arguments)
@ -629,7 +628,7 @@ def register_custom_call_target(name, fn, platform='cpu'):
fn: a PyCapsule object containing the function pointer. fn: a PyCapsule object containing the function pointer.
platform: the target platform. platform: the target platform.
""" """
_xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform]) _xla.register_custom_call_target(name, fn, xla_platform_names[platform])
# Deprecated. Use register_custom_call_target instead. # Deprecated. Use register_custom_call_target instead.

View File

@ -82,7 +82,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
return xla_client.XlaBuilder(name) return xla_client.XlaBuilder(name)
def _Execute(self, c, arguments): def _Execute(self, c, arguments):
compiled_c = self.backend.compile(c.Build()) compiled_c = self.backend.compile(c.build())
return xla_client.execute_with_python_values( return xla_client.execute_with_python_values(
compiled_c, arguments, backend=self.backend) compiled_c, arguments, backend=self.backend)
@ -136,34 +136,34 @@ def TestFactory(xla_backend, cloud_tpu=False):
builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
x = ops.Mul(p0, p1) x = ops.Mul(p0, p1)
ops.Add(x, x) ops.Add(x, x)
return builder.Build() return builder.build()
def testComputationToHloText(self): def testComputationToHloText(self):
computation = self.ExampleComputation() computation = self.ExampleComputation()
hlo_text = computation.GetHloText() hlo_text = computation.as_hlo_text()
self.assertTrue(hlo_text.startswith("HloModule acomputation")) self.assertTrue(hlo_text.startswith("HloModule acomputation"))
def testComputationToHloGraph(self): def testComputationToHloGraph(self):
computation = self.ExampleComputation() computation = self.ExampleComputation()
hlo_dot_graph = computation.GetHloDotGraph() hlo_dot_graph = computation.as_hlo_dot_graph()
self.assertTrue(hlo_dot_graph.startswith("digraph ")) self.assertTrue(hlo_dot_graph.startswith("digraph "))
def testHloModuleToHloText(self): def testHloModuleToHloText(self):
computation = self.ExampleComputation() computation = self.ExampleComputation()
hlo_text = computation.get_hlo_module().to_string() hlo_text = computation.as_hlo_module().to_string()
self.assertTrue(hlo_text.startswith("HloModule acomputation")) self.assertTrue(hlo_text.startswith("HloModule acomputation"))
def testHloModuleToHloGraph(self): def testHloModuleToHloGraph(self):
computation = self.ExampleComputation() computation = self.ExampleComputation()
hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph(
computation.get_hlo_module()) computation.as_hlo_module())
self.assertTrue(hlo_dot_graph.startswith("digraph ")) self.assertTrue(hlo_dot_graph.startswith("digraph "))
@unittest.skipIf(cloud_tpu, "not implemented") @unittest.skipIf(cloud_tpu, "not implemented")
def testCompiledHloModuleToHloText(self): def testCompiledHloModuleToHloText(self):
computation = self.ExampleComputation() computation = self.ExampleComputation()
executable = self.backend.compile(computation) executable = self.backend.compile(computation)
hlo_modules = executable.get_hlo_modules() hlo_modules = executable.hlo_modules()
self.assertLen(hlo_modules, 1) self.assertLen(hlo_modules, 1)
hlo_text = hlo_modules[0].to_string() hlo_text = hlo_modules[0].to_string()
self.assertTrue(hlo_text.startswith("HloModule acomputation")) self.assertTrue(hlo_text.startswith("HloModule acomputation"))
@ -180,7 +180,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
p1 = ops.Parameter( p1 = ops.Parameter(
builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
ops.Mul(p0, p1) ops.Mul(p0, p1)
computation0 = builder0.Build() computation0 = builder0.build()
builder1 = xla_client.XlaBuilder("computation1") builder1 = xla_client.XlaBuilder("computation1")
p0 = ops.Parameter(builder1, 0, p0 = ops.Parameter(builder1, 0,
@ -188,9 +188,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
p1 = ops.Parameter( p1 = ops.Parameter(
builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
ops.Mul(p0, p1) ops.Mul(p0, p1)
computation1 = builder1.Build() computation1 = builder1.build()
self.assertEqual(computation0.Hash(), computation1.Hash()) self.assertEqual(computation0.hash(), computation1.hash())
tests.append(ComputationHashTest) tests.append(ComputationHashTest)
@ -396,7 +396,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
# Build the HLO proto # Build the HLO proto
b = xla_client.XlaBuilder("computation") b = xla_client.XlaBuilder("computation")
ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
serialized_proto = b.Build().GetSerializedProto() serialized_proto = b.build().as_serialized_hlo_module_proto()
# Load and execute the proto # Load and execute the proto
c = xla_client.XlaComputation(serialized_proto) c = xla_client.XlaComputation(serialized_proto)
@ -478,22 +478,22 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
ops.Constant(c, np.float32(3.14))) ops.Constant(c, np.float32(3.14)))
arg = NumpyArrayF32(1.11) arg = NumpyArrayF32(1.11)
compiled_c = self.backend.compile(c.Build()) compiled_c = self.backend.compile(c.build())
arg_buffer = xla_client.Buffer.from_pyval(arg, backend=self.backend) arg_buffer = self.backend.buffer_from_pyval(arg)
arg_buffer.delete() arg_buffer.delete()
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
compiled_c.Execute([arg_buffer]) compiled_c.execute([arg_buffer])
def testShape(self): def testShape(self):
pyval = np.array([[1., 2.]], np.float32) pyval = np.array([[1., 2.]], np.float32)
local_buffer = xla_client.Buffer.from_pyval(pyval) local_buffer = self.backend.buffer_from_pyval(pyval)
xla_shape = local_buffer.shape() xla_shape = local_buffer.shape()
self.assertEqual(xla_shape.dimensions(), (1, 2)) self.assertEqual(xla_shape.dimensions(), (1, 2))
self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
def testBlockHostUntilReadyWorks(self): def testBlockHostUntilReadyWorks(self):
arg = np.array([[1., 2.]], np.float32) arg = np.array([[1., 2.]], np.float32)
arg_buffer = xla_client.Buffer.from_pyval(arg) arg_buffer = self.backend.buffer_from_pyval(arg)
arg_buffer.block_host_until_ready() arg_buffer.block_host_until_ready()
# This test merely checks that nothing goes awry when we call # This test merely checks that nothing goes awry when we call
# block_host_until_ready(); it's difficult to test anything else. # block_host_until_ready(); it's difficult to test anything else.
@ -501,8 +501,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
def testCopyToHost(self): def testCopyToHost(self):
arg0 = np.array([[1., 2.]], np.float32) arg0 = np.array([[1., 2.]], np.float32)
arg1 = np.array([[3., 4.]], np.float32) arg1 = np.array([[3., 4.]], np.float32)
arg0_buffer = xla_client.Buffer.from_pyval(arg0) arg0_buffer = self.backend.buffer_from_pyval(arg0)
arg1_buffer = xla_client.Buffer.from_pyval(arg1) arg1_buffer = self.backend.buffer_from_pyval(arg1)
# Prefetch two buffers using copy_to_host_async, and then retrieve their # Prefetch two buffers using copy_to_host_async, and then retrieve their
# values using to_py. # values using to_py.
arg0_buffer.copy_to_host_async() arg0_buffer.copy_to_host_async()
@ -517,8 +517,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
def testDevice(self): def testDevice(self):
x = np.arange(8, dtype=np.int32) x = np.arange(8, dtype=np.int32)
for device in self.backend.local_devices(): for device in self.backend.local_devices():
buf = xla_client.Buffer.from_pyval( buf = self.backend.buffer_from_pyval(x, device=device)
x, device=device, backend=self.backend)
self.assertEqual(buf.device(), device) self.assertEqual(buf.device(), device)
np.testing.assert_equal(x, buf.to_py()) np.testing.assert_equal(x, buf.to_py())
@ -564,7 +563,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
self.backend.compile(c.Build()), backend=self.backend) self.backend.compile(c.build()), backend=self.backend)
self.assertLen(result, 1) self.assertLen(result, 1)
expected = np.array(x, dtype=dst_dtype) expected = np.array(x, dtype=dst_dtype)
@ -591,7 +590,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
self.backend.compile(c.Build()), backend=self.backend) self.backend.compile(c.build()), backend=self.backend)
self.assertLen(result, 1) self.assertLen(result, 1)
expected = x.view(dst_dtype) expected = x.view(dst_dtype)
@ -1127,7 +1126,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Constant(c, NumpyArrayBool([True, False, False, True])) ops.Constant(c, NumpyArrayBool([True, False, False, True]))
]) ])
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
self.backend.compile(c.Build()), backend=self.backend) self.backend.compile(c.build()), backend=self.backend)
self.assertLen(result, 3) self.assertLen(result, 3)
np.testing.assert_equal(result[0], 42) np.testing.assert_equal(result[0], 42)
np.testing.assert_allclose(result[1], [1.0, 2.0]) np.testing.assert_allclose(result[1], [1.0, 2.0])
@ -1166,7 +1165,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
shape)) shape))
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
self.backend.compile(c.Build()), backend=self.backend) self.backend.compile(c.build()), backend=self.backend)
# since the result is random, we just check shape and uniqueness # since the result is random, we just check shape and uniqueness
self.assertLen(result, 1) self.assertLen(result, 1)
self.assertEqual(result[0].shape, shape) self.assertEqual(result[0].shape, shape)
@ -1182,7 +1181,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
shape)) shape))
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
self.backend.compile(c.Build()), backend=self.backend) self.backend.compile(c.build()), backend=self.backend)
# since the result is random, we just check shape, uniqueness, and range # since the result is random, we just check shape, uniqueness, and range
self.assertLen(result, 1) self.assertLen(result, 1)
self.assertEqual(result[0].shape, shape) self.assertEqual(result[0].shape, shape)
@ -1200,7 +1199,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
shape)) shape))
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
self.backend.compile(c.Build()), backend=self.backend) self.backend.compile(c.build()), backend=self.backend)
# since the result is random, we just check shape, integrality, and range # since the result is random, we just check shape, integrality, and range
self.assertLen(result, 1) self.assertLen(result, 1)
self.assertEqual(result[0].shape, shape) self.assertEqual(result[0].shape, shape)
@ -1229,7 +1228,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
c = self._NewComputation() c = self._NewComputation()
ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
self.backend.compile(c.Build()), backend=self.backend) self.backend.compile(c.build()), backend=self.backend)
self.assertLen(result, 2) self.assertLen(result, 2)
np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
@ -1241,7 +1240,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0)))
q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0)))
ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1)))
comparator = b.Build() comparator = b.build()
keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32)
values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
@ -1251,7 +1250,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
dimension=1, dimension=1,
comparator=comparator) comparator=comparator)
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
self.backend.compile(c.Build())) self.backend.compile(c.build()))
self.assertLen(result, 2) self.assertLen(result, 2)
np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
@ -1321,8 +1320,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0)))
const_expr = ops.Sub(b, a) const_expr = ops.Sub(b, a)
non_const_expr = ops.Mul(const_expr, x) non_const_expr = ops.Mul(const_expr, x)
self.assertTrue(c.IsConstant(const_expr)) self.assertTrue(c.is_constant(const_expr))
self.assertFalse(c.IsConstant(non_const_expr)) self.assertFalse(c.is_constant(non_const_expr))
def testGather(self): def testGather(self):
a = np.arange(9).astype(np.int32).reshape((3, 3)) a = np.arange(9).astype(np.int32).reshape((3, 3))
@ -1412,7 +1411,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Parameter(c, 0, ops.Parameter(c, 0,
xla_client.shape_from_pyval(np.array(0, dtype=in_dtype))) xla_client.shape_from_pyval(np.array(0, dtype=in_dtype)))
ops.Constant(c, out_dtype(1)) ops.Constant(c, out_dtype(1))
return c.Build() return c.build()
def _CreateMulBy2Computation(self, dtype): def _CreateMulBy2Computation(self, dtype):
"""Computation (dtype) -> dtype that multiplies its parameter by 2.""" """Computation (dtype) -> dtype that multiplies its parameter by 2."""
@ -1423,7 +1422,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
xla_client.shape_from_pyval(np.array( xla_client.shape_from_pyval(np.array(
0, dtype=dtype)).with_major_to_minor_layout_if_absent()), 0, dtype=dtype)).with_major_to_minor_layout_if_absent()),
ops.Constant(c, dtype(2.0))) ops.Constant(c, dtype(2.0)))
return c.Build() return c.build()
def _CreateMulF32ByParamComputation(self): def _CreateMulF32ByParamComputation(self):
"""Computation (f32) -> f32 that multiplies one parameter by the other.""" """Computation (f32) -> f32 that multiplies one parameter by the other."""
@ -1431,7 +1430,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Mul( ops.Mul(
ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))),
ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))))
return c.Build() return c.build()
def _CreateBinaryAddComputation(self, dtype): def _CreateBinaryAddComputation(self, dtype):
"""Computation (dtype, dtype) -> dtype that adds its two parameters.""" """Computation (dtype, dtype) -> dtype that adds its two parameters."""
@ -1439,7 +1438,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
shape = shape.with_major_to_minor_layout_if_absent() shape = shape.with_major_to_minor_layout_if_absent()
ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
return c.Build() return c.build()
def _CreateBinaryGeComputation(self, dtype): def _CreateBinaryGeComputation(self, dtype):
"""Computation (dtype, dtype) -> bool that tests param0 >= param1.""" """Computation (dtype, dtype) -> bool that tests param0 >= param1."""
@ -1447,7 +1446,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
shape = shape.with_major_to_minor_layout_if_absent() shape = shape.with_major_to_minor_layout_if_absent()
ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
return c.Build() return c.build()
def _MakeSample3DArray(self, dtype): def _MakeSample3DArray(self, dtype):
return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
@ -1516,7 +1515,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
c = self._NewComputation("div_param0_by_param1") c = self._NewComputation("div_param0_by_param1")
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
return c.Build() return c.build()
c = self._NewComputation() c = self._NewComputation()
ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)),
@ -1539,7 +1538,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
window_strides = (1, 2) window_strides = (1, 2)
padding = xla_client.window_padding_type_to_pad_values( padding = xla_client.window_padding_type_to_pad_values(
xla_client.PaddingType.VALID, xla_client.PaddingType.VALID,
c.GetShape(operand).dimensions(), window_dimensions, window_strides) c.get_shape(operand).dimensions(), window_dimensions, window_strides)
ops.SelectAndScatterWithGeneralPadding( ops.SelectAndScatterWithGeneralPadding(
operand, operand,
select=self._CreateBinaryGeComputation(dtype), select=self._CreateBinaryGeComputation(dtype),
@ -1686,7 +1685,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
c = self._NewComputation("test_lt_10") c = self._NewComputation("test_lt_10")
shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.)))
return c.Build() return c.build()
cond = LessThan10Cond() cond = LessThan10Cond()
body = self._CreateMulBy2Computation(dtype) body = self._CreateMulBy2Computation(dtype)
@ -1728,10 +1727,10 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.CreateToken(c), ops.CreateToken(c),
xla_client.shape_from_pyval( xla_client.shape_from_pyval(
to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) to_infeed[0]).with_major_to_minor_layout_if_absent()), 0)
compiled_c = self.backend.compile(c.Build()) compiled_c = self.backend.compile(c.build())
device = self.backend.local_devices()[0] device = self.backend.local_devices()[0]
for item in to_infeed: for item in to_infeed:
xla_client.transfer_to_infeed(item, device=device) device.transfer_to_infeed(item)
for item in to_infeed: for item in to_infeed:
result, = xla_client.execute_with_python_values( result, = xla_client.execute_with_python_values(
@ -1747,9 +1746,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.CreateToken(c), ops.CreateToken(c),
xla_client.shape_from_pyval( xla_client.shape_from_pyval(
to_infeed).with_major_to_minor_layout_if_absent()), 0) to_infeed).with_major_to_minor_layout_if_absent()), 0)
compiled_c = self.backend.compile(c.Build()) compiled_c = self.backend.compile(c.build())
device = self.backend.local_devices()[0] device = self.backend.local_devices()[0]
xla_client.transfer_to_infeed(to_infeed, device=device) device.transfer_to_infeed(to_infeed)
result = xla_client.execute_with_python_values( result = xla_client.execute_with_python_values(
compiled_c, backend=self.backend) compiled_c, backend=self.backend)
@ -1771,14 +1770,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
to_round_trip[0]).with_major_to_minor_layout_if_absent() to_round_trip[0]).with_major_to_minor_layout_if_absent()
ops.OutfeedWithToken(x, token, outfeed_shape) ops.OutfeedWithToken(x, token, outfeed_shape)
compiled_c = self.backend.compile(c.Build()) compiled_c = self.backend.compile(c.build())
device = self.backend.local_devices()[0] device = self.backend.local_devices()[0]
for want in to_round_trip: for want in to_round_trip:
execution = threading.Thread(target=lambda: compiled_c.Execute([])) execution = threading.Thread(target=lambda: compiled_c.execute([]))
execution.start() execution.start()
xla_client.transfer_to_infeed(want, device=device) device.transfer_to_infeed(want)
got = xla_client.transfer_from_outfeed(outfeed_shape, device=device) got = device.transfer_from_outfeed(outfeed_shape)
execution.join() execution.join()
self.assertEqual(want, got) self.assertEqual(want, got)
@ -1811,9 +1810,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
def testCompileWithWrongElementTypeInLayout(self): def testCompileWithWrongElementTypeInLayout(self):
c = self._NewComputation() c = self._NewComputation()
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
c.ClearOpMetadata() c.clear_op_metadata()
options = xla_client.CompileOptions() options = xla_client.CompileOptions()
options.argument_layouts = [ options.argument_layouts = [
@ -1821,7 +1820,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
] ]
def TestFun(): def TestFun():
return self.backend.compile(c.Build(), compile_options=options) return self.backend.compile(c.build(), compile_options=options)
self.assertRaisesRegex( self.assertRaisesRegex(
RuntimeError, r".*Invalid argument shape.*" RuntimeError, r".*Invalid argument shape.*"
@ -1829,13 +1828,13 @@ def TestFactory(xla_backend, cloud_tpu=False):
def testInvokeWithWrongElementType(self): def testInvokeWithWrongElementType(self):
c = self._NewComputation() c = self._NewComputation()
c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
c.ClearOpMetadata() c.clear_op_metadata()
def TestFun(): def TestFun():
return xla_client.execute_with_python_values( return xla_client.execute_with_python_values(
self.backend.compile(c.Build()), [self.f32_scalar_2]) self.backend.compile(c.build()), [self.f32_scalar_2])
self.assertRaisesRegex( self.assertRaisesRegex(
RuntimeError, r"Invalid argument: Argument does not match.*" RuntimeError, r"Invalid argument: Argument does not match.*"
@ -1853,7 +1852,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
ops.Add(result, ops.Constant(c, np.float32(1.618))) ops.Add(result, ops.Constant(c, np.float32(1.618)))
arg = NumpyArrayF32(1.0) arg = NumpyArrayF32(1.0)
compiled_c = self.backend.compile(c.Build(result)) compiled_c = self.backend.compile(c.build(result))
ans, = xla_client.execute_with_python_values( ans, = xla_client.execute_with_python_values(
compiled_c, [arg], backend=self.backend) compiled_c, [arg], backend=self.backend)
np.testing.assert_allclose(ans, 4.14) np.testing.assert_allclose(ans, 4.14)
@ -1869,16 +1868,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
sharding.type = sharding.type.REPLICATED sharding.type = sharding.type.REPLICATED
sharding.tile_assignment_dimensions.extend([1]) sharding.tile_assignment_dimensions.extend([1])
sharding.tile_assignment_devices.extend([0]) sharding.tile_assignment_devices.extend([0])
# Set Sharding. c.set_sharding(sharding)
c.SetSharding(sharding)
x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
# Clear Sharding. c.clear_sharding()
c.ClearSharding()
result = ops.Add(x, ops.Constant(c, np.float32(3.14))) result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
ops.Add(result, ops.Constant(c, np.float32(1.618))) ops.Add(result, ops.Constant(c, np.float32(1.618)))
arg = NumpyArrayF32(1.0) arg = NumpyArrayF32(1.0)
compiled_c = self.backend.compile(c.Build(result)) compiled_c = self.backend.compile(c.build(result))
ans, = xla_client.execute_with_python_values( ans, = xla_client.execute_with_python_values(
compiled_c, [arg], backend=self.backend) compiled_c, [arg], backend=self.backend)
np.testing.assert_allclose(ans, 4.14) np.testing.assert_allclose(ans, 4.14)
@ -1898,8 +1895,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
xla_client.shape_from_pyval( xla_client.shape_from_pyval(
NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent())
out = ops.Add(p1, p2) out = ops.Add(p1, p2)
c.SetUpAlias([], 0, []) c.setup_alias([], 0, [])
c = c.Build(out) c = c.build(out)
if self.backend.platform != "tpu": if self.backend.platform != "tpu":
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "Buffer aliasing is not supported " RuntimeError, "Buffer aliasing is not supported "
@ -1940,21 +1937,22 @@ def TestFactory(xla_backend, cloud_tpu=False):
} for dtype in dlpack_dtypes for shape in testcase_shapes) } for dtype in dlpack_dtypes for shape in testcase_shapes)
def testRoundTrip(self, dtype, shape): def testRoundTrip(self, dtype, shape):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype) x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) buffer = self.backend.buffer_from_pyval(x)
dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer)
del buffer # Free "buffer" to make sure dlt retains ownership. del buffer # Free "buffer" to make sure dlt retains ownership.
self.assertEqual(type(dlt).__name__, "PyCapsule") self.assertEqual(type(dlt).__name__, "PyCapsule")
y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, self.backend.client) y = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlt, self.backend.client)
np.testing.assert_array_equal(x, y.to_py()) np.testing.assert_array_equal(x, y.to_py())
def testTensorsCanBeConsumedOnceOnly(self): def testTensorsCanBeConsumedOnceOnly(self):
x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) buffer = self.backend.buffer_from_pyval(x)
dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer)
def ConsumeDLPackTensor(): def ConsumeDLPackTensor():
_ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, _ = xla_client._xla.dlpack_managed_tensor_to_buffer(
self.backend.client) dlt, self.backend.client)
ConsumeDLPackTensor() ConsumeDLPackTensor()
self.assertRaisesRegex( self.assertRaisesRegex(
@ -1981,7 +1979,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
def testRoundTrip(self, dtype, shape): def testRoundTrip(self, dtype, shape):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype) x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
x_ptr = x.__array_interface__["data"][0] x_ptr = x.__array_interface__["data"][0]
buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) buffer = self.backend.buffer_from_pyval(x)
y = np.array(buffer, copy=False) y = np.array(buffer, copy=False)
y_ptr = y.__array_interface__["data"][0] y_ptr = y.__array_interface__["data"][0]
np.testing.assert_array_equal(x, y) np.testing.assert_array_equal(x, y)
@ -1990,15 +1988,14 @@ def TestFactory(xla_backend, cloud_tpu=False):
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr) self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
buffer2 = xla_client.Buffer.from_pyval( buffer2 = self.backend.buffer_from_pyval(x, force_copy=True)
x, backend=self.backend, force_copy=True)
z = np.array(buffer2, copy=False) z = np.array(buffer2, copy=False)
self.assertNotEqual(x.__array_interface__["data"][0], self.assertNotEqual(x.__array_interface__["data"][0],
z.__array_interface__["data"][0]) z.__array_interface__["data"][0])
def testDeleteWithActiveView(self): def testDeleteWithActiveView(self):
x = np.random.randn(20, 10) x = np.random.randn(20, 10)
buffer = xla_client.Buffer.from_pyval(x, backend=self.backend) buffer = self.backend.buffer_from_pyval(x)
buffer_ptr = buffer.unsafe_buffer_pointer() buffer_ptr = buffer.unsafe_buffer_pointer()
y = np.array(buffer, copy=False) y = np.array(buffer, copy=False)
buffer.delete() buffer.delete()