Add int dtypes to TF XLA bridge for matmul ops

PiperOrigin-RevId: 356165547
Change-Id: Ib58bf831d003a4376938b8e2333716778fdbcc7c
This commit is contained in:
Harry Zhang 2021-02-07 16:46:52 -08:00 committed by TensorFlower Gardener
parent b646e8951c
commit 221c2d5d53

View File

@ -21,12 +21,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
constexpr std::array<DataType, 6> kMatmulTypes = {
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}};
constexpr std::array<DataType, 10> kMatmulTypes = {
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128,
DT_INT32, DT_INT64, DT_INT16, DT_INT8}};
class MatMulOp : public XlaOpKernel {
public: