diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc index fa506cb674c..c487aa9e281 100644 --- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc +++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc @@ -45,6 +45,10 @@ class MklEagerOpRewrite : public EagerOpRewrite { static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name, std::unique_ptr* new_mkl_op); + // Creates new MKL op for MatMul + static Status CreateMklMatMul(EagerOperation* orig_op, + std::unique_ptr* mkl_matmul_op); + // Creates new MKL op for Conv2D, Conv2DBackpropInput and // Conv2DBackpropFilter. static Status CreateMklConv2DOp( @@ -60,6 +64,10 @@ class MklEagerOpRewrite : public EagerOpRewrite { // Checks whether we can rewrite the op to MKL one or not. bool ShouldRewriteOp(EagerOperation* op, int* op_idx); + + // Default rewrite rule to be used when rewrite should happen without any + // restriction. + static bool AlwaysRewrite(EagerOperation* op) { return true; } }; REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite); @@ -72,6 +80,7 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line) {"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp}); mkl_eager_ops_.push_back( {"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp}); + mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateMklMatMul}); } Status MklEagerOpRewrite::Run( @@ -124,6 +133,13 @@ Status MklEagerOpRewrite::SetupNewOp( return Status::OK(); } +Status MklEagerOpRewrite::CreateMklMatMul( + EagerOperation* orig_op, std::unique_ptr* mkl_matmul_op) { + const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name()); + TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_matmul_op)); + return Status::OK(); +} + Status MklEagerOpRewrite::CreateMklConv2DOp( EagerOperation* orig_op, std::unique_ptr* mkl_conv2d_op) { const string mkl_op_name =