Fix --config=mkl_threadpool build errors.
PiperOrigin-RevId: 356614919 Change-Id: I6ac9c0639725f5b4472b4141011989a4e99311b1
This commit is contained in:
parent
d7defaef6d
commit
1f5361389a
@ -576,10 +576,14 @@ tf_cc_test_mkl(
|
||||
],
|
||||
deps = [
|
||||
":core",
|
||||
":eager_op_rewrite_registry",
|
||||
":mkl_eager_op_rewrite",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:device_mgr",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -56,8 +56,9 @@ class EagerOpRewriteTest : public ::testing::Test {
|
||||
// Validates the result of MKL eager rewrite.
|
||||
void CheckRewrite(EagerOperation* orig_op, string expected_op_name) {
|
||||
std::unique_ptr<tensorflow::EagerOperation> out_op;
|
||||
EagerOpRewriteRegistry::Global()->RunRewrite(
|
||||
EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, &out_op);
|
||||
EXPECT_EQ(Status::OK(),
|
||||
EagerOpRewriteRegistry::Global()->RunRewrite(
|
||||
EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, &out_op));
|
||||
|
||||
// actual_op_name is same as original op name if rewrite didn't happen.
|
||||
string actual_op_name = orig_op->Name();
|
||||
@ -159,11 +160,12 @@ REGISTER_TEST_ALL_TYPES(MostOps_Positive);
|
||||
}
|
||||
#define DATA_FORMAT "NCDHW"
|
||||
REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_1);
|
||||
#undef DATA_FORMAT
|
||||
|
||||
#define DATA_FORMAT "NDHWC"
|
||||
REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_2);
|
||||
|
||||
#undef DATA_FORMAT
|
||||
|
||||
#undef REGISTER_TEST
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -108,6 +108,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/framework:tensor_testutil",
|
||||
"//tensorflow/core/graph:mkl_graph_util",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/clusters:single_machine",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
|
@ -654,6 +654,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/grappler/clusters:single_machine",
|
||||
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
"//tensorflow/core/lib/random",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1201,11 +1201,9 @@ using FusedBiasAddDataTypes = ::testing::Types<float, double>;
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedConv2DWithBiasOpTest,
|
||||
FusedBiasAddDataTypes);
|
||||
|
||||
#ifndef INTEL_MKL
|
||||
using FusedBatchNormDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedConv2DWithBatchNormOpTest,
|
||||
FusedBatchNormDataTypes);
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_USE_ROCM
|
||||
} // namespace tensorflow
|
||||
|
@ -221,7 +221,7 @@ tf_cc_test_mkl(
|
||||
size = "small",
|
||||
srcs = ["mkl_relu_op_test.cc"],
|
||||
linkstatic = 1, # Fixes dyld error on MacOS.
|
||||
deps = MKL_TEST_DEPS,
|
||||
deps = ["@com_google_absl//absl/strings"] + MKL_TEST_DEPS,
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
@ -330,6 +330,7 @@ tf_cc_test_mkl(
|
||||
":mkl_conv_op",
|
||||
":mkl_fused_batch_norm_op",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core/graph:mkl_graph_util",
|
||||
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
|
||||
] + MKL_TEST_DEPS,
|
||||
)
|
||||
@ -420,6 +421,7 @@ tf_cc_test_mkl(
|
||||
":mkl_tfconv_op",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core/graph:mkl_graph_util",
|
||||
"//tensorflow/core/kernels:bias_op",
|
||||
"//tensorflow/core/kernels:conv_ops",
|
||||
"//tensorflow/core/kernels:depthwise_conv_op",
|
||||
|
@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
#if defined(INTEL_MKL_DNN_ONLY)
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#endif
|
||||
|
||||
|
@ -72,7 +72,8 @@ Tensor CreateMklInput() {
|
||||
mkl_shape.SetTfLayout(4, {1, 2, 2, 2}, MKL_TENSOR_FORMAT_NHWC);
|
||||
|
||||
DataType dtype = DataTypeToEnum<uint8>::v();
|
||||
Tensor mkl_tensor(dtype, {mkl_shape.GetSerializeBufferSize()});
|
||||
Tensor mkl_tensor(dtype,
|
||||
{static_cast<int64>(mkl_shape.GetSerializeBufferSize())});
|
||||
mkl_shape.SerializeMklDnnShape(
|
||||
mkl_tensor.flat<uint8>().data(),
|
||||
mkl_tensor.flat<uint8>().size() * sizeof(uint8));
|
||||
|
@ -320,7 +320,7 @@ class MklFusedConv2DOpTest : public OpsTestBase {
|
||||
if (!NativeFormatEnabled()) {
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
for (const Tensor& arg : args)
|
||||
for (int i = 0; i < num_args; ++i)
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
}
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -641,7 +641,7 @@ class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
|
||||
if (!NativeFormatEnabled()) {
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
for (const Tensor& arg : args)
|
||||
for (int i = 0; i < num_args; ++i)
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
}
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -1038,7 +1038,7 @@ class MklFusedMatMulOpTest : public OpsTestBase {
|
||||
// Add MKL meta input for input, filter and bias.
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
for (const Tensor& arg : args)
|
||||
for (int i = 0; i < num_args; ++i)
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
|
@ -34,11 +34,10 @@ namespace tensorflow {
|
||||
|
||||
class MklRequantizatedOpsTest : public OpsTestBase {};
|
||||
|
||||
class MklRequantizatedOpsTestHelper : public OpsTestBase {
|
||||
class MklRequantizatedOpsTestHelper {
|
||||
public:
|
||||
void Setup(Tensor &input_tensor_qint32, float &range_weights_ch1,
|
||||
float &range_weights_ch2);
|
||||
void TestBody() {}
|
||||
};
|
||||
|
||||
void MklRequantizatedOpsTestHelper::Setup(Tensor &input_tensor_qint32,
|
||||
|
@ -73,6 +73,7 @@ cc_library(
|
||||
visibility = [
|
||||
"//tensorflow/c/eager:__pkg__",
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/core/grappler/optimizers:__pkg__",
|
||||
"//tensorflow/core/lib/core:__pkg__",
|
||||
"//tensorflow/core/lib/gtl:__pkg__",
|
||||
"//tensorflow/core/lib/io:__pkg__",
|
||||
|
Loading…
Reference in New Issue
Block a user