Enabling eager rewriting for MKL matmul

This commit is contained in:
AG Ramesh 2019-08-03 14:02:04 -07:00 committed by Penporn Koanantakool
parent 135639d075
commit b505fde34f

View File

@ -45,6 +45,10 @@ class MklEagerOpRewrite : public EagerOpRewrite {
static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name, static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
std::unique_ptr<EagerOperation>* new_mkl_op); std::unique_ptr<EagerOperation>* new_mkl_op);
// Creates new MKL op for MatMul
static Status CreateMklMatMul(EagerOperation* orig_op,
std::unique_ptr<EagerOperation>* mkl_matmul_op);
// Creates new MKL op for Conv2D, Conv2DBackpropInput and // Creates new MKL op for Conv2D, Conv2DBackpropInput and
// Conv2DBackpropFilter. // Conv2DBackpropFilter.
static Status CreateMklConv2DOp( static Status CreateMklConv2DOp(
@ -60,6 +64,10 @@ class MklEagerOpRewrite : public EagerOpRewrite {
// Checks whether we can rewrite the op to MKL one or not. // Checks whether we can rewrite the op to MKL one or not.
bool ShouldRewriteOp(EagerOperation* op, int* op_idx); bool ShouldRewriteOp(EagerOperation* op, int* op_idx);
// Default rewrite rule to be used when rewrite should happen without any
// restriction.
static bool AlwaysRewrite(EagerOperation* op) { return true; }
}; };
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite); REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
@ -72,6 +80,7 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
{"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp}); {"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
mkl_eager_ops_.push_back( mkl_eager_ops_.push_back(
{"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp}); {"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateMklMatMul});
} }
Status MklEagerOpRewrite::Run( Status MklEagerOpRewrite::Run(
@ -124,6 +133,13 @@ Status MklEagerOpRewrite::SetupNewOp(
return Status::OK(); return Status::OK();
} }
Status MklEagerOpRewrite::CreateMklMatMul(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_matmul_op) {
const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_matmul_op));
return Status::OK();
}
Status MklEagerOpRewrite::CreateMklConv2DOp( Status MklEagerOpRewrite::CreateMklConv2DOp(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) { EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
const string mkl_op_name = const string mkl_op_name =