diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index cdf277581f4..cdfa30dd9a7 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -49,17 +49,18 @@ Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { return Status::OK(); } -/* static */ StatusOr DeviceAssignment::Deserialize( - const DeviceAssignmentProto& proto) { +/* static */ StatusOr> +DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); - DeviceAssignment assignment(proto.replica_count(), proto.computation_count()); + auto assignment = MakeUnique(proto.replica_count(), + proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); ++computation) { const auto& computation_device = proto.computation_devices(computation); TF_RET_CHECK(computation_device.replica_device_ids_size() == proto.replica_count()); for (int replica = 0; replica < proto.replica_count(); ++replica) { - assignment(replica, computation) = + (*assignment)(replica, computation) = computation_device.replica_device_ids(replica); } } diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 4d26d6bb85f..7d9abcd100d 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -49,7 +49,11 @@ class DeviceAssignment : public Array2D { // Protocol buffer serialization and deserialization. Status Serialize(DeviceAssignmentProto* proto) const; - static StatusOr Deserialize( + + // Return a std::unique_ptr instead of a DeviceAssignment + // directly because one of the supported TF platforms (mac) does not compile + // due to a StatusOr of an incomplete type (DeviceAssignment). + static StatusOr> Deserialize( const DeviceAssignmentProto& proto); };