Add half support to CrossReplicaSum
PiperOrigin-RevId: 325264977 Change-Id: I74d8c5667f8c89fc0c78641ab5be3576f5855c3f
This commit is contained in:
parent
916e5e54be
commit
b3caff096f
@ -78,7 +78,7 @@ REGISTER_OP("CrossReplicaSum")
|
||||
.Input("input: T")
|
||||
.Input("group_assignment: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: {bfloat16, float, int32, uint32}")
|
||||
.Attr("T: {half, bfloat16, float, int32, uint32}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("CollectivePermute")
|
||||
|
Loading…
Reference in New Issue
Block a user