diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 6fe3447cfc8..5da2faacc6e 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -464,7 +464,7 @@ REGISTER_OP("MulNoNan") .Input("x: T") .Input("y: T") .Output("z: T") - .Attr("T: {half, float, double, complex64, complex128}") + .Attr("T: {bfloat16, half, float, double, complex64, complex128}") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); // Note: This op is not commutative w.r.t. to all its inputs.