Check group_size in collectives is a positive integer.

PiperOrigin-RevId: 274866052
Change-Id: I99f130e27e8ff3e126aaf8d382519cef6bb0d41e
This commit is contained in:
Ayush Dubey 2019-10-15 12:46:20 -07:00 committed by TensorFlower Gardener
parent fc9315df4d
commit 0330d34590

View File

@ -74,6 +74,10 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
: CollectiveOpKernel(c) {
col_params_.instance.type = GATHER_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
@ -150,6 +154,10 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
: CollectiveOpKernel(c) {
col_params_.instance.type = REDUCTION_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
@ -267,6 +275,10 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
: CollectiveOpKernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
@ -339,6 +351,10 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
: CollectiveOpKernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));