Add int dtypes to TF XLA bridge for matmul ops
PiperOrigin-RevId: 356165547 Change-Id: Ib58bf831d003a4376938b8e2333716778fdbcc7c
This commit is contained in:
parent
b646e8951c
commit
221c2d5d53
@ -21,12 +21,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr std::array<DataType, 6> kMatmulTypes = {
|
||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}};
|
||||
constexpr std::array<DataType, 10> kMatmulTypes = {
|
||||
{DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128,
|
||||
DT_INT32, DT_INT64, DT_INT16, DT_INT8}};
|
||||
|
||||
class MatMulOp : public XlaOpKernel {
|
||||
public:
|
||||
|
Loading…
x
Reference in New Issue
Block a user