Add complex number support for tf.extract_image_patches
This PR tries to address the issue raised in 35955 where there was no complex number support for tf.extract_image_patches. The op `tf.extract_image_patches` itself could be used in many ways than just image so it makes sense to add complex support. This fix fixes 35955. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> Register GPU types for tf.extract_image_patches Signed-off-by: Yong Tang <yong.tang.github@outlook.com> Fix build failure with GPU Signed-off-by: Yong Tang <yong.tang.github@outlook.com> Update array_ops.cc to only include TF_CALL_NUMBER_TYPES This fix updates array_ops.cc to only include TF_CALL_NUMBER_TYPES for extract_image_patches (realnumbertypes + complex64 + complex128). Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
b8e5a9d9c5
commit
6ab77f3be3
@ -126,7 +126,7 @@ class ExtractImagePatchesOp : public UnaryOp<T> {
|
||||
Name("ExtractImagePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
ExtractImagePatchesOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
@ -145,7 +145,7 @@ namespace functor {
|
||||
typename TTypes<T, 4>::Tensor output); \
|
||||
extern template struct ExtractImagePatchesForward<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
||||
@ -157,7 +157,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
Name("ExtractImagePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
ExtractImagePatchesOp<GPUDevice, T>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
|
@ -29,7 +29,7 @@ namespace functor {
|
||||
|
||||
#define REGISTER(T) template struct ExtractImagePatchesForward<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER);
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
|
@ -2525,7 +2525,9 @@ REGISTER_OP("ExtractImagePatches")
|
||||
.Attr("ksizes: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
.Attr("rates: list(int) >= 4")
|
||||
.Attr("T: realnumbertype")
|
||||
.Attr(
|
||||
"T: {bfloat16, half, float, double, int8, int16, int32, int64, "
|
||||
"uint8, uint16, uint32, uint64, complex64, complex128, bool}")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input_shape;
|
||||
|
Loading…
Reference in New Issue
Block a user