diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 8050320e441..cb879e7226a 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -697,7 +697,8 @@ void LaunchConv2DOp::operator()( // adds necessary annotations to the graph. // TODO(ezhulenev): Convert in other direction for fp16? const TensorFormat compute_data_format = - compute_in_nhwc && data_format == FORMAT_NHWC ? FORMAT_NHWC : FORMAT_NCHW; + (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC + : FORMAT_NCHW; VLOG(3) << "Compute Conv2D with cuDNN:" << " data_format=" << ToString(data_format) @@ -853,7 +854,7 @@ void LaunchConv2DOp::operator()( Tensor transformed_filter; - const auto transform_filter = [&](FilterTensorFormat dst_format) -> void { + const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status { VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO) << " to " << ToString(dst_format); @@ -864,18 +865,20 @@ void LaunchConv2DOp::operator()( : TensorShape({filter.dim_size(3), filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}); - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, dst_shape, - &transformed_filter)); + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum::value, dst_shape, + &transformed_filter)); functor::TransformFilter()( ctx->eigen_device(), dst_format, To32Bit(filter.tensor()), To32Bit(transformed_filter.tensor())); + + return Status::OK(); }; if (compute_data_format == FORMAT_NCHW) { - transform_filter(FORMAT_OIHW); + OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW)); } else if (compute_data_format == FORMAT_NHWC) { - transform_filter(FORMAT_OHWI); + OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI)); } else { ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ", ToString(compute_data_format)));