Minor follow-up fixes for PR #37176.
PiperOrigin-RevId: 298389754 Change-Id: I7a32fea4fd8eb6d10e3c5dd32e341b1e9ad2f808
This commit is contained in:
parent
edcafbcefc
commit
c877732daf
@ -881,6 +881,9 @@ class MklFusedBatchNormOp : public OpKernel {
|
||||
auto batch_variance_data = batch_variance_tensor->flat<U>().data();
|
||||
auto est_mean_data = est_mean_tensor.flat<U>().data();
|
||||
auto est_variance_data = est_variance_tensor.flat<U>().data();
|
||||
|
||||
// TODO(intel-tf): Merge the `is_training && exponential_avg_factor == 1`
|
||||
// case with the `else` (`!is_training`) case if possible.
|
||||
if (is_training_) {
|
||||
if (exponential_avg_factor_ == U(1.0)) {
|
||||
for (int k = 0; k < depth_; k++) {
|
||||
|
@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
@ -222,7 +224,7 @@ class FusedBatchNormOpTest : public OpsTestBase {
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE_P(FusedBatchNormOpTest);
|
||||
TYPED_TEST_SUITE_P(FusedBatchNormOpTest);
|
||||
|
||||
TYPED_TEST_P(FusedBatchNormOpTest, Training) {
|
||||
const float exponential_avg_factor = 1.0;
|
||||
@ -248,12 +250,13 @@ TYPED_TEST_P(FusedBatchNormOpTest, InferenceIgnoreAvgFactor) {
|
||||
this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_CASE_P(FusedBatchNormOpTest, Training, TrainingRunningMean,
|
||||
Inference, InferenceIgnoreAvgFactor);
|
||||
REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormOpTest, Training, TrainingRunningMean,
|
||||
Inference, InferenceIgnoreAvgFactor);
|
||||
|
||||
using FusedBatchNormDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(Test, FusedBatchNormOpTest,
|
||||
FusedBatchNormDataTypes);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormOpTest,
|
||||
FusedBatchNormDataTypes);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
|
Loading…
x
Reference in New Issue
Block a user