Add some __eq__ functions to some objects that we can compare.
__hash__ is automatically implemented by pybind11 to raise an error. PiperOrigin-RevId: 358921081 Change-Id: Ica8de6ba733bd31d59d12c96b267ae4abf0dce6a
This commit is contained in:
parent
776812c173
commit
264312fe30
@ -346,7 +346,10 @@ void BuildPmapSubmodule(pybind11::module& m) {
|
||||
std::vector<MeshDimAssignment>>(),
|
||||
py::arg("sharding"), py::arg("mesh_mapping"))
|
||||
.def_property_readonly("sharding", &ShardingSpec::GetSharding)
|
||||
.def_property_readonly("mesh_mapping", &ShardingSpec::GetMeshMapping);
|
||||
.def_property_readonly("mesh_mapping", &ShardingSpec::GetMeshMapping)
|
||||
.def("__eq__", [](const ShardingSpec& self, const ShardingSpec& other) {
|
||||
return self == other;
|
||||
});
|
||||
|
||||
py::class_<ShardedDeviceArray> sda(pmap_lib, "ShardedDeviceArray");
|
||||
sda.def(py::init<pybind11::handle, ShardingSpec, pybind11::list>())
|
||||
|
||||
@ -49,7 +49,10 @@ namespace jax {
|
||||
// The 3 following structs define how to shard one dimension of an ndarry.
|
||||
//
|
||||
// `NoSharding` (`None` in Python) means no sharding.
|
||||
struct NoSharding {};
|
||||
struct NoSharding {
|
||||
bool operator==(const NoSharding& other) const { return true; }
|
||||
bool operator!=(const NoSharding& other) const { return false; }
|
||||
};
|
||||
|
||||
// `Chunked` means that the dimension is split into np.prod(chunks) chunks
|
||||
// and the split dimension itself is preserved inside the map.
|
||||
@ -125,6 +128,12 @@ class ShardingSpec {
|
||||
return mesh_mapping_;
|
||||
}
|
||||
|
||||
bool operator==(const ShardingSpec& other) const {
|
||||
return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_;
|
||||
}
|
||||
|
||||
bool operator!=(const ShardingSpec& other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
// `sharding` specifies how the array is supposed to get partitioned into
|
||||
// chunks. Its length matchs the rank of the array. See the docstring
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user