Merge pull request #42240 from Intel-tensorflow:yimei/fix_mkl_conv2d_bf16_eager_crash

PiperOrigin-RevId: 326244105
Change-Id: Ic3f6777dc8750c7fc1787ec817da79c66d006e88
This commit is contained in:
TensorFlower Gardener 2020-08-12 09:12:48 -07:00
commit 9f86089e45
2 changed files with 55 additions and 62 deletions

View File

@ -207,7 +207,7 @@ bool MklEagerOpRewrite::SlowCheckIfKernelRegistered(string op_name,
DataType dt) {
// Find if the eager op_name exists in mkl_eager_ops_ list.
auto element = mkl_eager_ops_.find(op_name);
if (element != mkl_eager_ops_.end() && dt == DT_FLOAT) {
if (element != mkl_eager_ops_.end()) {
// Eager Op exists. So verify registry and return registered or not.
return (mkl_op_registry::IsMklNameChangeOp(
mkl_op_registry::GetMklEagerOpName(op_name), dt) ||

View File

@ -19,10 +19,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
class EagerOpRewriteTest {
class EagerOpRewriteTest : public ::testing::Test {
public:
EagerOpRewriteTest() {}
@ -68,71 +69,63 @@ class EagerOpRewriteTest {
}
};
TEST(EagerOpRewriteTest, Conv2D) {
const string orig_op_name = "Conv2D";
std::unique_ptr<tensorflow::EagerOperation> orig_op =
EagerOpRewriteTest::CreateOp(orig_op_name);
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
auto orig_op = CreateOp("Conv2D"); \
orig_op->MutableAttrs()->Set("T", T); \
orig_op->MutableAttrs()->Set("padding", "VALID"); \
CheckRewrite(orig_op.get(), "_MklEagerConv2D"); \
}
REGISTER_TEST_ALL_TYPES(Conv2D);
#undef REGISTER_TEST
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
orig_op->MutableAttrs()->Set("padding", "VALID");
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
auto orig_op = CreateOp("Conv2D"); \
orig_op->MutableAttrs()->Set("T", T); \
orig_op->MutableAttrs()->Set("padding", "EXPLICIT"); \
CheckRewrite(orig_op.get(), "Conv2D"); \
}
REGISTER_TEST_ALL_TYPES(Conv2D_Explicit_Padding);
#undef REGISTER_TEST
EagerOpRewriteTest::CheckRewrite(orig_op.get(), "_MklEagerConv2D");
}
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
auto orig_op = CreateOp("Conv2DBackpropInput"); \
orig_op->MutableAttrs()->Set("T", T); \
orig_op->MutableAttrs()->Set("padding", "VALID"); \
CheckRewrite(orig_op.get(), "_MklEagerConv2DBackpropInput"); \
}
REGISTER_TEST_ALL_TYPES(Conv2DBackpropInput);
#undef REGISTER_TEST
TEST(EagerOpRewriteTest, Conv2D_Explicit_Padding) {
const string orig_op_name = "Conv2D";
std::unique_ptr<tensorflow::EagerOperation> orig_op =
EagerOpRewriteTest::CreateOp(orig_op_name);
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
auto orig_op = CreateOp("Conv2DBackpropFilter"); \
orig_op->MutableAttrs()->Set("T", T); \
orig_op->MutableAttrs()->Set("padding", "VALID"); \
CheckRewrite(orig_op.get(), "_MklEagerConv2DBackpropFilter"); \
}
REGISTER_TEST_ALL_TYPES(Conv2DBackpropFilter);
#undef REGISTER_TEST
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
orig_op->MutableAttrs()->Set("padding", "EXPLICIT");
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
auto orig_op = CreateOp("BatchMatMul"); \
orig_op->MutableAttrs()->Set("T", T); \
CheckRewrite(orig_op.get(), "_MklBatchMatMul"); \
}
REGISTER_TEST_ALL_TYPES(BatchMatMul);
#undef REGISTER_TEST
EagerOpRewriteTest::CheckRewrite(orig_op.get(), "Conv2D");
}
TEST(EagerOpRewriteTest, Conv2DBackpropInput) {
const string orig_op_name = "Conv2DBackpropInput";
std::unique_ptr<tensorflow::EagerOperation> orig_op =
EagerOpRewriteTest::CreateOp(orig_op_name);
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
orig_op->MutableAttrs()->Set("padding", "VALID");
EagerOpRewriteTest::CheckRewrite(orig_op.get(),
"_MklEagerConv2DBackpropInput");
}
TEST(EagerOpRewriteTest, Conv2DBackpropFilter) {
const string orig_op_name = "Conv2DBackpropFilter";
std::unique_ptr<tensorflow::EagerOperation> orig_op =
EagerOpRewriteTest::CreateOp(orig_op_name);
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
orig_op->MutableAttrs()->Set("padding", "VALID");
EagerOpRewriteTest::CheckRewrite(orig_op.get(),
"_MklEagerConv2DBackpropFilter");
}
TEST(EagerOpRewriteTest, BatchMatMul) {
const string orig_op_name = "BatchMatMul";
std::unique_ptr<tensorflow::EagerOperation> orig_op =
EagerOpRewriteTest::CreateOp(orig_op_name);
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
EagerOpRewriteTest::CheckRewrite(orig_op.get(), "_MklBatchMatMul");
}
TEST(EagerOpRewriteTest, MatMul) {
const string orig_op_name = "MatMul";
std::unique_ptr<tensorflow::EagerOperation> orig_op =
EagerOpRewriteTest::CreateOp(orig_op_name);
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
EagerOpRewriteTest::CheckRewrite(orig_op.get(), "_MklMatMul");
}
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
auto orig_op = CreateOp("MatMul"); \
orig_op->MutableAttrs()->Set("T", T); \
CheckRewrite(orig_op.get(), "_MklMatMul"); \
}
REGISTER_TEST_ALL_TYPES(MatMul);
#undef REGISTER_TEST
} // namespace tensorflow