Merge pull request #33540 from Intel-tensorflow:niroop/eager
PiperOrigin-RevId: 276530953 Change-Id: I9f0fd0166e50266561152da036e581b253c3bc73
This commit is contained in:
commit
b160ffcecf
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||||
@ -25,12 +28,18 @@ namespace tensorflow {
|
|||||||
class MklEagerOpRewrite : public EagerOpRewrite {
|
class MklEagerOpRewrite : public EagerOpRewrite {
|
||||||
public:
|
public:
|
||||||
MklEagerOpRewrite(string name, string file, string line);
|
MklEagerOpRewrite(string name, string file, string line);
|
||||||
typedef struct {
|
struct MklEagerOp {
|
||||||
string op_name;
|
string op_name;
|
||||||
std::function<bool(EagerOperation*)> RewriteRule;
|
std::function<bool(EagerOperation*)> RewriteRule;
|
||||||
std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)>
|
std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)>
|
||||||
CreateMklOp;
|
CreateMklOp;
|
||||||
} MklEagerOp;
|
|
||||||
|
// Overload Operator== for std::find comparison
|
||||||
|
// used by SlowCheckIfKernelRegistered.
|
||||||
|
bool operator==(const MklEagerOp& rhs) const {
|
||||||
|
return (op_name.compare(rhs.op_name) == 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// TODO(intel-tf): refactor with unordered_map;
|
// TODO(intel-tf): refactor with unordered_map;
|
||||||
@ -69,6 +78,16 @@ class MklEagerOpRewrite : public EagerOpRewrite {
|
|||||||
// Default rewrite rule to be used when rewrite should happen without any
|
// Default rewrite rule to be used when rewrite should happen without any
|
||||||
// restriction.
|
// restriction.
|
||||||
static bool AlwaysRewrite(EagerOperation* op) { return true; }
|
static bool AlwaysRewrite(EagerOperation* op) { return true; }
|
||||||
|
|
||||||
|
// Checks if kernel is registered for a particular op.
|
||||||
|
bool FastCheckIfKernelRegistered(string op_name, DataType dt);
|
||||||
|
|
||||||
|
// This is called by FastCheckIfKernelRegistered once per unique op name and
|
||||||
|
// data type.
|
||||||
|
bool SlowCheckIfKernelRegistered(string op_name, DataType dt);
|
||||||
|
|
||||||
|
// map used by FastCheckIfKernelRegistered.
|
||||||
|
std::unordered_map<string, bool> registered_kernels_map;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
|
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
|
||||||
@ -162,10 +181,8 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Check if we have registered MKL kernel for this op.
|
// Check if we have registered MKL kernel for this op.
|
||||||
if (!mkl_op_registry::IsMklNameChangeOp(
|
bool kernel_found = FastCheckIfKernelRegistered(op->Name(), data_type);
|
||||||
mkl_op_registry::GetMklEagerOpName(op->Name()), data_type) &&
|
if (!kernel_found) {
|
||||||
!mkl_op_registry::IsMklNameChangeOp(
|
|
||||||
mkl_op_registry::GetMklOpName(op->Name()), data_type)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,6 +198,44 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool MklEagerOpRewrite::FastCheckIfKernelRegistered(string op_name,
|
||||||
|
DataType dt) {
|
||||||
|
// Check for kernel registration only once per op name and data type
|
||||||
|
// for performance reasons.
|
||||||
|
string registered_kernels_key = op_name + std::to_string(dt);
|
||||||
|
auto kernel_element = registered_kernels_map.find(registered_kernels_key);
|
||||||
|
bool kernel_registered = false;
|
||||||
|
if (kernel_element == registered_kernels_map.end()) {
|
||||||
|
// Kernel registration is not verified even once yet.
|
||||||
|
// So verify and store registration.
|
||||||
|
kernel_registered = SlowCheckIfKernelRegistered(op_name, dt);
|
||||||
|
registered_kernels_map.insert(
|
||||||
|
std::make_pair(registered_kernels_key, kernel_registered));
|
||||||
|
} else {
|
||||||
|
// Kernel is visited atleast once. return stored registration result.
|
||||||
|
kernel_registered = kernel_element->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kernel_registered;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MklEagerOpRewrite::SlowCheckIfKernelRegistered(string op_name,
|
||||||
|
DataType dt) {
|
||||||
|
MklEagerOp op_key = {op_name, AlwaysRewrite, CreateGenericMklOp};
|
||||||
|
// Find if the eager op_name exists in vector list mkl_eager_ops_.
|
||||||
|
auto element =
|
||||||
|
std::find(std::begin(mkl_eager_ops_), std::end(mkl_eager_ops_), op_key);
|
||||||
|
if (element != std::end(mkl_eager_ops_) && dt == DT_FLOAT) {
|
||||||
|
// Eager Op exists. So verify registry and return registered or not.
|
||||||
|
return (mkl_op_registry::IsMklNameChangeOp(
|
||||||
|
mkl_op_registry::GetMklEagerOpName(op_name), dt) ||
|
||||||
|
mkl_op_registry::IsMklNameChangeOp(
|
||||||
|
mkl_op_registry::GetMklOpName(op_name), dt));
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status MklEagerOpRewrite::RewriteToMklOp(
|
Status MklEagerOpRewrite::RewriteToMklOp(
|
||||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op,
|
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op,
|
||||||
const int op_idx) {
|
const int op_idx) {
|
||||||
|
Loading…
Reference in New Issue
Block a user