Fix tf_cnn_benchmarks runtime failures
PiperOrigin-RevId: 247771407
This commit is contained in:
parent
4330405781
commit
2db22ecd6c
@ -697,7 +697,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
// adds necessary annotations to the graph.
|
||||
// TODO(ezhulenev): Convert in other direction for fp16?
|
||||
const TensorFormat compute_data_format =
|
||||
compute_in_nhwc && data_format == FORMAT_NHWC ? FORMAT_NHWC : FORMAT_NCHW;
|
||||
(compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
|
||||
: FORMAT_NCHW;
|
||||
|
||||
VLOG(3) << "Compute Conv2D with cuDNN:"
|
||||
<< " data_format=" << ToString(data_format)
|
||||
@ -853,7 +854,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
|
||||
Tensor transformed_filter;
|
||||
|
||||
const auto transform_filter = [&](FilterTensorFormat dst_format) -> void {
|
||||
const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
|
||||
VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
|
||||
<< " to " << ToString(dst_format);
|
||||
|
||||
@ -864,18 +865,20 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
: TensorShape({filter.dim_size(3), filter.dim_size(0),
|
||||
filter.dim_size(1), filter.dim_size(2)});
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
|
||||
TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
|
||||
&transformed_filter));
|
||||
functor::TransformFilter<GPUDevice, T, int, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), dst_format,
|
||||
To32Bit(filter.tensor<T, 4>()),
|
||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
if (compute_data_format == FORMAT_NCHW) {
|
||||
transform_filter(FORMAT_OIHW);
|
||||
OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
|
||||
} else if (compute_data_format == FORMAT_NHWC) {
|
||||
transform_filter(FORMAT_OHWI);
|
||||
OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
|
||||
} else {
|
||||
ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
|
||||
ToString(compute_data_format)));
|
||||
|
Loading…
Reference in New Issue
Block a user