Roll forward of https://github.com/tensorflow/tensorflow/pull/42970 with fix for 5D tensors and unsupported axis that cannot use fused batch norm.
PiperOrigin-RevId: 335071264 Change-Id: Id9ab44bdba870336f03522e1ca76d88b1f305a10
This commit is contained in:
parent
eed3ab97e5
commit
27d26a8d86
@ -1121,8 +1121,17 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
|
||||
}
|
||||
|
||||
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
||||
string data_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TensorFormat data_format;
|
||||
if (!FormatFromString(data_format_str, &data_format)) {
|
||||
return errors::InvalidArgument("Invalid data format string: ",
|
||||
data_format_str);
|
||||
}
|
||||
const int rank =
|
||||
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
|
||||
ShapeHandle x;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
|
||||
|
||||
bool is_training;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||
@ -1131,14 +1140,8 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
||||
exponential_avg_factor = 1.0f; // default value
|
||||
}
|
||||
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
||||
string data_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TensorFormat data_format;
|
||||
if (!FormatFromString(data_format_str, &data_format)) {
|
||||
return errors::InvalidArgument("Invalid data format string: ",
|
||||
data_format_str);
|
||||
}
|
||||
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
|
||||
|
||||
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
|
||||
DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
|
||||
|
||||
// covers scale, offset, and if is_training is false, mean, variance
|
||||
@ -1191,13 +1194,6 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
|
||||
}
|
||||
|
||||
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
||||
ShapeHandle y_backprop;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
|
||||
ShapeHandle x;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
|
||||
|
||||
bool is_training;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||
string data_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TensorFormat data_format;
|
||||
@ -1205,7 +1201,17 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
||||
return errors::InvalidArgument("Invalid data format string: ",
|
||||
data_format_str);
|
||||
}
|
||||
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
|
||||
const int rank =
|
||||
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
|
||||
ShapeHandle y_backprop;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
|
||||
ShapeHandle x;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
|
||||
|
||||
bool is_training;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||
|
||||
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
|
||||
DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
|
||||
|
@ -670,7 +670,25 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
|
||||
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
std::string src_format = context->src_format;
|
||||
std::string dst_format = context->dst_format;
|
||||
// Update the format from 4D to 5D layout if necessary.
|
||||
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
|
||||
if (allow_5d) {
|
||||
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||
dst_format_3d);
|
||||
}
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
|
||||
// Change back to the original layout due to early exit.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -679,6 +697,11 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
// Change back the format from 5D to 4D layout.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
@ -881,8 +904,26 @@ bool FusedBatchNormGradTransposer::IsTraining(
|
||||
Status FusedBatchNormGradTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
std::string src_format = context->src_format;
|
||||
std::string dst_format = context->dst_format;
|
||||
// Update the format from 4D to 5D layout if necessary.
|
||||
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
|
||||
if (allow_5d) {
|
||||
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||
dst_format_3d);
|
||||
}
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
|
||||
!IsTraining(*node)) {
|
||||
// Change back to the original layout due to early exit.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
@ -892,6 +933,11 @@ Status FusedBatchNormGradTransposer::TransposeNode(
|
||||
TF_RETURN_IF_ERROR(
|
||||
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
// Change back the format from 5D to 4D layout.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
|
@ -1438,29 +1438,41 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
||||
Status status;
|
||||
|
||||
if (fused_node.attr().at(kDataFormat).s() == "NCHW") {
|
||||
string x_format = fused_node.attr().at(kDataFormat).s();
|
||||
if (x_format == "NCHW" or x_format == "NCDHW") {
|
||||
// Need to reshape the last 4 inputs
|
||||
NodeDef new_shape;
|
||||
const string new_shape_name =
|
||||
AddPrefixToNodeName("NCHWShape", fused_node.name());
|
||||
AddPrefixToNodeName(x_format + "Shape", fused_node.name());
|
||||
new_shape.set_name(new_shape_name);
|
||||
new_shape.set_op("Const");
|
||||
new_shape.set_device(fused_node.device());
|
||||
*new_shape.add_input() = AsControlDependency(scale);
|
||||
(*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
|
||||
Tensor t(DT_INT32, {4});
|
||||
t.flat<int32>()(0) = 1;
|
||||
t.flat<int32>()(1) = -1;
|
||||
t.flat<int32>()(2) = 1;
|
||||
t.flat<int32>()(3) = 1;
|
||||
t.AsProtoTensorContent(
|
||||
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||
if (x_format == "NCHW") {
|
||||
Tensor t(DT_INT32, {4});
|
||||
t.flat<int32>()(0) = 1;
|
||||
t.flat<int32>()(1) = -1;
|
||||
t.flat<int32>()(2) = 1;
|
||||
t.flat<int32>()(3) = 1;
|
||||
t.AsProtoTensorContent(
|
||||
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||
} else {
|
||||
Tensor t(DT_INT32, {5});
|
||||
t.flat<int32>()(0) = 1;
|
||||
t.flat<int32>()(1) = -1;
|
||||
t.flat<int32>()(2) = 1;
|
||||
t.flat<int32>()(3) = 1;
|
||||
t.flat<int32>()(4) = 1;
|
||||
t.AsProtoTensorContent(
|
||||
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||
}
|
||||
mutation->AddNode(std::move(new_shape), &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
NodeDef reshaped_scale;
|
||||
reshaped_scale.set_name(
|
||||
AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
|
||||
AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
|
||||
reshaped_scale.set_op("Reshape");
|
||||
reshaped_scale.set_device(fused_node.device());
|
||||
*reshaped_scale.add_input() = scale;
|
||||
@ -1473,7 +1485,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
|
||||
NodeDef reshaped_offset;
|
||||
reshaped_offset.set_name(
|
||||
AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
|
||||
AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
|
||||
reshaped_offset.set_op("Reshape");
|
||||
reshaped_offset.set_device(fused_node.device());
|
||||
*reshaped_offset.add_input() = offset;
|
||||
@ -1486,7 +1498,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
|
||||
NodeDef reshaped_mean;
|
||||
reshaped_mean.set_name(
|
||||
AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
|
||||
AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
|
||||
reshaped_mean.set_op("Reshape");
|
||||
reshaped_mean.set_device(fused_node.device());
|
||||
*reshaped_mean.add_input() = mean;
|
||||
@ -1499,7 +1511,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
||||
|
||||
NodeDef reshaped_variance;
|
||||
reshaped_variance.set_name(
|
||||
AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
|
||||
AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
|
||||
reshaped_variance.set_op("Reshape");
|
||||
reshaped_variance.set_device(fused_node.device());
|
||||
*reshaped_variance.add_input() = variance;
|
||||
|
@ -1241,15 +1241,15 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
// If use_reserved_space is false, we don't have 5th output.
|
||||
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
||||
bool use_reserved_space) {
|
||||
const Tensor& x = context->input(0);
|
||||
Tensor x = context->input(0);
|
||||
const Tensor& scale = context->input(1);
|
||||
const Tensor& offset = context->input(2);
|
||||
const Tensor& estimated_mean = context->input(3);
|
||||
const Tensor& estimated_variance = context->input(4);
|
||||
const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
|
||||
|
||||
OP_REQUIRES(context, x.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
|
||||
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||
x.shape().DebugString()));
|
||||
OP_REQUIRES(context, scale.dims() == 1,
|
||||
errors::InvalidArgument("scale must be 1-dimensional",
|
||||
@ -1264,6 +1264,21 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
context, estimated_variance.dims() == 1,
|
||||
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
||||
estimated_variance.shape().DebugString()));
|
||||
bool use_reshape = (x.dims() == 5);
|
||||
auto x_shape = x.shape();
|
||||
TensorShape dest_shape;
|
||||
if (use_reshape) {
|
||||
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
|
||||
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
|
||||
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
|
||||
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
|
||||
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
|
||||
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
|
||||
{{in_planes, in_rows * in_cols}}, in_depth);
|
||||
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
}
|
||||
|
||||
if (has_side_input_) {
|
||||
OP_REQUIRES(context, side_input->shape() == x.shape(),
|
||||
errors::InvalidArgument(
|
||||
@ -1282,8 +1297,10 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
Tensor* y = nullptr;
|
||||
auto alloc_shape = use_reshape ? dest_shape : x_shape;
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{0}, 0, x.shape(), &y));
|
||||
{0}, 0, alloc_shape, &y));
|
||||
|
||||
Tensor* batch_mean = nullptr;
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{3}, 1, scale.shape(), &batch_mean));
|
||||
@ -1310,6 +1327,10 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
|
||||
tensor_format_, use_reserved_space);
|
||||
}
|
||||
if (use_reshape) {
|
||||
OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -1375,8 +1396,8 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
|
||||
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
||||
bool use_reserved_space) {
|
||||
const Tensor& y_backprop = context->input(0);
|
||||
const Tensor& x = context->input(1);
|
||||
Tensor y_backprop = context->input(0);
|
||||
Tensor x = context->input(1);
|
||||
const Tensor& scale = context->input(2);
|
||||
// When is_training=True, batch mean and variance/inverted variance are
|
||||
// saved in the forward pass to be reused here. When is_training=False,
|
||||
@ -1387,11 +1408,11 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
// saves inverted variance.
|
||||
const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
|
||||
|
||||
OP_REQUIRES(context, y_backprop.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5,
|
||||
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||
y_backprop.shape().DebugString()));
|
||||
OP_REQUIRES(context, x.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
|
||||
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||
x.shape().DebugString()));
|
||||
OP_REQUIRES(context, scale.dims() == 1,
|
||||
errors::InvalidArgument("scale must be 1-dimensional",
|
||||
@ -1404,10 +1425,27 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"saved variance must be 1-dimensional",
|
||||
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
|
||||
bool use_reshape = (x.dims() == 5);
|
||||
auto x_shape = x.shape();
|
||||
TensorShape dest_shape;
|
||||
if (use_reshape) {
|
||||
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
|
||||
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
|
||||
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
|
||||
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
|
||||
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
|
||||
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
|
||||
{{in_planes, in_rows * in_cols}}, in_depth);
|
||||
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
}
|
||||
|
||||
Tensor* x_backprop = nullptr;
|
||||
auto alloc_shape = use_reshape ? dest_shape : x_shape;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, x.shape(), &x_backprop));
|
||||
context->allocate_output(0, alloc_shape, &x_backprop));
|
||||
|
||||
const TensorShape& scale_offset_shape = scale.shape();
|
||||
Tensor* scale_backprop = nullptr;
|
||||
@ -1441,15 +1479,20 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
offset_backprop, use_reserved_space, tensor_format_);
|
||||
} else {
|
||||
// Necessary layout conversion is currently done in python.
|
||||
CHECK(tensor_format_ == FORMAT_NHWC)
|
||||
<< "The implementation of FusedBatchNormGrad with is_training=False "
|
||||
"only support "
|
||||
<< "NHWC tensor format for now.";
|
||||
OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument(
|
||||
"The implementation of "
|
||||
"FusedBatchNormGrad with is_training=False only support "
|
||||
"NHWC tensor format for now."));
|
||||
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
|
||||
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
||||
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
|
||||
offset_backprop);
|
||||
}
|
||||
if (use_reshape) {
|
||||
OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
|
||||
errors::InvalidArgument("Error during tensor copy."));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -221,7 +221,7 @@ REGISTER_OP("FusedBatchNormV3")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("exponential_avg_factor: float = 1.0")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr(GetConvnetDataFormat2D3DAttrString())
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormV3Shape);
|
||||
|
||||
@ -308,7 +308,7 @@ REGISTER_OP("FusedBatchNormGradV3")
|
||||
.Attr("T: {half, bfloat16, float}")
|
||||
.Attr("U: {float}")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr(GetConvnetDataFormat2D3DAttrString())
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
|
||||
// --------------------------------------------------------------------------
|
||||
|
@ -1275,6 +1275,94 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
|
||||
self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
|
||||
self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testBatchNorm3D(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
|
||||
filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
|
||||
strides_val = [1, 1, 1, 1, 1]
|
||||
scale = constant_op.constant(0.1, shape=[3])
|
||||
offset = constant_op.constant(0.3, shape=[3])
|
||||
conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
|
||||
y, _, _ = nn.fused_batch_norm(conv3d, scale, offset, data_format='NDHWC')
|
||||
output = array_ops.identity(y)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
output_val = sess.run(output, run_metadata=metadata)
|
||||
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 2
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
|
||||
self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormV3-0-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testBatchNormGrad3D(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
|
||||
filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
|
||||
strides_val = [1, 1, 1, 1, 1]
|
||||
scale = constant_op.constant(0.1, shape=[3])
|
||||
offset = constant_op.constant(0.3, shape=[3])
|
||||
mean = constant_op.constant(0.1, shape=[3])
|
||||
variance = constant_op.constant(0.3, shape=[3])
|
||||
conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
|
||||
y, running_mean, running_var, r0, r1, r2 = gen_nn_ops.fused_batch_norm_v3(
|
||||
conv3d,
|
||||
scale,
|
||||
offset,
|
||||
mean,
|
||||
variance,
|
||||
epsilon=1.001e-5,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NDHWC',
|
||||
is_training=True,
|
||||
name='batch_norm')
|
||||
dx, dscale, doffset, _, _ = gen_nn_ops.fused_batch_norm_grad_v3(
|
||||
y,
|
||||
x_3d,
|
||||
scale,
|
||||
r0,
|
||||
r1,
|
||||
r2,
|
||||
epsilon=1.001e-5,
|
||||
data_format='NDHWC',
|
||||
is_training=True)
|
||||
output = array_ops.identity(dx)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
output_val = sess.run(output, run_metadata=metadata)
|
||||
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
expected_num_transposes = 3
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
|
||||
self._assert_trans_ndhwc_to_ncdhw('FusedBatchNormGradV3-1', nodes)
|
||||
self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormGradV3-0-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
|
@ -330,13 +330,13 @@ class BatchNormalizationBase(Layer):
|
||||
# output back to its original shape accordingly.
|
||||
if self._USE_V2_BEHAVIOR:
|
||||
if self.fused is None:
|
||||
self.fused = (ndims == 4)
|
||||
elif self.fused and ndims != 4:
|
||||
self.fused = ndims in (4, 5)
|
||||
elif self.fused and ndims not in (4, 5):
|
||||
raise ValueError('Batch normalization layers with fused=True only '
|
||||
'support 4D input tensors.')
|
||||
'support 4D or 5D input tensors.')
|
||||
else:
|
||||
assert self.fused is not None
|
||||
self.fused = (ndims == 4 and self._fused_can_be_used())
|
||||
self.fused = (ndims in (4, 5) and self._fused_can_be_used())
|
||||
# TODO(chrisying): fused batch norm is currently not supported for
|
||||
# multi-axis batch norm and by extension virtual batches. In some cases,
|
||||
# it might be possible to use fused batch norm but would require reshaping
|
||||
@ -345,13 +345,22 @@ class BatchNormalizationBase(Layer):
|
||||
# common use case (turning 5D w/ virtual batch to NCHW)
|
||||
|
||||
if self.fused:
|
||||
if self.axis == [1]:
|
||||
if self.axis == [1] and ndims == 4:
|
||||
self._data_format = 'NCHW'
|
||||
elif self.axis == [3]:
|
||||
elif self.axis == [1] and ndims == 5:
|
||||
self._data_format = 'NCDHW'
|
||||
elif self.axis == [3] and ndims == 4:
|
||||
self._data_format = 'NHWC'
|
||||
elif self.axis == [4] and ndims == 5:
|
||||
self._data_format = 'NDHWC'
|
||||
elif ndims == 5:
|
||||
# 5D tensors that can be passed in but should not use fused batch norm
|
||||
# due to unsupported axis.
|
||||
self.fused = False
|
||||
else:
|
||||
raise ValueError('Unsupported axis, fused batch norm only supports '
|
||||
'axis == [1] or axis == [3]')
|
||||
'axis == [1] or axis == [3] for 4D input tensors or '
|
||||
'axis == [1] or axis == [4] for 5D input tensors')
|
||||
|
||||
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
|
||||
for x in axis_to_dim:
|
||||
|
@ -66,6 +66,15 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
|
||||
kwargs={'scale': False,
|
||||
'center': False},
|
||||
input_shape=(3, 3))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.BatchNormalization,
|
||||
kwargs={
|
||||
'gamma_initializer': 'ones',
|
||||
'beta_initializer': 'ones',
|
||||
'moving_mean_initializer': 'zeros',
|
||||
'moving_variance_initializer': 'ones'
|
||||
},
|
||||
input_shape=(3, 2, 4, 2))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_batchnorm_weights(self):
|
||||
@ -319,7 +328,7 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
|
||||
norm = normalization_v2.BatchNormalization(fused=True)
|
||||
self.assertEqual(norm.fused, True)
|
||||
inp = keras.layers.Input(shape=(4, 4))
|
||||
with self.assertRaisesRegex(ValueError, '4D input tensors'):
|
||||
with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'):
|
||||
norm(inp)
|
||||
|
||||
def test_updates_in_wrap_function(self):
|
||||
|
@ -43,14 +43,18 @@ class BatchNormalizationTest(test.TestCase):
|
||||
return math_ops.cast(y, x.dtype)
|
||||
|
||||
def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format):
|
||||
if data_format not in ['NHWC', 'NCHW']:
|
||||
raise ValueError('data_format must be NCHW or NHWC, '
|
||||
'got %s.' % data_format)
|
||||
if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
|
||||
raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
|
||||
'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
|
||||
if data_format == 'NCHW':
|
||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||
elif data_format == 'NCDHW':
|
||||
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
|
||||
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
|
||||
if data_format == 'NCHW':
|
||||
y = array_ops.transpose(y, [0, 3, 1, 2])
|
||||
elif data_format == 'NCDHW':
|
||||
y = array_ops.transpose(y, [0, 4, 1, 2, 3])
|
||||
return self.evaluate(y)
|
||||
|
||||
def _test_inference(self,
|
||||
@ -102,17 +106,24 @@ class BatchNormalizationTest(test.TestCase):
|
||||
|
||||
def _training_ref(self, x, scale, offset, old_mean, old_var,
|
||||
exponential_avg_factor, epsilon, data_format):
|
||||
if data_format not in ['NHWC', 'NCHW']:
|
||||
raise ValueError('data_format must be NCHW or NHWC, '
|
||||
'got %s.' % data_format)
|
||||
if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
|
||||
raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
|
||||
'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
|
||||
use_4d_tensor = (x.shape.ndims == 4)
|
||||
if data_format == 'NCHW':
|
||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||
elif data_format == 'NCDHW':
|
||||
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
|
||||
|
||||
mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3]
|
||||
batch_mean, batch_var = nn_impl.moments(
|
||||
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
|
||||
math_ops.cast(x, scale.dtype), mean_axis, keep_dims=False)
|
||||
|
||||
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
|
||||
if data_format == 'NCHW':
|
||||
y = array_ops.transpose(y, [0, 3, 1, 2])
|
||||
elif data_format == 'NCDHW':
|
||||
y = array_ops.transpose(y, [0, 4, 1, 2, 3])
|
||||
|
||||
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
||||
# the denominator in the formula to calculate variance, while
|
||||
@ -377,14 +388,18 @@ class BatchNormalizationTest(test.TestCase):
|
||||
|
||||
def _runtests(self, x_shape, is_training, gradient_test=False,
|
||||
cpu_only=False):
|
||||
if len(x_shape) == 4:
|
||||
data_format_list = ['NHWC', 'NCHW']
|
||||
else:
|
||||
data_format_list = ['NCDHW', 'NDHWC']
|
||||
use_gpu_vals = [False]
|
||||
if test.is_gpu_available(cuda_only=True) and not cpu_only:
|
||||
use_gpu_vals += [True]
|
||||
factors = [1.0, 0.6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
for use_gpu in use_gpu_vals:
|
||||
for data_format in ['NHWC', 'NCHW']:
|
||||
if data_format == 'NHWC':
|
||||
for data_format in data_format_list:
|
||||
if data_format == 'NHWC' or data_format == 'NDHWC':
|
||||
scale_shape = x_shape[-1:]
|
||||
else:
|
||||
scale_shape = x_shape[1:2]
|
||||
@ -444,6 +459,10 @@ class BatchNormalizationTest(test.TestCase):
|
||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||
self._runtests(x_shape, False, cpu_only=True)
|
||||
|
||||
def testInferenceShape7(self):
|
||||
x_shape = [1, 2, 6, 1, 3]
|
||||
self._runtests(x_shape, False)
|
||||
|
||||
def testTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
self._runtests(x_shape, True)
|
||||
@ -465,11 +484,16 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTrainingShape6(self):
|
||||
x_shape = [1, 1, 1, 1]
|
||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||
self._runtests(x_shape, True, cpu_only=True)
|
||||
|
||||
def testTrainingShape7(self):
|
||||
x_shape = [1, 2, 6, 1, 3]
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradInferenceShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
@ -503,6 +527,11 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True,
|
||||
cpu_only=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradInferenceShape7(self):
|
||||
x_shape = [1, 2, 6, 1, 3]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
@ -535,42 +564,54 @@ class BatchNormalizationTest(test.TestCase):
|
||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape7(self):
|
||||
x_shape = [1, 2, 6, 1, 3]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
def _testBatchNormGradGrad(self, config):
|
||||
shape = config['shape']
|
||||
err_tolerance = config['err_tolerance']
|
||||
dtype = config['dtype']
|
||||
rank = len(shape)
|
||||
if rank == 4:
|
||||
data_format_nhwc, features_nhwc = 'NHWC', shape[3]
|
||||
data_format_nchw, features_nchw = 'NCHW', shape[1]
|
||||
else:
|
||||
data_format_nhwc, features_nhwc = 'NDHWC', shape[4]
|
||||
data_format_nchw, features_nchw = 'NCDHW', shape[1]
|
||||
for is_training in [True, False]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_grad_grad(
|
||||
shape,
|
||||
dtype, [shape[3]],
|
||||
dtype, [features_nhwc],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
data_format=data_format_nhwc,
|
||||
is_training=is_training,
|
||||
err_tolerance=err_tolerance)
|
||||
self._test_grad_grad(
|
||||
shape,
|
||||
dtype, [shape[1]],
|
||||
dtype, [features_nchw],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
data_format=data_format_nchw,
|
||||
is_training=is_training,
|
||||
err_tolerance=err_tolerance)
|
||||
self._test_grad_grad(
|
||||
shape,
|
||||
dtype, [shape[3]],
|
||||
dtype, [features_nhwc],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
data_format=data_format_nhwc,
|
||||
is_training=is_training,
|
||||
err_tolerance=err_tolerance)
|
||||
self._test_grad_grad(
|
||||
shape,
|
||||
dtype, [shape[1]],
|
||||
dtype, [features_nchw],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
data_format=data_format_nchw,
|
||||
is_training=is_training,
|
||||
err_tolerance=err_tolerance)
|
||||
|
||||
@ -610,6 +651,24 @@ class BatchNormalizationTest(test.TestCase):
|
||||
}
|
||||
self._testBatchNormGradGrad(config)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradGradConfig5(self):
|
||||
config = {
|
||||
'shape': [2, 3, 2, 2, 2],
|
||||
'err_tolerance': 2e-3,
|
||||
'dtype': np.float32,
|
||||
}
|
||||
self._testBatchNormGradGrad(config)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradGradConfig6(self):
|
||||
config = {
|
||||
'shape': [2, 3, 2, 2, 2],
|
||||
'err_tolerance': 3e-3,
|
||||
'dtype': np.float16,
|
||||
}
|
||||
self._testBatchNormGradGrad(config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -897,6 +897,11 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
|
||||
if data_format == b"NCHW":
|
||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
|
||||
elif data_format == b"NCDHW":
|
||||
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
|
||||
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1])
|
||||
target_data_format = ("NHWC" if data_format in (b"NCHW",
|
||||
b"NHWC") else "NDHWC")
|
||||
args = {
|
||||
"y_backprop": grad_y,
|
||||
"x": x,
|
||||
@ -904,7 +909,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
|
||||
"reserve_space_1": pop_mean,
|
||||
"reserve_space_2": pop_var,
|
||||
"epsilon": epsilon,
|
||||
"data_format": "NHWC",
|
||||
"data_format": target_data_format,
|
||||
"is_training": is_training
|
||||
}
|
||||
if version == 2:
|
||||
@ -912,6 +917,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
|
||||
dx, dscale, doffset, _, _ = grad_fun(**args)
|
||||
if data_format == b"NCHW":
|
||||
dx = array_ops.transpose(dx, [0, 3, 1, 2])
|
||||
elif data_format == b"NCDHW":
|
||||
dx = array_ops.transpose(dx, [0, 4, 1, 2, 3])
|
||||
return dx, dscale, doffset, None, None
|
||||
|
||||
|
||||
@ -941,8 +948,8 @@ def _BatchNormGrad(grad_y,
|
||||
"""Returns the gradients for the 3 inputs of BatchNorm.
|
||||
|
||||
Args:
|
||||
grad_y: A `Tensor` of 4 dimensions for gradient for y.
|
||||
x: A `Tensor` of 4 dimensions for x.
|
||||
grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y.
|
||||
x: A `Tensor` of 4 or 5 dimensions for x.
|
||||
scale: A `Tensor` of 1 dimension for scaling.
|
||||
pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
|
||||
is_training=False.
|
||||
@ -968,11 +975,19 @@ def _BatchNormGrad(grad_y,
|
||||
if data_format == b"NHWC":
|
||||
keepdims = False
|
||||
reduce_axis = [0, 1, 2]
|
||||
else:
|
||||
elif data_format == b"NDHWC":
|
||||
keepdims = False
|
||||
reduce_axis = [0, 1, 2, 3]
|
||||
elif data_format == b"NCHW":
|
||||
keepdims = True
|
||||
reduce_axis = [0, 2, 3]
|
||||
shape = [1, array_ops.size(scale), 1, 1]
|
||||
scale = array_ops.reshape(scale, shape)
|
||||
else:
|
||||
keepdims = True
|
||||
reduce_axis = [0, 2, 3, 4]
|
||||
shape = [1, array_ops.size(scale), 1, 1, 1]
|
||||
scale = array_ops.reshape(scale, shape)
|
||||
mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
|
||||
mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
|
||||
var_x = math_ops.reduce_mean(
|
||||
@ -987,19 +1002,27 @@ def _BatchNormGrad(grad_y,
|
||||
grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
|
||||
grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
|
||||
grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
|
||||
if data_format == b"NCHW":
|
||||
if data_format == b"NCHW" or data_format == b"NCDHW":
|
||||
grad_scale = array_ops.squeeze(grad_scale)
|
||||
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
|
||||
return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
|
||||
else:
|
||||
if data_format == b"NHWC":
|
||||
reduce_axis = [0, 1, 2]
|
||||
else:
|
||||
elif data_format == b"NDHWC":
|
||||
reduce_axis = [0, 1, 2, 3]
|
||||
elif data_format == b"NCHW":
|
||||
reduce_axis = [0, 2, 3]
|
||||
shape = [1, array_ops.size(pop_mean), 1, 1]
|
||||
pop_mean = array_ops.reshape(pop_mean, shape)
|
||||
pop_var = array_ops.reshape(pop_var, shape)
|
||||
scale = array_ops.reshape(scale, shape)
|
||||
else:
|
||||
reduce_axis = [0, 2, 3, 4]
|
||||
shape = [1, array_ops.size(pop_mean), 1, 1, 1]
|
||||
pop_mean = array_ops.reshape(pop_mean, shape)
|
||||
pop_var = array_ops.reshape(pop_var, shape)
|
||||
scale = array_ops.reshape(scale, shape)
|
||||
|
||||
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
|
||||
var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
|
||||
|
@ -1585,7 +1585,7 @@ def fused_batch_norm(
|
||||
(http://arxiv.org/abs/1502.03167).
|
||||
|
||||
Args:
|
||||
x: Input `Tensor` of 4 dimensions.
|
||||
x: Input `Tensor` of 4 or 5 dimensions.
|
||||
scale: A `Tensor` of 1 dimension for scaling.
|
||||
offset: A `Tensor` of 1 dimension for bias.
|
||||
mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
|
||||
@ -1611,7 +1611,8 @@ def fused_batch_norm(
|
||||
Variance must be a `Tensor` of the same shape as scale containing
|
||||
the exponential running variance.
|
||||
epsilon: A small float number added to the variance of x.
|
||||
data_format: The data format for x. Either "NHWC" (default) or "NCHW".
|
||||
data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
|
||||
4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
|
||||
is_training: A bool value to specify if the operation is used for
|
||||
training or inference.
|
||||
name: A name for this operation (optional).
|
||||
@ -1622,7 +1623,7 @@ def fused_batch_norm(
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
y: A 4D Tensor for the normalized, scaled, offsetted x.
|
||||
y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
|
||||
running_mean: A 1D Tensor for the exponential running mean of x.
|
||||
The output value is (1 - exponential_avg_factor) * mean +
|
||||
exponential_avg_factor * batch_mean), where batch_mean
|
||||
|
Loading…
Reference in New Issue
Block a user