Add half support to CrossReplicaSum

PiperOrigin-RevId: 325264977
Change-Id: I74d8c5667f8c89fc0c78641ab5be3576f5855c3f
This commit is contained in:
David Majnemer 2020-08-06 10:55:33 -07:00 committed by TensorFlower Gardener
parent 916e5e54be
commit b3caff096f

View File

@ -78,7 +78,7 @@ REGISTER_OP("CrossReplicaSum")
.Input("input: T")
.Input("group_assignment: int32")
.Output("output: T")
.Attr("T: {bfloat16, float, int32, uint32}")
.Attr("T: {half, bfloat16, float, int32, uint32}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("CollectivePermute")