diff --git a/tensorflow/compiler/xla/python/pmap_lib.cc b/tensorflow/compiler/xla/python/pmap_lib.cc index e96ddc4f4e4..ee8cafd951a 100644 --- a/tensorflow/compiler/xla/python/pmap_lib.cc +++ b/tensorflow/compiler/xla/python/pmap_lib.cc @@ -346,7 +346,10 @@ void BuildPmapSubmodule(pybind11::module& m) { std::vector>(), py::arg("sharding"), py::arg("mesh_mapping")) .def_property_readonly("sharding", &ShardingSpec::GetSharding) - .def_property_readonly("mesh_mapping", &ShardingSpec::GetMeshMapping); + .def_property_readonly("mesh_mapping", &ShardingSpec::GetMeshMapping) + .def("__eq__", [](const ShardingSpec& self, const ShardingSpec& other) { + return self == other; + }); py::class_ sda(pmap_lib, "ShardedDeviceArray"); sda.def(py::init()) diff --git a/tensorflow/compiler/xla/python/pmap_lib.h b/tensorflow/compiler/xla/python/pmap_lib.h index a3f8fff16b4..1b5988c8f81 100644 --- a/tensorflow/compiler/xla/python/pmap_lib.h +++ b/tensorflow/compiler/xla/python/pmap_lib.h @@ -49,7 +49,10 @@ namespace jax { // The 3 following structs define how to shard one dimension of an ndarry. // // `NoSharding` (`None` in Python) means no sharding. -struct NoSharding {}; +struct NoSharding { + bool operator==(const NoSharding& other) const { return true; } + bool operator!=(const NoSharding& other) const { return false; } +}; // `Chunked` means that the dimension is split into np.prod(chunks) chunks // and the split dimension itself is preserved inside the map. @@ -125,6 +128,12 @@ class ShardingSpec { return mesh_mapping_; } + bool operator==(const ShardingSpec& other) const { + return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_; + } + + bool operator!=(const ShardingSpec& other) const { return !(*this == other); } + private: // `sharding` specifies how the array is supposed to get partitioned into // chunks. Its length matchs the rank of the array. See the docstring