Merge pull request #42240 from Intel-tensorflow:yimei/fix_mkl_conv2d_bf16_eager_crash
PiperOrigin-RevId: 326244105 Change-Id: Ic3f6777dc8750c7fc1787ec817da79c66d006e88
This commit is contained in:
commit
9f86089e45
@ -207,7 +207,7 @@ bool MklEagerOpRewrite::SlowCheckIfKernelRegistered(string op_name,
|
||||
DataType dt) {
|
||||
// Find if the eager op_name exists in mkl_eager_ops_ list.
|
||||
auto element = mkl_eager_ops_.find(op_name);
|
||||
if (element != mkl_eager_ops_.end() && dt == DT_FLOAT) {
|
||||
if (element != mkl_eager_ops_.end()) {
|
||||
// Eager Op exists. So verify registry and return registered or not.
|
||||
return (mkl_op_registry::IsMklNameChangeOp(
|
||||
mkl_op_registry::GetMklEagerOpName(op_name), dt) ||
|
||||
|
@ -19,10 +19,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class EagerOpRewriteTest {
|
||||
class EagerOpRewriteTest : public ::testing::Test {
|
||||
public:
|
||||
EagerOpRewriteTest() {}
|
||||
|
||||
@ -68,71 +69,63 @@ class EagerOpRewriteTest {
|
||||
}
|
||||
};
|
||||
|
||||
TEST(EagerOpRewriteTest, Conv2D) {
|
||||
const string orig_op_name = "Conv2D";
|
||||
std::unique_ptr<tensorflow::EagerOperation> orig_op =
|
||||
EagerOpRewriteTest::CreateOp(orig_op_name);
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
|
||||
auto orig_op = CreateOp("Conv2D"); \
|
||||
orig_op->MutableAttrs()->Set("T", T); \
|
||||
orig_op->MutableAttrs()->Set("padding", "VALID"); \
|
||||
CheckRewrite(orig_op.get(), "_MklEagerConv2D"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(Conv2D);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
|
||||
orig_op->MutableAttrs()->Set("padding", "VALID");
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
|
||||
auto orig_op = CreateOp("Conv2D"); \
|
||||
orig_op->MutableAttrs()->Set("T", T); \
|
||||
orig_op->MutableAttrs()->Set("padding", "EXPLICIT"); \
|
||||
CheckRewrite(orig_op.get(), "Conv2D"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(Conv2D_Explicit_Padding);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
EagerOpRewriteTest::CheckRewrite(orig_op.get(), "_MklEagerConv2D");
|
||||
}
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
|
||||
auto orig_op = CreateOp("Conv2DBackpropInput"); \
|
||||
orig_op->MutableAttrs()->Set("T", T); \
|
||||
orig_op->MutableAttrs()->Set("padding", "VALID"); \
|
||||
CheckRewrite(orig_op.get(), "_MklEagerConv2DBackpropInput"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(Conv2DBackpropInput);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
TEST(EagerOpRewriteTest, Conv2D_Explicit_Padding) {
|
||||
const string orig_op_name = "Conv2D";
|
||||
std::unique_ptr<tensorflow::EagerOperation> orig_op =
|
||||
EagerOpRewriteTest::CreateOp(orig_op_name);
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
|
||||
auto orig_op = CreateOp("Conv2DBackpropFilter"); \
|
||||
orig_op->MutableAttrs()->Set("T", T); \
|
||||
orig_op->MutableAttrs()->Set("padding", "VALID"); \
|
||||
CheckRewrite(orig_op.get(), "_MklEagerConv2DBackpropFilter"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(Conv2DBackpropFilter);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
|
||||
orig_op->MutableAttrs()->Set("padding", "EXPLICIT");
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
|
||||
auto orig_op = CreateOp("BatchMatMul"); \
|
||||
orig_op->MutableAttrs()->Set("T", T); \
|
||||
CheckRewrite(orig_op.get(), "_MklBatchMatMul"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(BatchMatMul);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
EagerOpRewriteTest::CheckRewrite(orig_op.get(), "Conv2D");
|
||||
}
|
||||
|
||||
TEST(EagerOpRewriteTest, Conv2DBackpropInput) {
|
||||
const string orig_op_name = "Conv2DBackpropInput";
|
||||
std::unique_ptr<tensorflow::EagerOperation> orig_op =
|
||||
EagerOpRewriteTest::CreateOp(orig_op_name);
|
||||
|
||||
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
|
||||
orig_op->MutableAttrs()->Set("padding", "VALID");
|
||||
|
||||
EagerOpRewriteTest::CheckRewrite(orig_op.get(),
|
||||
"_MklEagerConv2DBackpropInput");
|
||||
}
|
||||
|
||||
TEST(EagerOpRewriteTest, Conv2DBackpropFilter) {
|
||||
const string orig_op_name = "Conv2DBackpropFilter";
|
||||
std::unique_ptr<tensorflow::EagerOperation> orig_op =
|
||||
EagerOpRewriteTest::CreateOp(orig_op_name);
|
||||
|
||||
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
|
||||
orig_op->MutableAttrs()->Set("padding", "VALID");
|
||||
|
||||
EagerOpRewriteTest::CheckRewrite(orig_op.get(),
|
||||
"_MklEagerConv2DBackpropFilter");
|
||||
}
|
||||
|
||||
TEST(EagerOpRewriteTest, BatchMatMul) {
|
||||
const string orig_op_name = "BatchMatMul";
|
||||
std::unique_ptr<tensorflow::EagerOperation> orig_op =
|
||||
EagerOpRewriteTest::CreateOp(orig_op_name);
|
||||
|
||||
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
|
||||
|
||||
EagerOpRewriteTest::CheckRewrite(orig_op.get(), "_MklBatchMatMul");
|
||||
}
|
||||
|
||||
TEST(EagerOpRewriteTest, MatMul) {
|
||||
const string orig_op_name = "MatMul";
|
||||
std::unique_ptr<tensorflow::EagerOperation> orig_op =
|
||||
EagerOpRewriteTest::CreateOp(orig_op_name);
|
||||
|
||||
orig_op->MutableAttrs()->Set("T", DT_FLOAT);
|
||||
|
||||
EagerOpRewriteTest::CheckRewrite(orig_op.get(), "_MklMatMul");
|
||||
}
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(EagerOpRewriteTest, NAME##_##T) { \
|
||||
auto orig_op = CreateOp("MatMul"); \
|
||||
orig_op->MutableAttrs()->Set("T", T); \
|
||||
CheckRewrite(orig_op.get(), "_MklMatMul"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(MatMul);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user