Add a remapper pattern for FusedBatchNorm[is_training=False].
tf_cnn_benchmarks --forward_only=true --model=resnet50: fp32: 1010 images/sec -> 1145 images/sec fp16: 1990 images/sec -> 3010 images/sec PiperOrigin-RevId: 253610151
This commit is contained in:
parent
ce59aa0be8
commit
d73778d205
@ -761,8 +761,6 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
|||||||
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
|
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
|
||||||
.ok())
|
.ok())
|
||||||
return false;
|
return false;
|
||||||
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
|
|
||||||
if (!is_training) return false;
|
|
||||||
|
|
||||||
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
||||||
// inputs and activation, and it has its own limitations. In inference mode
|
// inputs and activation, and it has its own limitations. In inference mode
|
||||||
@ -796,7 +794,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
|||||||
!HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
|
!HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check that only one node consumes the output of a FusedBatchNorm.
|
// Check that only one node consumes the 0-th output of a FusedBatchNorm.
|
||||||
if (HasControlFaninOrFanout(fused_batch_norm) ||
|
if (HasControlFaninOrFanout(fused_batch_norm) ||
|
||||||
!HasAtMostOneFanoutAtPort0(fused_batch_norm) ||
|
!HasAtMostOneFanoutAtPort0(fused_batch_norm) ||
|
||||||
IsInPreserveSet(ctx, fused_batch_norm_node_def))
|
IsInPreserveSet(ctx, fused_batch_norm_node_def))
|
||||||
@ -1216,11 +1214,14 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx,
|
|||||||
const NodeDef& activation = graph->node(matched.activation);
|
const NodeDef& activation = graph->node(matched.activation);
|
||||||
|
|
||||||
VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
|
VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
|
||||||
<< " side_input="
|
<< " activation=" << activation.name() << " side_input="
|
||||||
<< (matched.side_input != kMissingIndex
|
<< (matched.side_input != kMissingIndex
|
||||||
? graph->node(matched.side_input).name()
|
? graph->node(matched.side_input).name()
|
||||||
: "<none>")
|
: "<none>")
|
||||||
<< " activation=" << activation.name()
|
<< " invalidated="
|
||||||
|
<< (matched.invalidated != kMissingIndex
|
||||||
|
? graph->node(matched.invalidated).name()
|
||||||
|
: "<none>")
|
||||||
<< " fused_batch_norm=" << fused_batch_norm.name();
|
<< " fused_batch_norm=" << fused_batch_norm.name();
|
||||||
|
|
||||||
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
|
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
|
||||||
@ -1512,23 +1513,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
|||||||
|
|
||||||
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
// Supported graph patterns.
|
|
||||||
// clang-format off
|
|
||||||
FusedBatchNorm fused_batch_norm;
|
|
||||||
FusedBatchNormEx fused_batch_norm_ex;
|
|
||||||
ContractionWithBiasAdd contract_with_bias;
|
|
||||||
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
|
||||||
#ifndef INTEL_MKL
|
|
||||||
ContractionWithBatchNorm contract_with_batch_norm;
|
|
||||||
ContractionWithBatchNormAndActivation contract_with_batch_norm_and_activation;
|
|
||||||
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
|
||||||
#endif // !INTEL_MKL
|
|
||||||
#ifdef INTEL_MKL
|
|
||||||
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
|
|
||||||
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
|
|
||||||
#endif // INTEL_MKL
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
GrapplerItem mutable_item = item;
|
GrapplerItem mutable_item = item;
|
||||||
Status status;
|
Status status;
|
||||||
RemapperContext ctx(&mutable_item, &status);
|
RemapperContext ctx(&mutable_item, &status);
|
||||||
@ -1556,6 +1540,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
|
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
|
||||||
|
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
|
||||||
|
|
||||||
if (!item.optimization_options().is_eager_mode) {
|
if (!item.optimization_options().is_eager_mode) {
|
||||||
// Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
|
// Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
|
||||||
if (FindContractionWithBiasAndAddActivation(
|
if (FindContractionWithBiasAndAddActivation(
|
||||||
@ -1578,6 +1565,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
#endif //! INTEL_MKL
|
#endif //! INTEL_MKL
|
||||||
|
|
||||||
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
|
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
|
||||||
|
ContractionWithBiasAdd contract_with_bias;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindContractionWithBias(ctx, i, &contract_with_bias)) {
|
FindContractionWithBias(ctx, i, &contract_with_bias)) {
|
||||||
TF_RETURN_IF_ERROR(AddFusedContractionNode(
|
TF_RETURN_IF_ERROR(AddFusedContractionNode(
|
||||||
@ -1586,6 +1574,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
|
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
|
||||||
|
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindContractionWithBiasAndActivation(
|
FindContractionWithBiasAndActivation(
|
||||||
ctx, i, &contract_with_bias_and_activation)) {
|
ctx, i, &contract_with_bias_and_activation)) {
|
||||||
@ -1603,6 +1592,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
||||||
#ifndef INTEL_MKL
|
#ifndef INTEL_MKL
|
||||||
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
||||||
|
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
|
FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -1612,6 +1602,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
|
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
|
||||||
|
ContractionWithBatchNorm contract_with_batch_norm;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
|
FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
|
||||||
TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
|
TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
|
||||||
@ -1621,6 +1612,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
|
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
|
||||||
|
ContractionWithBatchNormAndActivation
|
||||||
|
contract_with_batch_norm_and_activation;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindConv2DWithBatchNormAndActivation(
|
FindConv2DWithBatchNormAndActivation(
|
||||||
ctx, i, &contract_with_batch_norm_and_activation)) {
|
ctx, i, &contract_with_batch_norm_and_activation)) {
|
||||||
@ -1644,6 +1637,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
|
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
|
||||||
|
FusedBatchNormEx fused_batch_norm_ex;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
|
FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
|
||||||
TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
|
TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
|
||||||
@ -1653,6 +1647,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
// During inference, most of the inputs to FusedBatchNorm are constant, and
|
// During inference, most of the inputs to FusedBatchNorm are constant, and
|
||||||
// we can therefore replace the op with a much cheaper set of primitives.
|
// we can therefore replace the op with a much cheaper set of primitives.
|
||||||
|
FusedBatchNorm fused_batch_norm;
|
||||||
if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
|
if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
|
||||||
TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
|
TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
|
||||||
continue;
|
continue;
|
||||||
|
@ -112,42 +112,62 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
|
|||||||
ASSERT_EQ(tensors_expected.size(), 1);
|
ASSERT_EQ(tensors_expected.size(), 1);
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
ASSERT_EQ(tensors.size(), 1);
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-5);
|
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
|
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
|
||||||
using ::tensorflow::ops::Placeholder;
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
for (bool is_training : {true, false}) {
|
||||||
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires "
|
|
||||||
"CUDNN_VERSION >= 7402.";
|
|
||||||
#else
|
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||||
auto channels_shape = ops::Placeholder::Shape({24});
|
if (is_training) {
|
||||||
|
LOG(INFO) << "Skip FuseBatchNormWithRelu"
|
||||||
|
<< "[is_training=" << is_training << "] "
|
||||||
|
<< "test. It requires CUDNN_VERSION >= 7402.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
#if !defined(GOOGLE_CUDA)
|
||||||
|
if (!is_training) {
|
||||||
|
LOG(INFO) << "Skip FuseBatchNormWithRelu"
|
||||||
|
<< "[is_training=" << is_training << "]";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const int num_channels = 24;
|
||||||
|
|
||||||
|
TensorShape channel_shape({num_channels});
|
||||||
|
TensorShape empty_shape({0});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape({2, 8, 8, num_channels}));
|
||||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||||
|
|
||||||
float epsilon = 0.1f;
|
float epsilon = 0.1f;
|
||||||
auto fbn = ops::FusedBatchNormV3(
|
auto fbn = ops::FusedBatchNormV3(
|
||||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||||
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||||
"NHWC"));
|
.Epsilon(epsilon)
|
||||||
|
.DataFormat("NHWC"));
|
||||||
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
||||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||||
|
|
||||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
: channel_shape);
|
||||||
|
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch"};
|
item.fetch = {"fetch"};
|
||||||
@ -199,47 +219,67 @@ TEST_F(RemapperTest, FuseBatchNormWithRelu) {
|
|||||||
ASSERT_EQ(tensors.size(), 1);
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
||||||
}
|
}
|
||||||
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
|
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
|
||||||
using ::tensorflow::ops::Placeholder;
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
for (bool is_training : {true, false}) {
|
||||||
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires "
|
|
||||||
"CUDNN_VERSION >= 7402.";
|
|
||||||
#else
|
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||||
auto channels_shape = ops::Placeholder::Shape({24});
|
if (is_training) {
|
||||||
|
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
|
||||||
|
<< "[is_training=" << is_training << "] "
|
||||||
|
<< "test. It requires CUDNN_VERSION >= 7402.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
#if !defined(GOOGLE_CUDA)
|
||||||
|
if (!is_training) {
|
||||||
|
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
|
||||||
|
<< "[is_training=" << is_training << "]";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const int num_channels = 24;
|
||||||
|
|
||||||
|
TensorShape input_shape({2, 8, 8, num_channels});
|
||||||
|
TensorShape channel_shape({num_channels});
|
||||||
|
TensorShape empty_shape({0});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape(input_shape));
|
||||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||||
auto side_input =
|
auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT);
|
||||||
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
|
|
||||||
auto side_input_cast =
|
auto side_input_cast =
|
||||||
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
|
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
|
||||||
|
|
||||||
float epsilon = 0.1f;
|
float epsilon = 0.1f;
|
||||||
auto fbn = ops::FusedBatchNormV3(
|
auto fbn = ops::FusedBatchNormV3(
|
||||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||||
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||||
"NHWC"));
|
.Epsilon(epsilon)
|
||||||
|
.DataFormat("NHWC"));
|
||||||
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
||||||
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
||||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||||
|
|
||||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
auto input_t = GenerateRandomTensor<DT_FLOAT>(input_shape);
|
||||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
: channel_shape);
|
||||||
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch"};
|
item.fetch = {"fetch"};
|
||||||
@ -290,7 +330,7 @@ TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
|
|||||||
ASSERT_EQ(tensors.size(), 1);
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
||||||
}
|
}
|
||||||
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseConv2DWithBias) {
|
TEST_F(RemapperTest, FuseConv2DWithBias) {
|
||||||
|
Loading…
Reference in New Issue
Block a user