Merge pull request #41476 from Intel-tensorflow:nhasabni/auto_mp_ut_fix
PiperOrigin-RevId: 333714322 Change-Id: Ieaccd7ba1829d21e1a83aaad3598c00acccdfd52
This commit is contained in:
commit
40c3e9fc61
@ -318,6 +318,7 @@ tf_cc_test_mkl(
|
||||
srcs = ["mkl_fused_batch_norm_op_test.cc"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":mkl_conv_op",
|
||||
":mkl_fused_batch_norm_op",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
|
||||
|
@ -1327,6 +1327,35 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
||||
? dnn_shape_diff_dst.GetMklLayout()
|
||||
: memory::desc(diff_dst_dims, MklDnnType<T>(), dnn_fmt);
|
||||
|
||||
MklDnnData<T> reorder_src(&cpu_engine_);
|
||||
MklDnnData<T> reorder_diff_dst(&cpu_engine_);
|
||||
T* diff_dst_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
|
||||
T* src_data =
|
||||
static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// MKL-DNN requires src and diff_dst to be in same memory layout, either
|
||||
// blocked or native format. If these inputs are in different formats,
|
||||
// convert the one in native format to blocked format as MKL-DNN gives
|
||||
// better performance for blocked format.
|
||||
if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
|
||||
reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
reorder_diff_dst.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(src_md, cpu_engine_), context);
|
||||
diff_dst_md = src_md;
|
||||
diff_dst_data =
|
||||
static_cast<T*>(reorder_diff_dst.GetOpMem().get_data_handle());
|
||||
} else if (!dnn_shape_src.IsMklTensor() &&
|
||||
dnn_shape_diff_dst.IsMklTensor()) {
|
||||
reorder_src.SetUsrMem(src_md, &src_tensor);
|
||||
reorder_src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(diff_dst_md, cpu_engine_), context);
|
||||
src_md = diff_dst_md;
|
||||
src_data = static_cast<T*>(reorder_src.GetOpMem().get_data_handle());
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// weights -- MKL DNN packs scales/ shifts as weights in order
|
||||
// of scale, ..., scale, shift, ...., shift
|
||||
weights.AllocateBuffer(2 * depth_ * sizeof(U));
|
||||
@ -1350,20 +1379,24 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
||||
MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd =
|
||||
MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams);
|
||||
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
|
||||
// Check if diff_dst input needs to be reordered
|
||||
std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd();
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.SetUsrMem(diff_dst_md, diff_dst_data);
|
||||
diff_dst.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(bn_bwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
diff_dst_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
|
||||
}
|
||||
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, bn_bwd_pd, bn_bwd)) {
|
||||
src.SetUsrMem(src_md, src_data);
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(bn_bwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
|
||||
}
|
||||
|
||||
// Indices of output tensors
|
||||
|
@ -46,6 +46,12 @@ using GraphRunner = std::function<void(
|
||||
const float exponential_avg_factor, const bool is_training, Tensor* output,
|
||||
Tensor* batch_mean, Tensor* batch_var)>;
|
||||
|
||||
using GraphRunnerGrad = std::function<void(
|
||||
const Tensor& input, const Tensor& filter, const Tensor& y_backprop,
|
||||
const Tensor& scale, const Tensor& mean, const Tensor& variance,
|
||||
const Tensor& res_sp3, Tensor* output, Tensor* scale_backprop,
|
||||
Tensor* offset_backprop, bool disable_grappler_opts)>;
|
||||
|
||||
template <typename T>
|
||||
class CommonTestUtilities : public OpsTestBase {
|
||||
public:
|
||||
@ -118,10 +124,99 @@ class CommonTestUtilities : public OpsTestBase {
|
||||
test::ExpectClose(batch_var, mkl_batch_var, 1e-5);
|
||||
}
|
||||
|
||||
static void VerifyTensorsCloseForGrad(const float epsilon,
|
||||
const GraphRunnerGrad& run,
|
||||
const GraphRunnerGrad& run_mkl) {
|
||||
int batch = 2;
|
||||
int height = 8;
|
||||
int width = 8;
|
||||
int depth = 1;
|
||||
int filter_height = 3;
|
||||
int filter_width = 3;
|
||||
int in_channels = 1;
|
||||
int out_channels = 6;
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
|
||||
Tensor input(dtype, {batch, height, width, depth});
|
||||
input.flat<T>() = input.flat<T>().template setRandom<random_gen_>();
|
||||
Tensor filter(dtype,
|
||||
{filter_height, filter_width, in_channels, out_channels});
|
||||
filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
|
||||
|
||||
Tensor y_backprop(dtype, {batch, height, width, out_channels});
|
||||
y_backprop.flat<T>() =
|
||||
y_backprop.flat<T>().template setRandom<random_gen_>();
|
||||
Tensor scale(dtype, {out_channels});
|
||||
scale.flat<T>() = scale.flat<T>().template setRandom<random_gen_>();
|
||||
Tensor mean(dtype, {out_channels});
|
||||
mean.flat<T>() = mean.flat<T>().template setRandom<random_gen_>();
|
||||
Tensor variance(dtype, {out_channels});
|
||||
variance.flat<T>() =
|
||||
variance.flat<T>().template setRandom<random_gen_>().abs();
|
||||
Tensor res_sp3(dtype, {out_channels});
|
||||
res_sp3.flat<T>() =
|
||||
res_sp3.flat<T>().template setRandom<random_gen_>().abs();
|
||||
|
||||
Tensor output;
|
||||
Tensor scale_backprop;
|
||||
Tensor offset_backprop;
|
||||
Tensor mkl_output;
|
||||
Tensor mkl_scale_backprop;
|
||||
Tensor mkl_offset_backprop;
|
||||
|
||||
run(input, filter, y_backprop, scale, mean, variance, res_sp3, &output,
|
||||
&scale_backprop, &offset_backprop, epsilon);
|
||||
|
||||
run_mkl(input, filter, y_backprop, scale, mean, variance, res_sp3,
|
||||
&mkl_output, &mkl_scale_backprop, &mkl_offset_backprop, epsilon);
|
||||
|
||||
ASSERT_EQ(output.dtype(), mkl_output.dtype());
|
||||
ASSERT_EQ(output.shape(), mkl_output.shape());
|
||||
ASSERT_EQ(scale_backprop.dtype(), mkl_scale_backprop.dtype());
|
||||
ASSERT_EQ(scale_backprop.shape(), mkl_scale_backprop.shape());
|
||||
ASSERT_EQ(offset_backprop.dtype(), mkl_offset_backprop.dtype());
|
||||
ASSERT_EQ(offset_backprop.shape(), mkl_offset_backprop.shape());
|
||||
|
||||
test::ExpectClose(output, mkl_output, 1e-5);
|
||||
test::ExpectClose(scale_backprop, mkl_scale_backprop, 1e-5);
|
||||
test::ExpectClose(offset_backprop, mkl_offset_backprop, 1e-5);
|
||||
}
|
||||
|
||||
private:
|
||||
using random_gen_ = Eigen::internal::NormalRandomGenerator<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Conv2DOpTest : public OpsTestBase {
|
||||
void TestBody() {}
|
||||
|
||||
public:
|
||||
void RunConv2D(const Tensor& input, const Tensor& filter, Tensor* output,
|
||||
Tensor* meta_output) {
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
|
||||
TF_EXPECT_OK(NodeDefBuilder("MklConv2D", "_MklConv2D")
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Attr("strides", {1, 1, 1, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Attr("data_format", "NHWC")
|
||||
.Attr("_kernel", "MklLayoutDependentOp")
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
AddInputFromArray<T>(input.shape(), input.flat<T>());
|
||||
AddInputFromArray<T>(filter.shape(), filter.flat<T>());
|
||||
for (int i = 0; i < 2; ++i)
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
*output = *GetOutput(0);
|
||||
*meta_output = *GetOutput(2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class FusedBatchNormOpTest : public OpsTestBase {
|
||||
protected:
|
||||
@ -198,11 +293,11 @@ class FusedBatchNormOpTest : public OpsTestBase {
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
|
||||
AddInputFromArray<float>(input.shape(), input.flat<T>());
|
||||
AddInputFromArray<float>(scale.shape(), scale.flat<T>());
|
||||
AddInputFromArray<float>(offset.shape(), offset.flat<T>());
|
||||
AddInputFromArray<float>(mean.shape(), mean.flat<T>());
|
||||
AddInputFromArray<float>(variance.shape(), variance.flat<T>());
|
||||
AddInputFromArray<T>(input.shape(), input.flat<T>());
|
||||
AddInputFromArray<float>(scale.shape(), scale.flat<float>());
|
||||
AddInputFromArray<float>(offset.shape(), offset.flat<float>());
|
||||
AddInputFromArray<float>(mean.shape(), mean.flat<float>());
|
||||
AddInputFromArray<float>(variance.shape(), variance.flat<float>());
|
||||
for (int i = 0; i < 5; ++i)
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
@ -222,6 +317,133 @@ class FusedBatchNormOpTest : public OpsTestBase {
|
||||
CommonTestUtilities<T>::VerifyTensorsClose(exponential_avg_factor,
|
||||
is_training, run, run_mkl);
|
||||
}
|
||||
|
||||
void VerifyFusedBatchNormGradWithConv2D(const float epsilon) {
|
||||
const GraphRunnerGrad run =
|
||||
[this](const Tensor& input, const Tensor& filter,
|
||||
const Tensor& y_backprop, const Tensor& scale,
|
||||
const Tensor& mean, const Tensor& variance,
|
||||
const Tensor& res_sp3, Tensor* x_backprop_tensor,
|
||||
Tensor* scale_backprop_tensor, Tensor* offset_backprop_tensor,
|
||||
const float epsilon) {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_op =
|
||||
ops::Const(root.WithOpName("input"), Input::Initializer(input));
|
||||
auto filter_op =
|
||||
ops::Const(root.WithOpName("filter"), Input::Initializer(filter));
|
||||
ops::Conv2D::Attrs conv_attr;
|
||||
conv_attr = conv_attr.DataFormat("NHWC");
|
||||
auto conv = ops::Conv2D(root.WithOpName("Conv"), input_op, filter_op,
|
||||
{1, 1, 1, 1}, "SAME", conv_attr);
|
||||
// -------------------------------------------------------------
|
||||
auto y_backprop_op = ops::Const(root.WithOpName("y_backprop"),
|
||||
Input::Initializer(y_backprop));
|
||||
auto scale_op =
|
||||
ops::Const(root.WithOpName("scale"), Input::Initializer(scale));
|
||||
auto mean_op =
|
||||
ops::Const(root.WithOpName("mean"), Input::Initializer(mean));
|
||||
auto var_op = ops::Const(root.WithOpName("variance"),
|
||||
Input::Initializer(variance));
|
||||
auto res_sp3_op = ops::Const(root.WithOpName("reserve_space_3"),
|
||||
Input::Initializer(res_sp3));
|
||||
ops::FusedBatchNormGradV3::Attrs bn_attr;
|
||||
bn_attr = bn_attr.IsTraining(true);
|
||||
bn_attr = bn_attr.Epsilon(epsilon);
|
||||
bn_attr = bn_attr.DataFormat("NHWC");
|
||||
auto bn = ops::FusedBatchNormGradV3(
|
||||
root.WithOpName("FusedBatchNormGrad"), y_backprop_op, conv,
|
||||
scale_op, mean_op, var_op, res_sp3_op, bn_attr);
|
||||
|
||||
auto x_backprop =
|
||||
ops::Identity(root.WithOpName("x_backprop"), bn.x_backprop);
|
||||
auto scale_backprop = ops::Identity(root.WithOpName("scale_backprop"),
|
||||
bn.scale_backprop);
|
||||
auto offset_backprop = ops::Identity(
|
||||
root.WithOpName("offset_backprop"), bn.offset_backprop);
|
||||
|
||||
tensorflow::GraphDef graph;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph));
|
||||
|
||||
tensorflow::SessionOptions session_options;
|
||||
std::unique_ptr<tensorflow::Session> session(
|
||||
tensorflow::NewSession(session_options));
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
|
||||
std::vector<Tensor> output_tensors;
|
||||
TF_ASSERT_OK(session->Run(
|
||||
{}, {"x_backprop", "scale_backprop", "offset_backprop"}, {},
|
||||
&output_tensors));
|
||||
|
||||
*x_backprop_tensor = output_tensors[0];
|
||||
*scale_backprop_tensor = output_tensors[1];
|
||||
*offset_backprop_tensor = output_tensors[2];
|
||||
};
|
||||
|
||||
const GraphRunnerGrad run_mkl =
|
||||
[this](const Tensor& input, const Tensor& filter,
|
||||
const Tensor& y_backprop, const Tensor& scale,
|
||||
const Tensor& mean, const Tensor& variance,
|
||||
const Tensor& res_sp3, Tensor* x_backprop_tensor,
|
||||
Tensor* scale_backprop_tensor, Tensor* offset_backprop_tensor,
|
||||
const float epsilon) {
|
||||
Tensor conv2d_output, conv2d_meta_output;
|
||||
Conv2DOpTest<T> conv2d_test;
|
||||
conv2d_test.RunConv2D(input, filter, &conv2d_output,
|
||||
&conv2d_meta_output);
|
||||
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
TF_EXPECT_OK(
|
||||
NodeDefBuilder("MklFusedBatchNorm", "_MklFusedBatchNormGradV3")
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(dtype))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Attr("epsilon", epsilon)
|
||||
.Attr("is_training", true)
|
||||
.Attr("data_format", "NHWC")
|
||||
.Attr("_kernel", "MklLayoutDependentOp")
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
|
||||
AddInputFromArray<T>(y_backprop.shape(), y_backprop.flat<T>());
|
||||
AddInputFromArray<T>(conv2d_output.shape(), conv2d_output.flat<T>());
|
||||
AddInputFromArray<float>(scale.shape(), scale.flat<float>());
|
||||
AddInputFromArray<float>(mean.shape(), mean.flat<float>());
|
||||
AddInputFromArray<float>(variance.shape(), variance.flat<float>());
|
||||
AddInputFromArray<float>(res_sp3.shape(), res_sp3.flat<float>());
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(conv2d_meta_output.shape(),
|
||||
conv2d_meta_output.flat<uint8>());
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
CommonTestUtilities<T> test_util;
|
||||
test_util.PerformConversion(dtype, *GetOutput(0), *GetOutput(5),
|
||||
x_backprop_tensor);
|
||||
|
||||
CommonTestUtilities<T> test_util_mean;
|
||||
test_util_mean.PerformConversion(dtype, *GetOutput(1), *GetOutput(6),
|
||||
scale_backprop_tensor);
|
||||
|
||||
CommonTestUtilities<T> test_util_var;
|
||||
test_util_var.PerformConversion(dtype, *GetOutput(2), *GetOutput(7),
|
||||
offset_backprop_tensor);
|
||||
};
|
||||
|
||||
CommonTestUtilities<T>::VerifyTensorsCloseForGrad(epsilon, run, run_mkl);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE_P(FusedBatchNormOpTest);
|
||||
@ -250,8 +472,14 @@ TYPED_TEST_P(FusedBatchNormOpTest, InferenceIgnoreAvgFactor) {
|
||||
this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedBatchNormOpTest, FusedBatchNormGradV3) {
|
||||
const float epsilon = 0.001;
|
||||
this->VerifyFusedBatchNormGradWithConv2D(epsilon);
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormOpTest, Training, TrainingRunningMean,
|
||||
Inference, InferenceIgnoreAvgFactor);
|
||||
Inference, InferenceIgnoreAvgFactor,
|
||||
FusedBatchNormGradV3);
|
||||
|
||||
using FusedBatchNormDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormOpTest,
|
||||
|
@ -566,6 +566,7 @@ class MklReluOpBase : public OpKernel {
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, eltwise_fwd->GetEngine()));
|
||||
// Check if src needs to be reordered
|
||||
bool is_src_reordered = false;
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
@ -575,27 +576,48 @@ class MklReluOpBase : public OpKernel {
|
||||
context);
|
||||
src_data = const_cast<T*>(
|
||||
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
||||
is_src_reordered = true;
|
||||
}
|
||||
// Allocate dst tensor, always set it as MKL-DNN layout
|
||||
if (dnn_shape_src.IsMklTensor()) {
|
||||
|
||||
// If src is reordered, then dst tensor would be in blocked layout.
|
||||
// So we propagate this blocked layout on the output. We follow same
|
||||
// logic when src is in blocked (MKL) layout to start of with also.
|
||||
if (is_src_reordered || dnn_shape_src.IsMklTensor()) {
|
||||
dnn_shape_dst.SetMklTensor(true);
|
||||
auto dst_pd = eltwise_fwd_pd->PRIMITIVE_DESC_DST;
|
||||
dnn_shape_dst.SetMklLayout(&dst_pd);
|
||||
dnn_shape_dst.SetElemType(MklDnnType<T>());
|
||||
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
|
||||
dnn_shape_src.GetSizesAsMklDnnDims(),
|
||||
dnn_shape_src.GetTfDataFormat());
|
||||
if (dnn_shape_src.IsMklTensor()) {
|
||||
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
|
||||
dnn_shape_src.GetSizesAsMklDnnDims(),
|
||||
dnn_shape_src.GetTfDataFormat());
|
||||
} else {
|
||||
dnn_shape_dst.SetTfLayout(src_tensor.dims(),
|
||||
TFShapeToMklDnnDims(src_tensor.shape()),
|
||||
MKL_TENSOR_FORMAT_BLOCKED);
|
||||
}
|
||||
tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
|
||||
} else {
|
||||
// If src is not in blocked layout or it is not reordered, then dst is
|
||||
// in native layout.
|
||||
dnn_shape_dst.SetMklTensor(false);
|
||||
tf_shape_dst = src_tensor.shape();
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{static_cast<const int>(src_index)},
|
||||
static_cast<const int>(dst_index),
|
||||
tf_shape_dst, &dst_tensor));
|
||||
AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
|
||||
|
||||
if (is_src_reordered) {
|
||||
// If src is reordered, then src and dst would be in different layouts.
|
||||
AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst,
|
||||
dnn_shape_dst);
|
||||
} else {
|
||||
// forwarding input to output works only when layouts of src and
|
||||
// dst tensor remains same -- either both of them are in native layout
|
||||
// or in blocked (MKL) layout.
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{static_cast<const int>(src_index)},
|
||||
static_cast<const int>(dst_index),
|
||||
tf_shape_dst, &dst_tensor));
|
||||
AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
|
||||
}
|
||||
T* dst_data = dst_tensor->flat<T>().data();
|
||||
|
||||
// execute eltwise
|
||||
|
@ -512,9 +512,7 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
tol = 5e-2 if mode == 'mkl' else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
# TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
|
||||
# MKL
|
||||
@parameterized.parameters(['cuda'])
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn_dropout(self, mode):
|
||||
@ -545,6 +543,7 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
||||
# miscompares on ROCm platform, and hence the tolerance bump
|
||||
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
||||
tol = 5e-2 if mode == 'mkl' else tol
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
# TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
|
||||
|
Loading…
Reference in New Issue
Block a user