Support fill_value for ImageProjectiveTransform
Fix conflict Fix typo Update number
This commit is contained in:
parent
9f86089e45
commit
13a9df90f4
tensorflow
core
api_def/base_api
kernels/image
ops
python
eager
keras/layers/preprocessing
ops
tools/api/golden
@ -0,0 +1,63 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ImageProjectiveTransformV3"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "images"
|
||||||
|
description: <<END
|
||||||
|
4-D with shape `[batch, height, width, channels]`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "transforms"
|
||||||
|
description: <<END
|
||||||
|
2-D Tensor, `[batch, 8]` or `[1, 8]` matrix, where each row corresponds to a 3 x 3
|
||||||
|
projective transformation matrix, with the last entry assumed to be 1. If there
|
||||||
|
is one row, the same transformation will be applied to all images.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "output_shape"
|
||||||
|
description: <<END
|
||||||
|
1-D Tensor [new_height, new_width].
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "fill_value"
|
||||||
|
description: <<END
|
||||||
|
float, the value to be filled when fill_mode is constant".
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "transformed_images"
|
||||||
|
description: <<END
|
||||||
|
4-D with shape
|
||||||
|
`[batch, new_height, new_width, channels]`.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
Input dtype.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "interpolation"
|
||||||
|
description: <<END
|
||||||
|
Interpolation method, "NEAREST" or "BILINEAR".
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "fill_mode"
|
||||||
|
description: <<END
|
||||||
|
Fill mode, "REFLECT", "WRAP", or "CONSTANT".
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Applies the given transform to each of the images."
|
||||||
|
description: <<END
|
||||||
|
If one row of `transforms` is `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps
|
||||||
|
the *output* point `(x, y)` to a transformed *input* point
|
||||||
|
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where
|
||||||
|
`k = c0 x + c1 y + 1`. If the transformed point lays outside of the input
|
||||||
|
image, the output pixel is set to fill_value.
|
||||||
|
END
|
||||||
|
}
|
@ -48,6 +48,68 @@ using functor::FillProjectiveTransform;
|
|||||||
using generator::Interpolation;
|
using generator::Interpolation;
|
||||||
using generator::Mode;
|
using generator::Mode;
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
void DoImageProjectiveTransformOp(OpKernelContext* ctx,
|
||||||
|
const Interpolation& interpolation,
|
||||||
|
const Mode& fill_mode) {
|
||||||
|
const Tensor& images_t = ctx->input(0);
|
||||||
|
const Tensor& transform_t = ctx->input(1);
|
||||||
|
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
|
||||||
|
errors::InvalidArgument("Input images must have rank 4"));
|
||||||
|
OP_REQUIRES(ctx,
|
||||||
|
(TensorShapeUtils::IsMatrix(transform_t.shape()) &&
|
||||||
|
(transform_t.dim_size(0) == images_t.dim_size(0) ||
|
||||||
|
transform_t.dim_size(0) == 1) &&
|
||||||
|
transform_t.dim_size(1) == 8),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input transform should be num_images x 8 or 1 x 8"));
|
||||||
|
|
||||||
|
int32 out_height, out_width;
|
||||||
|
// Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
|
||||||
|
if (ctx->num_inputs() >= 3) {
|
||||||
|
const Tensor& shape_t = ctx->input(2);
|
||||||
|
OP_REQUIRES(ctx, shape_t.dims() == 1,
|
||||||
|
errors::InvalidArgument("output shape must be 1-dimensional",
|
||||||
|
shape_t.shape().DebugString()));
|
||||||
|
OP_REQUIRES(ctx, shape_t.NumElements() == 2,
|
||||||
|
errors::InvalidArgument("output shape must have two elements",
|
||||||
|
shape_t.shape().DebugString()));
|
||||||
|
auto shape_vec = shape_t.vec<int32>();
|
||||||
|
out_height = shape_vec(0);
|
||||||
|
out_width = shape_vec(1);
|
||||||
|
OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
|
||||||
|
errors::InvalidArgument("output dimensions must be positive"));
|
||||||
|
} else {
|
||||||
|
// Shape is N (batch size), H (height), W (width), C (channels).
|
||||||
|
out_height = images_t.shape().dim_size(1);
|
||||||
|
out_width = images_t.shape().dim_size(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
T fill_value(0);
|
||||||
|
// Kernel is shared by "ImageProjectiveTransformV2" with 3 args.
|
||||||
|
if (ctx->num_inputs() >= 4) {
|
||||||
|
const Tensor& fill_value_t = ctx->input(3);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(fill_value_t.shape()),
|
||||||
|
errors::InvalidArgument("fill_value must be a scalar",
|
||||||
|
fill_value_t.shape().DebugString()));
|
||||||
|
fill_value = static_cast<T>(*(fill_value_t.scalar<float>().data()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor* output_t;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0,
|
||||||
|
TensorShape({images_t.dim_size(0), out_height,
|
||||||
|
out_width, images_t.dim_size(3)}),
|
||||||
|
&output_t));
|
||||||
|
auto output = output_t->tensor<T, 4>();
|
||||||
|
auto images = images_t.tensor<T, 4>();
|
||||||
|
auto transform = transform_t.matrix<float>();
|
||||||
|
|
||||||
|
(FillProjectiveTransform<Device, T>(interpolation))(
|
||||||
|
ctx->eigen_device<Device>(), &output, images, transform, fill_mode,
|
||||||
|
fill_value);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class ImageProjectiveTransformV2 : public OpKernel {
|
class ImageProjectiveTransformV2 : public OpKernel {
|
||||||
private:
|
private:
|
||||||
@ -84,52 +146,7 @@ class ImageProjectiveTransformV2 : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor& images_t = ctx->input(0);
|
DoImageProjectiveTransformOp<Device, T>(ctx, interpolation_, fill_mode_);
|
||||||
const Tensor& transform_t = ctx->input(1);
|
|
||||||
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
|
|
||||||
errors::InvalidArgument("Input images must have rank 4"));
|
|
||||||
OP_REQUIRES(ctx,
|
|
||||||
(TensorShapeUtils::IsMatrix(transform_t.shape()) &&
|
|
||||||
(transform_t.dim_size(0) == images_t.dim_size(0) ||
|
|
||||||
transform_t.dim_size(0) == 1) &&
|
|
||||||
transform_t.dim_size(1) == 8),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Input transform should be num_images x 8 or 1 x 8"));
|
|
||||||
|
|
||||||
int32 out_height, out_width;
|
|
||||||
// Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
|
|
||||||
if (ctx->num_inputs() >= 3) {
|
|
||||||
const Tensor& shape_t = ctx->input(2);
|
|
||||||
OP_REQUIRES(ctx, shape_t.dims() == 1,
|
|
||||||
errors::InvalidArgument("output shape must be 1-dimensional",
|
|
||||||
shape_t.shape().DebugString()));
|
|
||||||
OP_REQUIRES(ctx, shape_t.NumElements() == 2,
|
|
||||||
errors::InvalidArgument("output shape must have two elements",
|
|
||||||
shape_t.shape().DebugString()));
|
|
||||||
auto shape_vec = shape_t.vec<int32>();
|
|
||||||
out_height = shape_vec(0);
|
|
||||||
out_width = shape_vec(1);
|
|
||||||
OP_REQUIRES(
|
|
||||||
ctx, out_height > 0 && out_width > 0,
|
|
||||||
errors::InvalidArgument("output dimensions must be positive"));
|
|
||||||
} else {
|
|
||||||
// Shape is N (batch size), H (height), W (width), C (channels).
|
|
||||||
out_height = images_t.shape().dim_size(1);
|
|
||||||
out_width = images_t.shape().dim_size(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor* output_t;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(
|
|
||||||
0,
|
|
||||||
TensorShape({images_t.dim_size(0), out_height,
|
|
||||||
out_width, images_t.dim_size(3)}),
|
|
||||||
&output_t));
|
|
||||||
auto output = output_t->tensor<T, 4>();
|
|
||||||
auto images = images_t.tensor<T, 4>();
|
|
||||||
auto transform = transform_t.matrix<float>();
|
|
||||||
|
|
||||||
(FillProjectiveTransform<Device, T>(interpolation_))(
|
|
||||||
ctx->eigen_device<Device>(), &output, images, transform, fill_mode_);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -148,6 +165,29 @@ TF_CALL_double(REGISTER);
|
|||||||
|
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class ImageProjectiveTransformV3
|
||||||
|
: public ImageProjectiveTransformV2<Device, T> {
|
||||||
|
public:
|
||||||
|
explicit ImageProjectiveTransformV3(OpKernelConstruction* ctx)
|
||||||
|
: ImageProjectiveTransformV2<Device, T>(ctx) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(TYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV3") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
ImageProjectiveTransformV3<CPUDevice, TYPE>)
|
||||||
|
|
||||||
|
TF_CALL_uint8(REGISTER);
|
||||||
|
TF_CALL_int32(REGISTER);
|
||||||
|
TF_CALL_int64(REGISTER);
|
||||||
|
TF_CALL_half(REGISTER);
|
||||||
|
TF_CALL_float(REGISTER);
|
||||||
|
TF_CALL_double(REGISTER);
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
@ -161,7 +201,8 @@ namespace functor {
|
|||||||
template <> \
|
template <> \
|
||||||
void FillProjectiveTransform<GPUDevice, TYPE>::operator()( \
|
void FillProjectiveTransform<GPUDevice, TYPE>::operator()( \
|
||||||
const GPUDevice& device, OutputType* output, const InputType& images, \
|
const GPUDevice& device, OutputType* output, const InputType& images, \
|
||||||
const TransformsType& transform, const Mode fill_mode) const; \
|
const TransformsType& transform, const Mode fill_mode, \
|
||||||
|
const TYPE fill_value) const; \
|
||||||
extern template struct FillProjectiveTransform<GPUDevice, TYPE>
|
extern template struct FillProjectiveTransform<GPUDevice, TYPE>
|
||||||
|
|
||||||
TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR);
|
TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR);
|
||||||
@ -204,6 +245,23 @@ TF_CALL_double(REGISTER);
|
|||||||
|
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
|
#define REGISTER(TYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV3") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.TypeConstraint<TYPE>("dtype") \
|
||||||
|
.HostMemory("output_shape") \
|
||||||
|
.HostMemory("fill_value"), \
|
||||||
|
ImageProjectiveTransformV3<GPUDevice, TYPE>)
|
||||||
|
|
||||||
|
TF_CALL_uint8(REGISTER);
|
||||||
|
TF_CALL_int32(REGISTER);
|
||||||
|
TF_CALL_int64(REGISTER);
|
||||||
|
TF_CALL_half(REGISTER);
|
||||||
|
TF_CALL_float(REGISTER);
|
||||||
|
TF_CALL_double(REGISTER);
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -107,17 +107,20 @@ class ProjectiveGenerator {
|
|||||||
typename TTypes<T, 4>::ConstTensor input_;
|
typename TTypes<T, 4>::ConstTensor input_;
|
||||||
typename TTypes<float>::ConstMatrix transforms_;
|
typename TTypes<float>::ConstMatrix transforms_;
|
||||||
const Interpolation interpolation_;
|
const Interpolation interpolation_;
|
||||||
|
const T fill_value_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||||
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
||||||
typename TTypes<float>::ConstMatrix transforms,
|
typename TTypes<float>::ConstMatrix transforms,
|
||||||
const Interpolation interpolation)
|
const Interpolation interpolation, const T fill_value)
|
||||||
: input_(input), transforms_(transforms), interpolation_(interpolation) {}
|
: input_(input),
|
||||||
|
transforms_(transforms),
|
||||||
|
interpolation_(interpolation),
|
||||||
|
fill_value_(fill_value) {}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
operator()(const array<DenseIndex, 4>& coords) const {
|
operator()(const array<DenseIndex, 4>& coords) const {
|
||||||
const T fill_value = T(0);
|
|
||||||
const int64 output_y = coords[1];
|
const int64 output_y = coords[1];
|
||||||
const int64 output_x = coords[2];
|
const int64 output_x = coords[2];
|
||||||
const float* transform =
|
const float* transform =
|
||||||
@ -126,9 +129,9 @@ class ProjectiveGenerator {
|
|||||||
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
|
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
|
||||||
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
|
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
|
||||||
if (projection == 0) {
|
if (projection == 0) {
|
||||||
// Return the fill value (0) for infinite coordinates,
|
// Return the fill value for infinite coordinates,
|
||||||
// which are outside the input image
|
// which are outside the input image
|
||||||
return fill_value;
|
return fill_value_;
|
||||||
}
|
}
|
||||||
const float input_x =
|
const float input_x =
|
||||||
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
||||||
@ -146,13 +149,13 @@ class ProjectiveGenerator {
|
|||||||
const DenseIndex channels = coords[3];
|
const DenseIndex channels = coords[3];
|
||||||
switch (interpolation_) {
|
switch (interpolation_) {
|
||||||
case NEAREST:
|
case NEAREST:
|
||||||
return nearest_interpolation(batch, y, x, channels, fill_value);
|
return nearest_interpolation(batch, y, x, channels, fill_value_);
|
||||||
case BILINEAR:
|
case BILINEAR:
|
||||||
return bilinear_interpolation(batch, y, x, channels, fill_value);
|
return bilinear_interpolation(batch, y, x, channels, fill_value_);
|
||||||
}
|
}
|
||||||
// Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
|
// Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
|
||||||
// or INTERPOLATION_BILINEAR.
|
// or INTERPOLATION_BILINEAR.
|
||||||
return fill_value;
|
return fill_value_;
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
@ -225,27 +228,27 @@ struct FillProjectiveTransform {
|
|||||||
EIGEN_ALWAYS_INLINE
|
EIGEN_ALWAYS_INLINE
|
||||||
void operator()(const Device& device, OutputType* output,
|
void operator()(const Device& device, OutputType* output,
|
||||||
const InputType& images, const TransformsType& transform,
|
const InputType& images, const TransformsType& transform,
|
||||||
const Mode fill_mode) const {
|
const Mode fill_mode, const T fill_value) const {
|
||||||
switch (fill_mode) {
|
switch (fill_mode) {
|
||||||
case Mode::FILL_REFLECT:
|
case Mode::FILL_REFLECT:
|
||||||
output->device(device) =
|
output->device(device) =
|
||||||
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_REFLECT>(
|
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_REFLECT>(
|
||||||
images, transform, interpolation));
|
images, transform, interpolation, fill_value));
|
||||||
break;
|
break;
|
||||||
case Mode::FILL_WRAP:
|
case Mode::FILL_WRAP:
|
||||||
output->device(device) =
|
output->device(device) =
|
||||||
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_WRAP>(
|
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_WRAP>(
|
||||||
images, transform, interpolation));
|
images, transform, interpolation, fill_value));
|
||||||
break;
|
break;
|
||||||
case Mode::FILL_CONSTANT:
|
case Mode::FILL_CONSTANT:
|
||||||
output->device(device) = output->generate(
|
output->device(device) = output->generate(
|
||||||
ProjectiveGenerator<Device, T, Mode::FILL_CONSTANT>(
|
ProjectiveGenerator<Device, T, Mode::FILL_CONSTANT>(
|
||||||
images, transform, interpolation));
|
images, transform, interpolation, fill_value));
|
||||||
break;
|
break;
|
||||||
case Mode::FILL_NEAREST:
|
case Mode::FILL_NEAREST:
|
||||||
output->device(device) =
|
output->device(device) =
|
||||||
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_NEAREST>(
|
output->generate(ProjectiveGenerator<Device, T, Mode::FILL_NEAREST>(
|
||||||
images, transform, interpolation));
|
images, transform, interpolation, fill_value));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1146,8 +1146,9 @@ REGISTER_OP("GenerateBoundingBoxProposals")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
// V3 op supports fill_value.
|
||||||
// V2 op supports output_shape. V1 op is in contrib.
|
// V2 op supports output_shape.
|
||||||
|
// V1 op is in contrib.
|
||||||
REGISTER_OP("ImageProjectiveTransformV2")
|
REGISTER_OP("ImageProjectiveTransformV2")
|
||||||
.Input("images: dtype")
|
.Input("images: dtype")
|
||||||
.Input("transforms: float32")
|
.Input("transforms: float32")
|
||||||
@ -1163,4 +1164,20 @@ REGISTER_OP("ImageProjectiveTransformV2")
|
|||||||
c->Dim(input, 3));
|
c->Dim(input, 3));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("ImageProjectiveTransformV3")
|
||||||
|
.Input("images: dtype")
|
||||||
|
.Input("transforms: float32")
|
||||||
|
.Input("output_shape: int32")
|
||||||
|
.Input("fill_value: float32")
|
||||||
|
.Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
|
||||||
|
.Attr("interpolation: string")
|
||||||
|
.Attr("fill_mode: string = 'CONSTANT'")
|
||||||
|
.Output("transformed_images: dtype")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle input;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
|
||||||
|
return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
|
||||||
|
c->Dim(input, 3));
|
||||||
|
});
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19102,6 +19102,54 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "ImageProjectiveTransformV3"
|
||||||
|
input_arg {
|
||||||
|
name: "images"
|
||||||
|
type_attr: "dtype"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "transforms"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "output_shape"
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "fill_value"
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "transformed_images"
|
||||||
|
type_attr: "dtype"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "interpolation"
|
||||||
|
type: "string"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "fill_mode"
|
||||||
|
type: "string"
|
||||||
|
default_value {
|
||||||
|
s: "CONSTANT"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "ImageSummary"
|
name: "ImageSummary"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
|||||||
|
|
||||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||||
const tensorflow::string &op_name) {
|
const tensorflow::string &op_name) {
|
||||||
static std::array<OpIndexInfo, 349> a = {{
|
static std::array<OpIndexInfo, 350> a = {{
|
||||||
{"Acosh"},
|
{"Acosh"},
|
||||||
{"AllToAll", 1, {0}},
|
{"AllToAll", 1, {0}},
|
||||||
{"ApproximateEqual"},
|
{"ApproximateEqual"},
|
||||||
@ -152,6 +152,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
|||||||
{"IdentityReader"},
|
{"IdentityReader"},
|
||||||
{"Imag"},
|
{"Imag"},
|
||||||
{"ImageProjectiveTransformV2", 1, {2}},
|
{"ImageProjectiveTransformV2", 1, {2}},
|
||||||
|
{"ImageProjectiveTransformV3", 2, {2, 3}},
|
||||||
{"ImageSummary"},
|
{"ImageSummary"},
|
||||||
{"InitializeTable"},
|
{"InitializeTable"},
|
||||||
{"InitializeTableFromTextFile"},
|
{"InitializeTableFromTextFile"},
|
||||||
@ -412,7 +413,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
|||||||
|
|
||||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||||
const tensorflow::string &op_name) {
|
const tensorflow::string &op_name) {
|
||||||
static std::array<OpIndexInfo, 465> a = {{
|
static std::array<OpIndexInfo, 466> a = {{
|
||||||
{"Abs"},
|
{"Abs"},
|
||||||
{"AccumulateNV2"},
|
{"AccumulateNV2"},
|
||||||
{"Acos"},
|
{"Acos"},
|
||||||
@ -568,6 +569,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
|||||||
{"Igammac"},
|
{"Igammac"},
|
||||||
{"Imag"},
|
{"Imag"},
|
||||||
{"ImageProjectiveTransformV2"},
|
{"ImageProjectiveTransformV2"},
|
||||||
|
{"ImageProjectiveTransformV3"},
|
||||||
{"ImageSummary"},
|
{"ImageSummary"},
|
||||||
{"InitializeTable"},
|
{"InitializeTable"},
|
||||||
{"InitializeTableFromTextFile"},
|
{"InitializeTableFromTextFile"},
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -466,6 +467,8 @@ class RandomTranslation(PreprocessingLayer):
|
|||||||
The input is extended by wrapping around to the opposite edge.
|
The input is extended by wrapping around to the opposite edge.
|
||||||
- *nearest*: `(a a a a | a b c d | d d d d)`
|
- *nearest*: `(a a a a | a b c d | d d d d)`
|
||||||
The input is extended by the nearest pixel.
|
The input is extended by the nearest pixel.
|
||||||
|
fill_value: a float represents the value to be filled outside the
|
||||||
|
boundaries when `fill_mode` is "constant".
|
||||||
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
||||||
seed: Integer. Used to create a random seed.
|
seed: Integer. Used to create a random seed.
|
||||||
name: A string, the name of the layer.
|
name: A string, the name of the layer.
|
||||||
@ -487,6 +490,7 @@ class RandomTranslation(PreprocessingLayer):
|
|||||||
height_factor,
|
height_factor,
|
||||||
width_factor,
|
width_factor,
|
||||||
fill_mode='reflect',
|
fill_mode='reflect',
|
||||||
|
fill_value=0.0,
|
||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
seed=None,
|
seed=None,
|
||||||
name=None,
|
name=None,
|
||||||
@ -522,6 +526,7 @@ class RandomTranslation(PreprocessingLayer):
|
|||||||
check_fill_mode_and_interpolation(fill_mode, interpolation)
|
check_fill_mode_and_interpolation(fill_mode, interpolation)
|
||||||
|
|
||||||
self.fill_mode = fill_mode
|
self.fill_mode = fill_mode
|
||||||
|
self.fill_value = fill_value
|
||||||
self.interpolation = interpolation
|
self.interpolation = interpolation
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self._rng = make_generator(self.seed)
|
self._rng = make_generator(self.seed)
|
||||||
@ -559,7 +564,8 @@ class RandomTranslation(PreprocessingLayer):
|
|||||||
inputs,
|
inputs,
|
||||||
get_translation_matrix(translations),
|
get_translation_matrix(translations),
|
||||||
interpolation=self.interpolation,
|
interpolation=self.interpolation,
|
||||||
fill_mode=self.fill_mode)
|
fill_mode=self.fill_mode,
|
||||||
|
fill_value=self.fill_value)
|
||||||
|
|
||||||
output = control_flow_util.smart_cond(training, random_translated_inputs,
|
output = control_flow_util.smart_cond(training, random_translated_inputs,
|
||||||
lambda: inputs)
|
lambda: inputs)
|
||||||
@ -574,6 +580,7 @@ class RandomTranslation(PreprocessingLayer):
|
|||||||
'height_factor': self.height_factor,
|
'height_factor': self.height_factor,
|
||||||
'width_factor': self.width_factor,
|
'width_factor': self.width_factor,
|
||||||
'fill_mode': self.fill_mode,
|
'fill_mode': self.fill_mode,
|
||||||
|
'fill_value': self.fill_value,
|
||||||
'interpolation': self.interpolation,
|
'interpolation': self.interpolation,
|
||||||
'seed': self.seed,
|
'seed': self.seed,
|
||||||
}
|
}
|
||||||
@ -617,6 +624,7 @@ def get_translation_matrix(translations, name=None):
|
|||||||
def transform(images,
|
def transform(images,
|
||||||
transforms,
|
transforms,
|
||||||
fill_mode='reflect',
|
fill_mode='reflect',
|
||||||
|
fill_value=0.0,
|
||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
output_shape=None,
|
output_shape=None,
|
||||||
name=None):
|
name=None):
|
||||||
@ -636,6 +644,8 @@ def transform(images,
|
|||||||
not backpropagated into transformation parameters.
|
not backpropagated into transformation parameters.
|
||||||
fill_mode: Points outside the boundaries of the input are filled according
|
fill_mode: Points outside the boundaries of the input are filled according
|
||||||
to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
|
to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
|
||||||
|
fill_value: a float represents the value to be filled outside the
|
||||||
|
boundaries when `fill_mode` is "constant".
|
||||||
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
||||||
output_shape: Output dimesion after the transform, [height, width]. If None,
|
output_shape: Output dimesion after the transform, [height, width]. If None,
|
||||||
output is the same size as input image.
|
output is the same size as input image.
|
||||||
@ -689,6 +699,18 @@ def transform(images,
|
|||||||
'new_height, new_width, instead got '
|
'new_height, new_width, instead got '
|
||||||
'{}'.format(output_shape))
|
'{}'.format(output_shape))
|
||||||
|
|
||||||
|
fill_value = ops.convert_to_tensor_v2(
|
||||||
|
fill_value, dtypes.float32, name='fill_value')
|
||||||
|
|
||||||
|
if compat.forward_compatible(2020, 8, 5):
|
||||||
|
return gen_image_ops.ImageProjectiveTransformV3(
|
||||||
|
images=images,
|
||||||
|
output_shape=output_shape,
|
||||||
|
fill_value=fill_value,
|
||||||
|
transforms=transforms,
|
||||||
|
fill_mode=fill_mode.upper(),
|
||||||
|
interpolation=interpolation.upper())
|
||||||
|
|
||||||
return gen_image_ops.ImageProjectiveTransformV2(
|
return gen_image_ops.ImageProjectiveTransformV2(
|
||||||
images=images,
|
images=images,
|
||||||
output_shape=output_shape,
|
output_shape=output_shape,
|
||||||
@ -774,6 +796,8 @@ class RandomRotation(PreprocessingLayer):
|
|||||||
The input is extended by wrapping around to the opposite edge.
|
The input is extended by wrapping around to the opposite edge.
|
||||||
- *nearest*: `(a a a a | a b c d | d d d d)`
|
- *nearest*: `(a a a a | a b c d | d d d d)`
|
||||||
The input is extended by the nearest pixel.
|
The input is extended by the nearest pixel.
|
||||||
|
fill_value: a float represents the value to be filled outside the
|
||||||
|
boundaries when `fill_mode` is "constant".
|
||||||
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
||||||
seed: Integer. Used to create a random seed.
|
seed: Integer. Used to create a random seed.
|
||||||
name: A string, the name of the layer.
|
name: A string, the name of the layer.
|
||||||
@ -793,6 +817,7 @@ class RandomRotation(PreprocessingLayer):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
factor,
|
factor,
|
||||||
fill_mode='reflect',
|
fill_mode='reflect',
|
||||||
|
fill_value=0.0,
|
||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
seed=None,
|
seed=None,
|
||||||
name=None,
|
name=None,
|
||||||
@ -809,6 +834,7 @@ class RandomRotation(PreprocessingLayer):
|
|||||||
'got {}'.format(factor))
|
'got {}'.format(factor))
|
||||||
check_fill_mode_and_interpolation(fill_mode, interpolation)
|
check_fill_mode_and_interpolation(fill_mode, interpolation)
|
||||||
self.fill_mode = fill_mode
|
self.fill_mode = fill_mode
|
||||||
|
self.fill_value = fill_value
|
||||||
self.interpolation = interpolation
|
self.interpolation = interpolation
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self._rng = make_generator(self.seed)
|
self._rng = make_generator(self.seed)
|
||||||
@ -834,6 +860,7 @@ class RandomRotation(PreprocessingLayer):
|
|||||||
inputs,
|
inputs,
|
||||||
get_rotation_matrix(angles, img_hd, img_wd),
|
get_rotation_matrix(angles, img_hd, img_wd),
|
||||||
fill_mode=self.fill_mode,
|
fill_mode=self.fill_mode,
|
||||||
|
fill_value=self.fill_value,
|
||||||
interpolation=self.interpolation)
|
interpolation=self.interpolation)
|
||||||
|
|
||||||
output = control_flow_util.smart_cond(training, random_rotated_inputs,
|
output = control_flow_util.smart_cond(training, random_rotated_inputs,
|
||||||
@ -848,6 +875,7 @@ class RandomRotation(PreprocessingLayer):
|
|||||||
config = {
|
config = {
|
||||||
'factor': self.factor,
|
'factor': self.factor,
|
||||||
'fill_mode': self.fill_mode,
|
'fill_mode': self.fill_mode,
|
||||||
|
'fill_value': self.fill_value,
|
||||||
'interpolation': self.interpolation,
|
'interpolation': self.interpolation,
|
||||||
'seed': self.seed,
|
'seed': self.seed,
|
||||||
}
|
}
|
||||||
@ -889,6 +917,8 @@ class RandomZoom(PreprocessingLayer):
|
|||||||
The input is extended by wrapping around to the opposite edge.
|
The input is extended by wrapping around to the opposite edge.
|
||||||
- *nearest*: `(a a a a | a b c d | d d d d)`
|
- *nearest*: `(a a a a | a b c d | d d d d)`
|
||||||
The input is extended by the nearest pixel.
|
The input is extended by the nearest pixel.
|
||||||
|
fill_value: a float represents the value to be filled outside the
|
||||||
|
boundaries when `fill_mode` is "constant".
|
||||||
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
||||||
seed: Integer. Used to create a random seed.
|
seed: Integer. Used to create a random seed.
|
||||||
name: A string, the name of the layer.
|
name: A string, the name of the layer.
|
||||||
@ -914,11 +944,11 @@ class RandomZoom(PreprocessingLayer):
|
|||||||
negative.
|
negative.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO(b/156526279): Add `fill_value` argument.
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
height_factor,
|
height_factor,
|
||||||
width_factor=None,
|
width_factor=None,
|
||||||
fill_mode='reflect',
|
fill_mode='reflect',
|
||||||
|
fill_value=0.0,
|
||||||
interpolation='bilinear',
|
interpolation='bilinear',
|
||||||
seed=None,
|
seed=None,
|
||||||
name=None,
|
name=None,
|
||||||
@ -951,6 +981,7 @@ class RandomZoom(PreprocessingLayer):
|
|||||||
check_fill_mode_and_interpolation(fill_mode, interpolation)
|
check_fill_mode_and_interpolation(fill_mode, interpolation)
|
||||||
|
|
||||||
self.fill_mode = fill_mode
|
self.fill_mode = fill_mode
|
||||||
|
self.fill_value = fill_value
|
||||||
self.interpolation = interpolation
|
self.interpolation = interpolation
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self._rng = make_generator(self.seed)
|
self._rng = make_generator(self.seed)
|
||||||
@ -985,6 +1016,7 @@ class RandomZoom(PreprocessingLayer):
|
|||||||
return transform(
|
return transform(
|
||||||
inputs, get_zoom_matrix(zooms, img_hd, img_wd),
|
inputs, get_zoom_matrix(zooms, img_hd, img_wd),
|
||||||
fill_mode=self.fill_mode,
|
fill_mode=self.fill_mode,
|
||||||
|
fill_value=self.fill_value,
|
||||||
interpolation=self.interpolation)
|
interpolation=self.interpolation)
|
||||||
|
|
||||||
output = control_flow_util.smart_cond(training, random_zoomed_inputs,
|
output = control_flow_util.smart_cond(training, random_zoomed_inputs,
|
||||||
@ -1000,6 +1032,7 @@ class RandomZoom(PreprocessingLayer):
|
|||||||
'height_factor': self.height_factor,
|
'height_factor': self.height_factor,
|
||||||
'width_factor': self.width_factor,
|
'width_factor': self.width_factor,
|
||||||
'fill_mode': self.fill_mode,
|
'fill_mode': self.fill_mode,
|
||||||
|
'fill_value': self.fill_value,
|
||||||
'interpolation': self.interpolation,
|
'interpolation': self.interpolation,
|
||||||
'seed': self.seed,
|
'seed': self.seed,
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
|
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
@ -698,11 +699,13 @@ class RandomTransformTest(keras_parameterized.TestCase):
|
|||||||
transform_matrix,
|
transform_matrix,
|
||||||
expected_output,
|
expected_output,
|
||||||
mode,
|
mode,
|
||||||
|
fill_value=0.0,
|
||||||
interpolation='bilinear'):
|
interpolation='bilinear'):
|
||||||
inp = np.arange(15).reshape((1, 5, 3, 1)).astype(np.float32)
|
inp = np.arange(15).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
output = image_preprocessing.transform(
|
output = image_preprocessing.transform(
|
||||||
inp, transform_matrix, fill_mode=mode, interpolation=interpolation)
|
inp, transform_matrix, fill_mode=mode,
|
||||||
|
fill_value=fill_value, interpolation=interpolation)
|
||||||
self.assertAllClose(expected_output, output)
|
self.assertAllClose(expected_output, output)
|
||||||
|
|
||||||
def test_random_translation_reflect(self):
|
def test_random_translation_reflect(self):
|
||||||
@ -871,7 +874,7 @@ class RandomTransformTest(keras_parameterized.TestCase):
|
|||||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||||
'nearest')
|
'nearest')
|
||||||
|
|
||||||
def test_random_translation_constant(self):
|
def test_random_translation_constant_0(self):
|
||||||
# constant output is (0000|abcd|0000)
|
# constant output is (0000|abcd|0000)
|
||||||
|
|
||||||
# Test down shift by 1.
|
# Test down shift by 1.
|
||||||
@ -926,6 +929,62 @@ class RandomTransformTest(keras_parameterized.TestCase):
|
|||||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||||
'constant')
|
'constant')
|
||||||
|
|
||||||
|
def test_random_translation_constant_1(self):
|
||||||
|
with compat.forward_compatibility_horizon(2020, 8, 6):
|
||||||
|
# constant output is (1111|abcd|1111)
|
||||||
|
|
||||||
|
# Test down shift by 1.
|
||||||
|
# pyformat: disable
|
||||||
|
expected_output = np.asarray(
|
||||||
|
[[1., 1., 1.],
|
||||||
|
[0., 1., 2.],
|
||||||
|
[3., 4., 5.],
|
||||||
|
[6., 7., 8],
|
||||||
|
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||||
|
# pyformat: enable
|
||||||
|
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
|
||||||
|
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||||
|
'constant', fill_value=1.0)
|
||||||
|
|
||||||
|
# Test up shift by 1.
|
||||||
|
# pyformat: disable
|
||||||
|
expected_output = np.asarray(
|
||||||
|
[[3., 4., 5.],
|
||||||
|
[6., 7., 8],
|
||||||
|
[9., 10., 11.],
|
||||||
|
[12., 13., 14.],
|
||||||
|
[1., 1., 1.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||||
|
# pyformat: enable
|
||||||
|
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
|
||||||
|
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||||
|
'constant', fill_value=1.0)
|
||||||
|
|
||||||
|
# Test left shift by 1.
|
||||||
|
# pyformat: disable
|
||||||
|
expected_output = np.asarray(
|
||||||
|
[[1., 2., 1.],
|
||||||
|
[4., 5., 1.],
|
||||||
|
[7., 8., 1.],
|
||||||
|
[10., 11., 1.],
|
||||||
|
[13., 14., 1.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||||
|
# pyformat: enable
|
||||||
|
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
|
||||||
|
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||||
|
'constant', fill_value=1.0)
|
||||||
|
|
||||||
|
# Test right shift by 1.
|
||||||
|
# pyformat: disable
|
||||||
|
expected_output = np.asarray(
|
||||||
|
[[1., 0., 1.],
|
||||||
|
[1., 3., 4],
|
||||||
|
[1., 6., 7.],
|
||||||
|
[1., 9., 10.],
|
||||||
|
[1., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||||
|
# pyformat: enable
|
||||||
|
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
|
||||||
|
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||||
|
'constant', fill_value=1.0)
|
||||||
|
|
||||||
def test_random_translation_nearest_interpolation(self):
|
def test_random_translation_nearest_interpolation(self):
|
||||||
# nearest output is (aaaa|abcd|dddd)
|
# nearest output is (aaaa|abcd|dddd)
|
||||||
|
|
||||||
|
@ -271,3 +271,38 @@ def _image_projective_transform_grad(op, grad):
|
|||||||
interpolation=interpolation,
|
interpolation=interpolation,
|
||||||
fill_mode=fill_mode)
|
fill_mode=fill_mode)
|
||||||
return [output, None, None]
|
return [output, None, None]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("ImageProjectiveTransformV3")
|
||||||
|
def _image_projective_transform_v3_grad(op, grad):
|
||||||
|
"""Computes the gradient for ImageProjectiveTransform."""
|
||||||
|
images = op.inputs[0]
|
||||||
|
transforms = op.inputs[1]
|
||||||
|
interpolation = op.get_attr("interpolation")
|
||||||
|
fill_mode = op.get_attr("fill_mode")
|
||||||
|
|
||||||
|
image_or_images = ops.convert_to_tensor(images, name="images")
|
||||||
|
transform_or_transforms = ops.convert_to_tensor(
|
||||||
|
transforms, name="transforms", dtype=dtypes.float32)
|
||||||
|
|
||||||
|
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
|
||||||
|
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
|
||||||
|
if len(transform_or_transforms.get_shape()) == 1:
|
||||||
|
transforms = transform_or_transforms[None]
|
||||||
|
elif len(transform_or_transforms.get_shape()) == 2:
|
||||||
|
transforms = transform_or_transforms
|
||||||
|
else:
|
||||||
|
raise TypeError("Transforms should have rank 1 or 2.")
|
||||||
|
|
||||||
|
# Invert transformations
|
||||||
|
transforms = flat_transforms_to_matrices(transforms=transforms)
|
||||||
|
inverse = linalg_ops.matrix_inverse(transforms)
|
||||||
|
transforms = matrices_to_flat_transforms(inverse)
|
||||||
|
output = gen_image_ops.image_projective_transform_v3(
|
||||||
|
images=grad,
|
||||||
|
transforms=transforms,
|
||||||
|
output_shape=array_ops.shape(image_or_images)[1:3],
|
||||||
|
interpolation=interpolation,
|
||||||
|
fill_mode=fill_mode,
|
||||||
|
fill_value=0.0)
|
||||||
|
return [output, None, None, None]
|
||||||
|
@ -118,7 +118,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adapt"
|
name: "adapt"
|
||||||
|
@ -118,7 +118,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adapt"
|
name: "adapt"
|
||||||
|
@ -118,7 +118,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'reflect\', \'bilinear\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adapt"
|
name: "adapt"
|
||||||
|
@ -1880,6 +1880,10 @@ tf_module {
|
|||||||
name: "ImageProjectiveTransformV2"
|
name: "ImageProjectiveTransformV2"
|
||||||
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
|
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ImageProjectiveTransformV3"
|
||||||
|
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'fill_value\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ImageSummary"
|
name: "ImageSummary"
|
||||||
argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n dim {\\n size: 4\\n }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], "
|
argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n dim {\\n size: 4\\n }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], "
|
||||||
|
@ -118,7 +118,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adapt"
|
name: "adapt"
|
||||||
|
@ -118,7 +118,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adapt"
|
name: "adapt"
|
||||||
|
@ -118,7 +118,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'reflect\', \'bilinear\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'reflect\', \'0.0\', \'bilinear\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adapt"
|
name: "adapt"
|
||||||
|
@ -1880,6 +1880,10 @@ tf_module {
|
|||||||
name: "ImageProjectiveTransformV2"
|
name: "ImageProjectiveTransformV2"
|
||||||
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
|
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ImageProjectiveTransformV3"
|
||||||
|
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'fill_value\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "ImageSummary"
|
name: "ImageSummary"
|
||||||
argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n dim {\\n size: 4\\n }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], "
|
argspec: "args=[\'tag\', \'tensor\', \'max_images\', \'bad_color\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'dtype: DT_UINT8\\ntensor_shape {\\n dim {\\n size: 4\\n }\\n}\\nint_val: 255\\nint_val: 0\\nint_val: 0\\nint_val: 255\\n\', \'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user