Add complex number support for tf.extract_image_patches

This PR tries to address the issue raised in 35955 where
there was no complex number support for tf.extract_image_patches.
The op `tf.extract_image_patches` itself could be used in many
ways than just image so it makes sense to add complex support.

This fix fixes 35955.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

Register GPU types for tf.extract_image_patches

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

Fix build failure with GPU

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

Update array_ops.cc to only include TF_CALL_NUMBER_TYPES

This fix updates array_ops.cc to only include TF_CALL_NUMBER_TYPES
for extract_image_patches (realnumbertypes + complex64 + complex128).

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-01-17 00:13:26 +00:00
parent b8e5a9d9c5
commit 6ab77f3be3
3 changed files with 7 additions and 5 deletions

View File

@ -126,7 +126,7 @@ class ExtractImagePatchesOp : public UnaryOp<T> {
Name("ExtractImagePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
ExtractImagePatchesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER);
TF_CALL_NUMBER_TYPES(REGISTER);
#undef REGISTER
@ -145,7 +145,7 @@ namespace functor {
typename TTypes<T, 4>::Tensor output); \
extern template struct ExtractImagePatchesForward<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
@ -157,7 +157,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
Name("ExtractImagePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
ExtractImagePatchesOp<GPUDevice, T>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
TF_CALL_GPU_ALL_TYPES(REGISTER);
#undef REGISTER

View File

@ -29,7 +29,7 @@ namespace functor {
#define REGISTER(T) template struct ExtractImagePatchesForward<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
TF_CALL_GPU_ALL_TYPES(REGISTER);
#undef REGISTER

View File

@ -2525,7 +2525,9 @@ REGISTER_OP("ExtractImagePatches")
.Attr("ksizes: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr("rates: list(int) >= 4")
.Attr("T: realnumbertype")
.Attr(
"T: {bfloat16, half, float, double, int8, int16, int32, int64, "
"uint8, uint16, uint32, uint64, complex64, complex128, bool}")
.Attr(GetPaddingAttrString())
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input_shape;