From 221c2d5d5391606a42abee95b39b9814a8af802e Mon Sep 17 00:00:00 2001 From: Harry Zhang Date: Sun, 7 Feb 2021 16:46:52 -0800 Subject: [PATCH] Add int dtypes to TF XLA bridge for matmul ops PiperOrigin-RevId: 356165547 Change-Id: Ib58bf831d003a4376938b8e2333716778fdbcc7c --- tensorflow/compiler/tf2xla/kernels/matmul_op.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index bd6f58453df..ee4d3d1314e 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -21,12 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, + DT_INT32, DT_INT64, DT_INT16, DT_INT8}}; class MatMulOp : public XlaOpKernel { public: