Add some __eq__ functions to some objects that we can compare.

__hash__ is automatically implemented by pybind11 to raise an error.

PiperOrigin-RevId: 358921081
Change-Id: Ica8de6ba733bd31d59d12c96b267ae4abf0dce6a
This commit is contained in:
Jean-Baptiste Lespiau 2021-02-22 15:45:29 -08:00 committed by TensorFlower Gardener
parent 776812c173
commit 264312fe30
2 changed files with 14 additions and 2 deletions

View File

@ -346,7 +346,10 @@ void BuildPmapSubmodule(pybind11::module& m) {
std::vector<MeshDimAssignment>>(),
py::arg("sharding"), py::arg("mesh_mapping"))
.def_property_readonly("sharding", &ShardingSpec::GetSharding)
.def_property_readonly("mesh_mapping", &ShardingSpec::GetMeshMapping);
.def_property_readonly("mesh_mapping", &ShardingSpec::GetMeshMapping)
.def("__eq__", [](const ShardingSpec& self, const ShardingSpec& other) {
return self == other;
});
py::class_<ShardedDeviceArray> sda(pmap_lib, "ShardedDeviceArray");
sda.def(py::init<pybind11::handle, ShardingSpec, pybind11::list>())

View File

@ -49,7 +49,10 @@ namespace jax {
// The 3 following structs define how to shard one dimension of an ndarry.
//
// `NoSharding` (`None` in Python) means no sharding.
struct NoSharding {};
struct NoSharding {
bool operator==(const NoSharding& other) const { return true; }
bool operator!=(const NoSharding& other) const { return false; }
};
// `Chunked` means that the dimension is split into np.prod(chunks) chunks
// and the split dimension itself is preserved inside the map.
@ -125,6 +128,12 @@ class ShardingSpec {
return mesh_mapping_;
}
bool operator==(const ShardingSpec& other) const {
return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_;
}
bool operator!=(const ShardingSpec& other) const { return !(*this == other); }
private:
// `sharding` specifies how the array is supposed to get partitioned into
// chunks. Its length matchs the rank of the array. See the docstring