Enabling eager rewriting for MKL matmul
This commit is contained in:
parent
135639d075
commit
b505fde34f
@ -45,6 +45,10 @@ class MklEagerOpRewrite : public EagerOpRewrite {
|
|||||||
static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
|
static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
|
||||||
std::unique_ptr<EagerOperation>* new_mkl_op);
|
std::unique_ptr<EagerOperation>* new_mkl_op);
|
||||||
|
|
||||||
|
// Creates new MKL op for MatMul
|
||||||
|
static Status CreateMklMatMul(EagerOperation* orig_op,
|
||||||
|
std::unique_ptr<EagerOperation>* mkl_matmul_op);
|
||||||
|
|
||||||
// Creates new MKL op for Conv2D, Conv2DBackpropInput and
|
// Creates new MKL op for Conv2D, Conv2DBackpropInput and
|
||||||
// Conv2DBackpropFilter.
|
// Conv2DBackpropFilter.
|
||||||
static Status CreateMklConv2DOp(
|
static Status CreateMklConv2DOp(
|
||||||
@ -60,6 +64,10 @@ class MklEagerOpRewrite : public EagerOpRewrite {
|
|||||||
|
|
||||||
// Checks whether we can rewrite the op to MKL one or not.
|
// Checks whether we can rewrite the op to MKL one or not.
|
||||||
bool ShouldRewriteOp(EagerOperation* op, int* op_idx);
|
bool ShouldRewriteOp(EagerOperation* op, int* op_idx);
|
||||||
|
|
||||||
|
// Default rewrite rule to be used when rewrite should happen without any
|
||||||
|
// restriction.
|
||||||
|
static bool AlwaysRewrite(EagerOperation* op) { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
|
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
|
||||||
@ -72,6 +80,7 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
|
|||||||
{"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
|
{"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
|
||||||
mkl_eager_ops_.push_back(
|
mkl_eager_ops_.push_back(
|
||||||
{"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
|
{"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
|
||||||
|
mkl_eager_ops_.push_back({"MatMul", AlwaysRewrite, CreateMklMatMul});
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MklEagerOpRewrite::Run(
|
Status MklEagerOpRewrite::Run(
|
||||||
@ -124,6 +133,13 @@ Status MklEagerOpRewrite::SetupNewOp(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MklEagerOpRewrite::CreateMklMatMul(
|
||||||
|
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_matmul_op) {
|
||||||
|
const string mkl_op_name = mkl_op_registry::GetMklOpName(orig_op->Name());
|
||||||
|
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_matmul_op));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status MklEagerOpRewrite::CreateMklConv2DOp(
|
Status MklEagerOpRewrite::CreateMklConv2DOp(
|
||||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
|
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
|
||||||
const string mkl_op_name =
|
const string mkl_op_name =
|
||||||
|
Loading…
Reference in New Issue
Block a user