Fix tf_cnn_benchmarks runtime failures

PiperOrigin-RevId: 247771407
This commit is contained in:
Eugene Zhulenev 2019-05-11 12:49:47 -07:00 committed by TensorFlower Gardener
parent 4330405781
commit 2db22ecd6c

View File

@ -697,7 +697,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
// adds necessary annotations to the graph.
// TODO(ezhulenev): Convert in other direction for fp16?
const TensorFormat compute_data_format =
compute_in_nhwc && data_format == FORMAT_NHWC ? FORMAT_NHWC : FORMAT_NCHW;
(compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
: FORMAT_NCHW;
VLOG(3) << "Compute Conv2D with cuDNN:"
<< " data_format=" << ToString(data_format)
@ -853,7 +854,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
Tensor transformed_filter;
const auto transform_filter = [&](FilterTensorFormat dst_format) -> void {
const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
<< " to " << ToString(dst_format);
@ -864,18 +865,20 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
: TensorShape({filter.dim_size(3), filter.dim_size(0),
filter.dim_size(1), filter.dim_size(2)});
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
&transformed_filter));
TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 4>()(
ctx->eigen_device<GPUDevice>(), dst_format,
To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
return Status::OK();
};
if (compute_data_format == FORMAT_NCHW) {
transform_filter(FORMAT_OIHW);
OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
} else if (compute_data_format == FORMAT_NHWC) {
transform_filter(FORMAT_OHWI);
OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
} else {
ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
ToString(compute_data_format)));