diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 7e8d3d7002a..b461aa43153 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -186,7 +186,7 @@ class StatelessCategoricalOp : public CategoricalOp { REGISTER_XLA_OP(Name("StatelessMultinomial") .CompileTimeConstantInput("num_samples") - .TypeConstraint("T", {DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("T", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessCategoricalOp);