Fix --config=mkl_threadpool build errors.

PiperOrigin-RevId: 356614919
Change-Id: I6ac9c0639725f5b4472b4141011989a4e99311b1
This commit is contained in:
Penporn Koanantakool 2021-02-09 16:07:54 -08:00 committed by TensorFlower Gardener
parent d7defaef6d
commit 1f5361389a
12 changed files with 21 additions and 14 deletions

View File

@ -576,10 +576,14 @@ tf_cc_test_mkl(
],
deps = [
":core",
":eager_op_rewrite_registry",
":mkl_eager_op_rewrite",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:device_mgr",
],
)

View File

@ -56,8 +56,9 @@ class EagerOpRewriteTest : public ::testing::Test {
// Validates the result of MKL eager rewrite.
void CheckRewrite(EagerOperation* orig_op, string expected_op_name) {
std::unique_ptr<tensorflow::EagerOperation> out_op;
EagerOpRewriteRegistry::Global()->RunRewrite(
EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, &out_op);
EXPECT_EQ(Status::OK(),
EagerOpRewriteRegistry::Global()->RunRewrite(
EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, &out_op));
// actual_op_name is same as original op name if rewrite didn't happen.
string actual_op_name = orig_op->Name();
@ -159,11 +160,12 @@ REGISTER_TEST_ALL_TYPES(MostOps_Positive);
}
#define DATA_FORMAT "NCDHW"
REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_1);
#undef DATA_FORMAT
#define DATA_FORMAT "NDHWC"
REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_2);
#undef DATA_FORMAT
#undef REGISTER_TEST
} // namespace tensorflow

View File

@ -108,6 +108,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/framework:tensor_testutil",
"//tensorflow/core/graph:mkl_graph_util",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",

View File

@ -654,6 +654,7 @@ tf_cuda_cc_test(
"//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/utils:grappler_test",
"//tensorflow/core/lib/random",
],
)

View File

@ -1201,11 +1201,9 @@ using FusedBiasAddDataTypes = ::testing::Types<float, double>;
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedConv2DWithBiasOpTest,
FusedBiasAddDataTypes);
#ifndef INTEL_MKL
using FusedBatchNormDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedConv2DWithBatchNormOpTest,
FusedBatchNormDataTypes);
#endif
#endif // TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -221,7 +221,7 @@ tf_cc_test_mkl(
size = "small",
srcs = ["mkl_relu_op_test.cc"],
linkstatic = 1, # Fixes dyld error on MacOS.
deps = MKL_TEST_DEPS,
deps = ["@com_google_absl//absl/strings"] + MKL_TEST_DEPS,
)
tf_mkl_kernel_library(
@ -330,6 +330,7 @@ tf_cc_test_mkl(
":mkl_conv_op",
":mkl_fused_batch_norm_op",
"//tensorflow/core:direct_session",
"//tensorflow/core/graph:mkl_graph_util",
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
] + MKL_TEST_DEPS,
)
@ -420,6 +421,7 @@ tf_cc_test_mkl(
":mkl_tfconv_op",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:direct_session",
"//tensorflow/core/graph:mkl_graph_util",
"//tensorflow/core/kernels:bias_op",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:depthwise_conv_op",

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/core/public/session.h"
#if defined(INTEL_MKL_DNN_ONLY)
#include "mkldnn.hpp"
#include "tensorflow/core/util/mkl_util.h"
#endif

View File

@ -72,7 +72,8 @@ Tensor CreateMklInput() {
mkl_shape.SetTfLayout(4, {1, 2, 2, 2}, MKL_TENSOR_FORMAT_NHWC);
DataType dtype = DataTypeToEnum<uint8>::v();
Tensor mkl_tensor(dtype, {mkl_shape.GetSerializeBufferSize()});
Tensor mkl_tensor(dtype,
{static_cast<int64>(mkl_shape.GetSerializeBufferSize())});
mkl_shape.SerializeMklDnnShape(
mkl_tensor.flat<uint8>().data(),
mkl_tensor.flat<uint8>().size() * sizeof(uint8));

View File

@ -320,7 +320,7 @@ class MklFusedConv2DOpTest : public OpsTestBase {
if (!NativeFormatEnabled()) {
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
for (const Tensor& arg : args)
for (int i = 0; i < num_args; ++i)
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
}
TF_ASSERT_OK(RunOpKernel());
@ -641,7 +641,7 @@ class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
if (!NativeFormatEnabled()) {
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
for (const Tensor& arg : args)
for (int i = 0; i < num_args; ++i)
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
}
TF_ASSERT_OK(RunOpKernel());
@ -1038,7 +1038,7 @@ class MklFusedMatMulOpTest : public OpsTestBase {
// Add MKL meta input for input, filter and bias.
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
for (const Tensor& arg : args)
for (int i = 0; i < num_args; ++i)
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
}

View File

@ -15,7 +15,6 @@ limitations under the License.
#ifdef INTEL_MKL
#include "mkldnn.hpp"
#include "absl/strings/match.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/nn_ops.h"

View File

@ -34,11 +34,10 @@ namespace tensorflow {
class MklRequantizatedOpsTest : public OpsTestBase {};
class MklRequantizatedOpsTestHelper : public OpsTestBase {
class MklRequantizatedOpsTestHelper {
public:
void Setup(Tensor &input_tensor_qint32, float &range_weights_ch1,
float &range_weights_ch2);
void TestBody() {}
};
void MklRequantizatedOpsTestHelper::Setup(Tensor &input_tensor_qint32,

View File

@ -73,6 +73,7 @@ cc_library(
visibility = [
"//tensorflow/c/eager:__pkg__",
"//tensorflow/core:__pkg__",
"//tensorflow/core/grappler/optimizers:__pkg__",
"//tensorflow/core/lib/core:__pkg__",
"//tensorflow/core/lib/gtl:__pkg__",
"//tensorflow/core/lib/io:__pkg__",