Merge pull request #33540 from Intel-tensorflow:niroop/eager

PiperOrigin-RevId: 276530953
Change-Id: I9f0fd0166e50266561152da036e581b253c3bc73
This commit is contained in:
TensorFlower Gardener 2019-10-24 12:06:25 -07:00
commit b160ffcecf

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifdef INTEL_MKL #ifdef INTEL_MKL
#include <string>
#include <unordered_map>
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/graph/mkl_layout_pass.h"
@ -25,12 +28,18 @@ namespace tensorflow {
class MklEagerOpRewrite : public EagerOpRewrite { class MklEagerOpRewrite : public EagerOpRewrite {
public: public:
MklEagerOpRewrite(string name, string file, string line); MklEagerOpRewrite(string name, string file, string line);
typedef struct { struct MklEagerOp {
string op_name; string op_name;
std::function<bool(EagerOperation*)> RewriteRule; std::function<bool(EagerOperation*)> RewriteRule;
std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)> std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)>
CreateMklOp; CreateMklOp;
} MklEagerOp;
// Overload Operator== for std::find comparison
// used by SlowCheckIfKernelRegistered.
bool operator==(const MklEagerOp& rhs) const {
return (op_name.compare(rhs.op_name) == 0);
}
};
private: private:
// TODO(intel-tf): refactor with unordered_map; // TODO(intel-tf): refactor with unordered_map;
@ -69,6 +78,16 @@ class MklEagerOpRewrite : public EagerOpRewrite {
// Default rewrite rule to be used when rewrite should happen without any // Default rewrite rule to be used when rewrite should happen without any
// restriction. // restriction.
static bool AlwaysRewrite(EagerOperation* op) { return true; } static bool AlwaysRewrite(EagerOperation* op) { return true; }
// Checks if kernel is registered for a particular op.
bool FastCheckIfKernelRegistered(string op_name, DataType dt);
// This is called by FastCheckIfKernelRegistered once per unique op name and
// data type.
bool SlowCheckIfKernelRegistered(string op_name, DataType dt);
// map used by FastCheckIfKernelRegistered.
std::unordered_map<string, bool> registered_kernels_map;
}; };
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite); REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
@ -162,10 +181,8 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
return false; return false;
} }
// Check if we have registered MKL kernel for this op. // Check if we have registered MKL kernel for this op.
if (!mkl_op_registry::IsMklNameChangeOp( bool kernel_found = FastCheckIfKernelRegistered(op->Name(), data_type);
mkl_op_registry::GetMklEagerOpName(op->Name()), data_type) && if (!kernel_found) {
!mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklOpName(op->Name()), data_type)) {
return false; return false;
} }
@ -181,6 +198,44 @@ bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
return false; return false;
} }
bool MklEagerOpRewrite::FastCheckIfKernelRegistered(string op_name,
DataType dt) {
// Check for kernel registration only once per op name and data type
// for performance reasons.
string registered_kernels_key = op_name + std::to_string(dt);
auto kernel_element = registered_kernels_map.find(registered_kernels_key);
bool kernel_registered = false;
if (kernel_element == registered_kernels_map.end()) {
// Kernel registration is not verified even once yet.
// So verify and store registration.
kernel_registered = SlowCheckIfKernelRegistered(op_name, dt);
registered_kernels_map.insert(
std::make_pair(registered_kernels_key, kernel_registered));
} else {
// Kernel is visited atleast once. return stored registration result.
kernel_registered = kernel_element->second;
}
return kernel_registered;
}
bool MklEagerOpRewrite::SlowCheckIfKernelRegistered(string op_name,
DataType dt) {
MklEagerOp op_key = {op_name, AlwaysRewrite, CreateGenericMklOp};
// Find if the eager op_name exists in vector list mkl_eager_ops_.
auto element =
std::find(std::begin(mkl_eager_ops_), std::end(mkl_eager_ops_), op_key);
if (element != std::end(mkl_eager_ops_) && dt == DT_FLOAT) {
// Eager Op exists. So verify registry and return registered or not.
return (mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklEagerOpName(op_name), dt) ||
mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklOpName(op_name), dt));
} else {
return false;
}
}
Status MklEagerOpRewrite::RewriteToMklOp( Status MklEagerOpRewrite::RewriteToMklOp(
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op, EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op,
const int op_idx) { const int op_idx) {