Enabling eager rewriting for MKL matmul
This commit is contained in:
parent
135639d075
commit
b505fde34f
@ -45,6 +45,10 @@ class MklEagerOpRewrite : public EagerOpRewrite {
|
||||
static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
|
||||
std::unique_ptr<EagerOperation>* new_mkl_op);
|
||||
|
||||
// Creates new MKL op for MatMul
|
||||
static Status CreateMklMatMul(EagerOperation* orig_op,
|
||||
std::unique_ptr<EagerOperation>* mkl_matmul_op);
|
||||
|
||||
// Creates new MKL op for Conv2D, Conv2DBackpropInput and
|
||||
// Conv2DBackpropFilter.
|
||||
static Status CreateMklConv2DOp(
|
||||
@ -60,6 +64,10 @@ class MklEagerOpRewrite : public EagerOpRewrite {
|
||||
|
||||
// Checks whether we can rewrite the op to MKL one or not.
|
||||
bool ShouldRewriteOp(EagerOperation* op, int* op_idx);
|
||||
|
||||
// Default rewrite rule to be used when rewrite should happen without any
|
||||
// restriction.
|
||||
static bool AlwaysRewrite(EagerOperation* op) { return true; }
|
||||
};
|
||||
|
||||
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
|
||||
@ -72,6 +80,7 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
|
||||
{"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
|
||||
mkl_eager_ops_.push_back(
|
||||
{"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
|
||||
mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateMklMatMul});
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::Run(
|
||||
@ -124,6 +133,13 @@ Status MklEagerOpRewrite::SetupNewOp(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::CreateMklMatMul(
|
||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_matmul_op) {
|
||||
const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
|
||||
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_matmul_op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::CreateMklConv2DOp(
|
||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
|
||||
const string mkl_op_name =
|
||||
|
Loading…
Reference in New Issue
Block a user