Merge pull request #33540 from Intel-tensorflow:niroop/eager
PiperOrigin-RevId: 276530953 Change-Id: I9f0fd0166e50266561152da036e581b253c3bc73
This commit is contained in:
commit
b160ffcecf
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifdef INTEL_MKL
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||
@ -25,12 +28,18 @@ namespace tensorflow {
|
||||
class MklEagerOpRewrite : public EagerOpRewrite {
|
||||
public:
|
||||
MklEagerOpRewrite(string name, string file, string line);
|
||||
typedef struct {
|
||||
struct MklEagerOp {
|
||||
string op_name;
|
||||
std::function<bool(EagerOperation*)> RewriteRule;
|
||||
std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)>
|
||||
CreateMklOp;
|
||||
} MklEagerOp;
|
||||
|
||||
// Overload Operator== for std::find comparison
|
||||
// used by SlowCheckIfKernelRegistered.
|
||||
bool operator==(const MklEagerOp& rhs) const {
|
||||
return (op_name.compare(rhs.op_name) == 0);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
// TODO(intel-tf): refactor with unordered_map;
|
||||
@ -69,6 +78,16 @@ class MklEagerOpRewrite : public EagerOpRewrite {
|
||||
// Default rewrite rule to be used when rewrite should happen without any
|
||||
// restriction.
|
||||
static bool AlwaysRewrite(EagerOperation* op) { return true; }
|
||||
|
||||
// Checks if kernel is registered for a particular op.
|
||||
bool FastCheckIfKernelRegistered(string op_name, DataType dt);
|
||||
|
||||
// This is called by FastCheckIfKernelRegistered once per unique op name and
|
||||
// data type.
|
||||
bool SlowCheckIfKernelRegistered(string op_name, DataType dt);
|
||||
|
||||
// map used by FastCheckIfKernelRegistered.
|
||||
std::unordered_map<string, bool> registered_kernels_map;
|
||||
};
|
||||
|
||||
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
|
||||
@ -162,10 +181,8 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
|
||||
return false;
|
||||
}
|
||||
// Check if we have registered MKL kernel for this op.
|
||||
if (!mkl_op_registry::IsMklNameChangeOp(
|
||||
mkl_op_registry::GetMklEagerOpName(op->Name()), data_type) &&
|
||||
!mkl_op_registry::IsMklNameChangeOp(
|
||||
mkl_op_registry::GetMklOpName(op->Name()), data_type)) {
|
||||
bool kernel_found = FastCheckIfKernelRegistered(op->Name(), data_type);
|
||||
if (!kernel_found) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -181,6 +198,44 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MklEagerOpRewrite::FastCheckIfKernelRegistered(string op_name,
|
||||
DataType dt) {
|
||||
// Check for kernel registration only once per op name and data type
|
||||
// for performance reasons.
|
||||
string registered_kernels_key = op_name + std::to_string(dt);
|
||||
auto kernel_element = registered_kernels_map.find(registered_kernels_key);
|
||||
bool kernel_registered = false;
|
||||
if (kernel_element == registered_kernels_map.end()) {
|
||||
// Kernel registration is not verified even once yet.
|
||||
// So verify and store registration.
|
||||
kernel_registered = SlowCheckIfKernelRegistered(op_name, dt);
|
||||
registered_kernels_map.insert(
|
||||
std::make_pair(registered_kernels_key, kernel_registered));
|
||||
} else {
|
||||
// Kernel is visited atleast once. return stored registration result.
|
||||
kernel_registered = kernel_element->second;
|
||||
}
|
||||
|
||||
return kernel_registered;
|
||||
}
|
||||
|
||||
bool MklEagerOpRewrite::SlowCheckIfKernelRegistered(string op_name,
|
||||
DataType dt) {
|
||||
MklEagerOp op_key = {op_name, AlwaysRewrite, CreateGenericMklOp};
|
||||
// Find if the eager op_name exists in vector list mkl_eager_ops_.
|
||||
auto element =
|
||||
std::find(std::begin(mkl_eager_ops_), std::end(mkl_eager_ops_), op_key);
|
||||
if (element != std::end(mkl_eager_ops_) && dt == DT_FLOAT) {
|
||||
// Eager Op exists. So verify registry and return registered or not.
|
||||
return (mkl_op_registry::IsMklNameChangeOp(
|
||||
mkl_op_registry::GetMklEagerOpName(op_name), dt) ||
|
||||
mkl_op_registry::IsMklNameChangeOp(
|
||||
mkl_op_registry::GetMklOpName(op_name), dt));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::RewriteToMklOp(
|
||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op,
|
||||
const int op_idx) {
|
||||
|
Loading…
Reference in New Issue
Block a user