Export fill mode for image_projective_transform
PiperOrigin-RevId: 300652382 Change-Id: I762197f7dfbd545445db7b3b330d463f5f66d856
This commit is contained in:
parent
3614f46e1f
commit
287f14529c
@ -38,6 +38,12 @@ END
|
||||
name: "interpolation"
|
||||
description: <<END
|
||||
Interpolation method, "NEAREST" or "BILINEAR".
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "fill_mode"
|
||||
description: <<END
|
||||
Fill mode, "REFLECT", "WRAP", or "CONSTANT".
|
||||
END
|
||||
}
|
||||
summary: "Applies the given transform to each of the images."
|
||||
|
@ -46,27 +46,39 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
using functor::FillProjectiveTransform;
|
||||
using generator::Interpolation;
|
||||
using generator::INTERPOLATION_BILINEAR;
|
||||
using generator::INTERPOLATION_NEAREST;
|
||||
using generator::ProjectiveGenerator;
|
||||
using generator::Mode;
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ImageProjectiveTransform : public OpKernel {
|
||||
class ImageProjectiveTransformV2 : public OpKernel {
|
||||
private:
|
||||
Interpolation interpolation_;
|
||||
Mode fill_mode_;
|
||||
|
||||
public:
|
||||
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
string interpolation_str;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
|
||||
if (interpolation_str == "NEAREST") {
|
||||
interpolation_ = INTERPOLATION_NEAREST;
|
||||
interpolation_ = Interpolation::NEAREST;
|
||||
} else if (interpolation_str == "BILINEAR") {
|
||||
interpolation_ = INTERPOLATION_BILINEAR;
|
||||
interpolation_ = Interpolation::BILINEAR;
|
||||
} else {
|
||||
LOG(ERROR) << "Invalid interpolation " << interpolation_str
|
||||
<< ". Supported types: NEAREST, BILINEAR";
|
||||
}
|
||||
string mode_str;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("fill_mode", &mode_str));
|
||||
if (mode_str == "REFLECT") {
|
||||
fill_mode_ = Mode::REFLECT;
|
||||
} else if (mode_str == "WRAP") {
|
||||
fill_mode_ = Mode::WRAP;
|
||||
} else if (mode_str == "CONSTANT") {
|
||||
fill_mode_ = Mode::CONSTANT;
|
||||
} else {
|
||||
LOG(ERROR) << "Invalid mode " << mode_str
|
||||
<< ". Supported types: REFLECT, WRAP, CONSTANT";
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -78,8 +90,7 @@ class ImageProjectiveTransform : public OpKernel {
|
||||
(TensorShapeUtils::IsMatrix(transform_t.shape()) &&
|
||||
(transform_t.dim_size(0) == images_t.dim_size(0) ||
|
||||
transform_t.dim_size(0) == 1) &&
|
||||
transform_t.dim_size(1) ==
|
||||
ProjectiveGenerator<Device, T>::kNumParameters),
|
||||
transform_t.dim_size(1) == 8),
|
||||
errors::InvalidArgument(
|
||||
"Input transform should be num_images x 8 or 1 x 8"));
|
||||
|
||||
@ -116,15 +127,15 @@ class ImageProjectiveTransform : public OpKernel {
|
||||
auto transform = transform_t.matrix<float>();
|
||||
|
||||
(FillProjectiveTransform<Device, T>(interpolation_))(
|
||||
ctx->eigen_device<Device>(), &output, images, transform);
|
||||
ctx->eigen_device<Device>(), &output, images, transform, fill_mode_);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
ImageProjectiveTransform<CPUDevice, TYPE>)
|
||||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
ImageProjectiveTransformV2<CPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_uint8(REGISTER);
|
||||
TF_CALL_int32(REGISTER);
|
||||
@ -138,33 +149,48 @@ TF_CALL_double(REGISTER);
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
typedef generator::Mode Mode;
|
||||
|
||||
namespace functor {
|
||||
|
||||
// NOTE(ringwalt): We get an undefined symbol error if we don't explicitly
|
||||
// instantiate the operator() in GCC'd code.
|
||||
#define DECLARE_FUNCTOR(TYPE) \
|
||||
#define DECLARE_PROJECT_FUNCTOR(TYPE) \
|
||||
template <> \
|
||||
void FillProjectiveTransform<GPUDevice, TYPE>::operator()( \
|
||||
const GPUDevice& device, OutputType* output, const InputType& images, \
|
||||
const TransformsType& transform) const; \
|
||||
const TransformsType& transform, const Mode fill_mode) const; \
|
||||
extern template struct FillProjectiveTransform<GPUDevice, TYPE>
|
||||
|
||||
TF_CALL_uint8(DECLARE_FUNCTOR);
|
||||
TF_CALL_int32(DECLARE_FUNCTOR);
|
||||
TF_CALL_int64(DECLARE_FUNCTOR);
|
||||
TF_CALL_half(DECLARE_FUNCTOR);
|
||||
TF_CALL_float(DECLARE_FUNCTOR);
|
||||
TF_CALL_double(DECLARE_FUNCTOR);
|
||||
TF_CALL_uint8(DECLARE_PROJECT_FUNCTOR);
|
||||
TF_CALL_int32(DECLARE_PROJECT_FUNCTOR);
|
||||
TF_CALL_int64(DECLARE_PROJECT_FUNCTOR);
|
||||
TF_CALL_half(DECLARE_PROJECT_FUNCTOR);
|
||||
TF_CALL_float(DECLARE_PROJECT_FUNCTOR);
|
||||
TF_CALL_double(DECLARE_PROJECT_FUNCTOR);
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<TYPE>("dtype") \
|
||||
.HostMemory("output_shape"), \
|
||||
ImageProjectiveTransform<GPUDevice, TYPE>)
|
||||
namespace generator {
|
||||
|
||||
#define DECLARE_MAP_FUNCTOR(Mode) \
|
||||
template <> \
|
||||
float MapCoordinate<GPUDevice, Mode>::operator()(const float out_coord, \
|
||||
const DenseIndex len); \
|
||||
extern template struct MapCoordinate<GPUDevice, Mode>
|
||||
|
||||
DECLARE_MAP_FUNCTOR(Mode::REFLECT);
|
||||
DECLARE_MAP_FUNCTOR(Mode::WRAP);
|
||||
DECLARE_MAP_FUNCTOR(Mode::CONSTANT);
|
||||
|
||||
} // end namespace generator
|
||||
|
||||
#define REGISTER(TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<TYPE>("dtype") \
|
||||
.HostMemory("output_shape"), \
|
||||
ImageProjectiveTransformV2<GPUDevice, TYPE>)
|
||||
|
||||
TF_CALL_uint8(REGISTER);
|
||||
TF_CALL_int32(REGISTER);
|
||||
|
@ -29,12 +29,67 @@ namespace tensorflow {
|
||||
|
||||
namespace generator {
|
||||
|
||||
enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
|
||||
enum Interpolation { NEAREST, BILINEAR };
|
||||
enum Mode { REFLECT, WRAP, CONSTANT };
|
||||
|
||||
using Eigen::array;
|
||||
using Eigen::DenseIndex;
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, Mode M>
|
||||
struct MapCoordinate {
|
||||
float operator()(const float out_coord, const DenseIndex len);
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
struct MapCoordinate<Device, Mode::REFLECT> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord,
|
||||
const DenseIndex len) {
|
||||
float in_coord = out_coord;
|
||||
// Reflect [abcd] to [dcba|abcd|dcba], periodically from [0, 2 * len)
|
||||
// over [abcddcba]
|
||||
const DenseIndex boundary = 2 * len;
|
||||
// Shift coordinate to (-boundary, boundary)
|
||||
in_coord -= boundary * static_cast<DenseIndex>(in_coord / boundary);
|
||||
// Convert negative coordinates from [-boundary, 0) to [0, boundary)
|
||||
if (in_coord < 0) {
|
||||
in_coord += boundary;
|
||||
}
|
||||
// Coordinate in_coord between [len, boundary) should reverse reflect
|
||||
// to coordinate to (bounary - 1 - in_coord) between [0, len)
|
||||
if (in_coord > len - 1) {
|
||||
in_coord = boundary - 1 - in_coord;
|
||||
}
|
||||
return in_coord;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
struct MapCoordinate<Device, Mode::WRAP> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord,
|
||||
const DenseIndex len) {
|
||||
float in_coord = out_coord;
|
||||
// Wrap [abcd] to [abcd|abcd|abcd], periodically from [0, len)
|
||||
// over [abcd]
|
||||
const DenseIndex boundary = len;
|
||||
// Shift coordinate to (-boundary, boundary)
|
||||
in_coord -= boundary * static_cast<DenseIndex>(in_coord / boundary);
|
||||
// Shift negative coordinate from [-boundary, 0) to [0, boundary)
|
||||
if (in_coord < 0) {
|
||||
in_coord += boundary;
|
||||
}
|
||||
return in_coord;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
struct MapCoordinate<Device, Mode::CONSTANT> {
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord,
|
||||
const DenseIndex len) {
|
||||
return out_coord;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T, Mode M>
|
||||
class ProjectiveGenerator {
|
||||
private:
|
||||
typename TTypes<T, 4>::ConstTensor input_;
|
||||
@ -42,8 +97,6 @@ class ProjectiveGenerator {
|
||||
const Interpolation interpolation_;
|
||||
|
||||
public:
|
||||
static const int kNumParameters = 8;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<float>::ConstMatrix transforms,
|
||||
@ -52,6 +105,7 @@ class ProjectiveGenerator {
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
operator()(const array<DenseIndex, 4>& coords) const {
|
||||
const T fill_value = T(0);
|
||||
const int64 output_y = coords[1];
|
||||
const int64 output_x = coords[2];
|
||||
const float* transform =
|
||||
@ -62,7 +116,7 @@ class ProjectiveGenerator {
|
||||
if (projection == 0) {
|
||||
// Return the fill value (0) for infinite coordinates,
|
||||
// which are outside the input image
|
||||
return T(0);
|
||||
return fill_value;
|
||||
}
|
||||
const float input_x =
|
||||
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
||||
@ -71,22 +125,24 @@ class ProjectiveGenerator {
|
||||
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
|
||||
projection;
|
||||
|
||||
const T fill_value = T(0);
|
||||
// Map out-of-boundary input coordinates to in-boundary based on fill_mode.
|
||||
auto map_functor = MapCoordinate<Device, M>();
|
||||
const float x = map_functor(input_x, input_.dimension(2));
|
||||
const float y = map_functor(input_y, input_.dimension(1));
|
||||
|
||||
const DenseIndex batch = coords[0];
|
||||
const DenseIndex channels = coords[3];
|
||||
switch (interpolation_) {
|
||||
case INTERPOLATION_NEAREST:
|
||||
// Switch the order of x and y again for indexing into the image.
|
||||
return nearest_interpolation(coords[0], input_y, input_x, coords[3],
|
||||
fill_value);
|
||||
case INTERPOLATION_BILINEAR:
|
||||
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
|
||||
fill_value);
|
||||
case NEAREST:
|
||||
return nearest_interpolation(batch, y, x, channels, fill_value);
|
||||
case BILINEAR:
|
||||
return bilinear_interpolation(batch, y, x, channels, fill_value);
|
||||
}
|
||||
// Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST
|
||||
// or INTERPOLATION_BILINEAR.
|
||||
return T(0);
|
||||
return fill_value;
|
||||
}
|
||||
|
||||
private:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
nearest_interpolation(const DenseIndex batch, const float y, const float x,
|
||||
const DenseIndex channel, const T fill_value) const {
|
||||
@ -138,12 +194,10 @@ class ProjectiveGenerator {
|
||||
|
||||
} // end namespace generator
|
||||
|
||||
// NOTE(ringwalt): We MUST wrap the generate() call in a functor and explicitly
|
||||
// instantiate the functor in image_ops_gpu.cu.cc. Otherwise, we will be missing
|
||||
// some Eigen device code.
|
||||
namespace functor {
|
||||
|
||||
using generator::Interpolation;
|
||||
using generator::Mode;
|
||||
using generator::ProjectiveGenerator;
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -151,17 +205,32 @@ struct FillProjectiveTransform {
|
||||
typedef typename TTypes<T, 4>::Tensor OutputType;
|
||||
typedef typename TTypes<T, 4>::ConstTensor InputType;
|
||||
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
|
||||
const Interpolation interpolation_;
|
||||
const Interpolation interpolation;
|
||||
|
||||
FillProjectiveTransform(Interpolation interpolation)
|
||||
: interpolation_(interpolation) {}
|
||||
explicit FillProjectiveTransform(Interpolation interpolation)
|
||||
: interpolation(interpolation) {}
|
||||
|
||||
EIGEN_ALWAYS_INLINE
|
||||
void operator()(const Device& device, OutputType* output,
|
||||
const InputType& images,
|
||||
const TransformsType& transform) const {
|
||||
output->device(device) = output->generate(
|
||||
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
|
||||
const InputType& images, const TransformsType& transform,
|
||||
const Mode fill_mode) const {
|
||||
switch (fill_mode) {
|
||||
case Mode::REFLECT:
|
||||
output->device(device) =
|
||||
output->generate(ProjectiveGenerator<Device, T, Mode::REFLECT>(
|
||||
images, transform, interpolation));
|
||||
break;
|
||||
case Mode::WRAP:
|
||||
output->device(device) =
|
||||
output->generate(ProjectiveGenerator<Device, T, Mode::WRAP>(
|
||||
images, transform, interpolation));
|
||||
break;
|
||||
case Mode::CONSTANT:
|
||||
output->device(device) =
|
||||
output->generate(ProjectiveGenerator<Device, T, Mode::CONSTANT>(
|
||||
images, transform, interpolation));
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1028,7 +1028,6 @@ REGISTER_OP("GenerateBoundingBoxProposals")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// TODO(ringwalt): Add a "fill_mode" attr with "constant", "mirror", etc.
|
||||
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
||||
// V2 op supports output_shape. V1 op is in contrib.
|
||||
REGISTER_OP("ImageProjectiveTransformV2")
|
||||
@ -1037,6 +1036,7 @@ REGISTER_OP("ImageProjectiveTransformV2")
|
||||
.Input("output_shape: int32")
|
||||
.Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
|
||||
.Attr("interpolation: string")
|
||||
.Attr("fill_mode: string = 'CONSTANT'")
|
||||
.Output("transformed_images: dtype")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input;
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -418,9 +419,15 @@ class RandomTranslation(Layer):
|
||||
When represented as a single float, this value is used for both the upper
|
||||
and lower bound.
|
||||
fill_mode: Points outside the boundaries of the input are filled according
|
||||
to the given mode (one of `{'nearest', 'bilinear'}`).
|
||||
fill_value: Value used for points outside the boundaries of the input if
|
||||
`mode='constant'`.
|
||||
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
|
||||
- *reflect*: `(d c b a | a b c d | d c b a)`
|
||||
The input is extended by reflecting about the edge of the last pixel.
|
||||
- *constant*: `(k k k k | a b c d | k k k k)`
|
||||
The input is extended by filling all values beyond the edge with the
|
||||
same constant value k = 0.
|
||||
- *wrap*: `(a b c d | a b c d | a b c d)`
|
||||
The input is extended by wrapping around to the opposite edge.
|
||||
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
||||
seed: Integer. Used to create a random seed.
|
||||
name: A string, the name of the layer.
|
||||
|
||||
@ -440,8 +447,8 @@ class RandomTranslation(Layer):
|
||||
def __init__(self,
|
||||
height_factor,
|
||||
width_factor,
|
||||
fill_mode='nearest',
|
||||
fill_value=0.,
|
||||
fill_mode='reflect',
|
||||
interpolation='bilinear',
|
||||
seed=None,
|
||||
name=None,
|
||||
**kwargs):
|
||||
@ -471,11 +478,16 @@ class RandomTranslation(Layer):
|
||||
raise ValueError('`width_factor` must have values between [-1, 1], '
|
||||
'got {}'.format(width_factor))
|
||||
|
||||
if fill_mode not in {'nearest', 'bilinear'}:
|
||||
if fill_mode not in {'reflect', 'wrap', 'constant'}:
|
||||
raise NotImplementedError(
|
||||
'`fill_mode` {} is not supported yet.'.format(fill_mode))
|
||||
'Unknown `fill_mode` {}. Only `reflect`, `wrap` and '
|
||||
'`constant` are supported.'.format(fill_mode))
|
||||
if interpolation not in {'nearest', 'bilinear'}:
|
||||
raise NotImplementedError(
|
||||
'Unknown `interpolation` {}. Only `nearest` and '
|
||||
'`bilinear` are supported.'.format(interpolation))
|
||||
self.fill_mode = fill_mode
|
||||
self.fill_value = fill_value
|
||||
self.interpolation = interpolation
|
||||
self.seed = seed
|
||||
self._rng = make_generator(self.seed)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
@ -508,7 +520,8 @@ class RandomTranslation(Layer):
|
||||
return transform(
|
||||
inputs,
|
||||
get_translation_matrix(translations),
|
||||
interpolation=self.fill_mode)
|
||||
interpolation=self.interpolation,
|
||||
fill_mode=self.fill_mode)
|
||||
|
||||
output = tf_utils.smart_cond(training, random_translated_inputs,
|
||||
lambda: inputs)
|
||||
@ -523,7 +536,7 @@ class RandomTranslation(Layer):
|
||||
'height_factor': self.height_factor,
|
||||
'width_factor': self.width_factor,
|
||||
'fill_mode': self.fill_mode,
|
||||
'fill_value': self.fill_value,
|
||||
'interpolation': self.interpolation,
|
||||
'seed': self.seed,
|
||||
}
|
||||
base_config = super(RandomTranslation, self).get_config()
|
||||
@ -565,7 +578,8 @@ def get_translation_matrix(translations, name=None):
|
||||
|
||||
def transform(images,
|
||||
transforms,
|
||||
interpolation='nearest',
|
||||
fill_mode='reflect',
|
||||
interpolation='bilinear',
|
||||
output_shape=None,
|
||||
name=None):
|
||||
"""Applies the given transform(s) to the image(s).
|
||||
@ -582,11 +596,33 @@ def transform(images,
|
||||
`k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the
|
||||
transform mapping input points to output points. Note that gradients are
|
||||
not backpropagated into transformation parameters.
|
||||
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||
fill_mode: Points outside the boundaries of the input are filled according
|
||||
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
|
||||
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
||||
output_shape: Output dimesion after the transform, [height, width]. If None,
|
||||
output is the same size as input image.
|
||||
name: The name of the op.
|
||||
|
||||
## Fill mode.
|
||||
Behavior for each valid value is as follows:
|
||||
|
||||
reflect (d c b a | a b c d | d c b a)
|
||||
The input is extended by reflecting about the edge of the last pixel.
|
||||
|
||||
constant (k k k k | a b c d | k k k k)
|
||||
The input is extended by filling all values beyond the edge with the same
|
||||
constant value k = 0.
|
||||
|
||||
wrap (a b c d | a b c d | a b c d)
|
||||
The input is extended by wrapping around to the opposite edge.
|
||||
|
||||
Input shape:
|
||||
4D tensor with shape: `(samples, height, width, channels)`,
|
||||
data_format='channels_last'.
|
||||
Output shape:
|
||||
4D tensor with shape: `(samples, height, width, channels)`,
|
||||
data_format='channels_last'.
|
||||
|
||||
Returns:
|
||||
Image(s) with the same type and shape as `images`, with the given
|
||||
transform(s) applied. Transformed coordinates outside of the input image
|
||||
@ -612,6 +648,13 @@ def transform(images,
|
||||
'new_height, new_width, instead got '
|
||||
'{}'.format(output_shape))
|
||||
|
||||
if compat.forward_compatible(2020, 3, 25):
|
||||
return image_ops.image_projective_transform_v2(
|
||||
images,
|
||||
output_shape=output_shape,
|
||||
transforms=transforms,
|
||||
fill_mode=fill_mode.upper(),
|
||||
interpolation=interpolation.upper())
|
||||
return image_ops.image_projective_transform_v2(
|
||||
images,
|
||||
output_shape=output_shape,
|
||||
@ -680,11 +723,24 @@ class RandomRotation(Layer):
|
||||
2 representing lower and upper bound for rotating clockwise and
|
||||
counter-clockwise. When represented as a single float, lower = upper.
|
||||
fill_mode: Points outside the boundaries of the input are filled according
|
||||
to the given mode (one of `{'constant', 'nearest', 'bilinear', 'reflect',
|
||||
'wrap'}`).
|
||||
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
|
||||
- *reflect*: `(d c b a | a b c d | d c b a)`
|
||||
The input is extended by reflecting about the edge of the last pixel.
|
||||
- *constant*: `(k k k k | a b c d | k k k k)`
|
||||
The input is extended by filling all values beyond the edge with the
|
||||
same constant value k = 0.
|
||||
- *wrap*: `(a b c d | a b c d | a b c d)`
|
||||
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
||||
seed: Integer. Used to create a random seed.
|
||||
name: A string, the name of the layer.
|
||||
|
||||
Input shape:
|
||||
4D tensor with shape: `(samples, height, width, channels)`,
|
||||
data_format='channels_last'.
|
||||
Output shape:
|
||||
4D tensor with shape: `(samples, height, width, channels)`,
|
||||
data_format='channels_last'.
|
||||
|
||||
Raise:
|
||||
ValueError: if lower bound is not between [0, 1], or upper bound is
|
||||
negative.
|
||||
@ -692,7 +748,8 @@ class RandomRotation(Layer):
|
||||
|
||||
def __init__(self,
|
||||
factor,
|
||||
fill_mode='nearest',
|
||||
fill_mode='reflect',
|
||||
interpolation='bilinear',
|
||||
seed=None,
|
||||
name=None,
|
||||
**kwargs):
|
||||
@ -705,10 +762,16 @@ class RandomRotation(Layer):
|
||||
if self.lower < 0. or self.upper < 0.:
|
||||
raise ValueError('Factor cannot have negative values, '
|
||||
'got {}'.format(factor))
|
||||
if fill_mode not in {'nearest', 'bilinear'}:
|
||||
if fill_mode not in {'reflect', 'wrap', 'constant'}:
|
||||
raise NotImplementedError(
|
||||
'`fill_mode` {} is not supported yet.'.format(fill_mode))
|
||||
'Unknown `fill_mode` {}. Only `reflect`, `wrap` and '
|
||||
'`constant` are supported.'.format(fill_mode))
|
||||
if interpolation not in {'nearest', 'bilinear'}:
|
||||
raise NotImplementedError(
|
||||
'Unknown `interpolation` {}. Only `nearest` and '
|
||||
'`bilinear` are supported.'.format(interpolation))
|
||||
self.fill_mode = fill_mode
|
||||
self.interpolation = interpolation
|
||||
self.seed = seed
|
||||
self._rng = make_generator(self.seed)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
@ -732,7 +795,8 @@ class RandomRotation(Layer):
|
||||
return transform(
|
||||
inputs,
|
||||
get_rotation_matrix(angles, img_hd, img_wd),
|
||||
interpolation=self.fill_mode)
|
||||
fill_mode=self.fill_mode,
|
||||
interpolation=self.interpolation)
|
||||
|
||||
output = tf_utils.smart_cond(training, random_rotated_inputs,
|
||||
lambda: inputs)
|
||||
@ -746,6 +810,7 @@ class RandomRotation(Layer):
|
||||
config = {
|
||||
'factor': self.factor,
|
||||
'fill_mode': self.fill_mode,
|
||||
'interpolation': self.interpolation,
|
||||
'seed': self.seed,
|
||||
}
|
||||
base_config = super(RandomRotation, self).get_config()
|
||||
@ -768,9 +833,14 @@ class RandomZoom(Layer):
|
||||
upper and lower bound. For instance, `width_factor=(0.2, 0.3)` result in
|
||||
an output zoom varying in the range `[original * 20%, original * 30%]`.
|
||||
fill_mode: Points outside the boundaries of the input are filled according
|
||||
to the given mode (one of `{'nearest', 'bilinear'}`).
|
||||
fill_value: Value used for points outside the boundaries of the input if
|
||||
`mode='constant'`.
|
||||
to the given mode (one of `{'constant', 'reflect', 'wrap'}`).
|
||||
- *reflect*: `(d c b a | a b c d | d c b a)`
|
||||
The input is extended by reflecting about the edge of the last pixel.
|
||||
- *constant*: `(k k k k | a b c d | k k k k)`
|
||||
The input is extended by filling all values beyond the edge with the
|
||||
same constant value k = 0.
|
||||
- *wrap*: `(a b c d | a b c d | a b c d)`
|
||||
interpolation: Interpolation mode. Supported values: "nearest", "bilinear".
|
||||
seed: Integer. Used to create a random seed.
|
||||
name: A string, the name of the layer.
|
||||
|
||||
@ -790,8 +860,8 @@ class RandomZoom(Layer):
|
||||
def __init__(self,
|
||||
height_factor,
|
||||
width_factor,
|
||||
fill_mode='nearest',
|
||||
fill_value=0.,
|
||||
fill_mode='reflect',
|
||||
interpolation='bilinear',
|
||||
seed=None,
|
||||
name=None,
|
||||
**kwargs):
|
||||
@ -821,11 +891,16 @@ class RandomZoom(Layer):
|
||||
raise ValueError('`width_factor` cannot have lower bound larger than '
|
||||
'upper bound, got {}.'.format(width_factor))
|
||||
|
||||
if fill_mode not in {'nearest', 'bilinear'}:
|
||||
if fill_mode not in {'reflect', 'wrap', 'constant'}:
|
||||
raise NotImplementedError(
|
||||
'`fill_mode` {} is not supported yet.'.format(fill_mode))
|
||||
'Unknown `fill_mode` {}. Only `reflect`, `wrap` and '
|
||||
'`constant` are supported.'.format(fill_mode))
|
||||
if interpolation not in {'nearest', 'bilinear'}:
|
||||
raise NotImplementedError(
|
||||
'Unknown `interpolation` {}. Only `nearest` and '
|
||||
'`bilinear` are supported.'.format(interpolation))
|
||||
self.fill_mode = fill_mode
|
||||
self.fill_value = fill_value
|
||||
self.interpolation = interpolation
|
||||
self.seed = seed
|
||||
self._rng = make_generator(self.seed)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
@ -857,7 +932,8 @@ class RandomZoom(Layer):
|
||||
dtype=inputs.dtype)
|
||||
return transform(
|
||||
inputs, get_zoom_matrix(zooms, img_hd, img_wd),
|
||||
interpolation=self.fill_mode)
|
||||
fill_mode=self.fill_mode,
|
||||
interpolation=self.interpolation)
|
||||
|
||||
output = tf_utils.smart_cond(training, random_zoomed_inputs,
|
||||
lambda: inputs)
|
||||
@ -872,7 +948,7 @@ class RandomZoom(Layer):
|
||||
'height_factor': self.height_factor,
|
||||
'width_factor': self.width_factor,
|
||||
'fill_mode': self.fill_mode,
|
||||
'fill_value': self.fill_value,
|
||||
'interpolation': self.interpolation,
|
||||
'seed': self.seed,
|
||||
}
|
||||
base_config = super(RandomZoom, self).get_config()
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
@ -509,6 +510,250 @@ class RandomTranslationTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(layer_1.name, layer.name)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class RandomTransformTest(keras_parameterized.TestCase):
|
||||
|
||||
def _run_random_transform_with_mock(self,
|
||||
transform_matrix,
|
||||
expected_output,
|
||||
mode,
|
||||
interpolation='bilinear'):
|
||||
inp = np.arange(15).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
with self.cached_session(use_gpu=True):
|
||||
output = image_preprocessing.transform(
|
||||
inp, transform_matrix, fill_mode=mode, interpolation=interpolation)
|
||||
self.assertAllClose(expected_output, output)
|
||||
|
||||
def test_random_translation_reflect(self):
|
||||
# reflected output is (dcba|abcd|dcba)
|
||||
|
||||
if compat.forward_compatible(2020, 3, 25):
|
||||
# Test down shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[0., 1., 2.],
|
||||
[0., 1., 2.],
|
||||
[3., 4., 5.],
|
||||
[6., 7., 8],
|
||||
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'reflect')
|
||||
|
||||
# Test up shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[3., 4., 5.],
|
||||
[6., 7., 8],
|
||||
[9., 10., 11.],
|
||||
[12., 13., 14.],
|
||||
[12., 13., 14.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'reflect')
|
||||
|
||||
# Test left shift by 1.
|
||||
# reflected output is (dcba|abcd|dcba)
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[1., 2., 2.],
|
||||
[4., 5., 5.],
|
||||
[7., 8., 8.],
|
||||
[10., 11., 11.],
|
||||
[13., 14., 14.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'reflect')
|
||||
|
||||
# Test right shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[0., 0., 1.],
|
||||
[3., 3., 4],
|
||||
[6., 6., 7.],
|
||||
[9., 9., 10.],
|
||||
[12., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'reflect')
|
||||
|
||||
def test_random_translation_wrap(self):
|
||||
# warpped output is (abcd|abcd|abcd)
|
||||
|
||||
if compat.forward_compatible(2020, 3, 25):
|
||||
# Test down shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[12., 13., 14.],
|
||||
[0., 1., 2.],
|
||||
[3., 4., 5.],
|
||||
[6., 7., 8],
|
||||
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'wrap')
|
||||
|
||||
# Test up shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[3., 4., 5.],
|
||||
[6., 7., 8],
|
||||
[9., 10., 11.],
|
||||
[12., 13., 14.],
|
||||
[0., 1., 2.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'wrap')
|
||||
|
||||
# Test left shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[1., 2., 0.],
|
||||
[4., 5., 3.],
|
||||
[7., 8., 6.],
|
||||
[10., 11., 9.],
|
||||
[13., 14., 12.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'wrap')
|
||||
|
||||
# Test right shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[2., 0., 1.],
|
||||
[5., 3., 4],
|
||||
[8., 6., 7.],
|
||||
[11., 9., 10.],
|
||||
[14., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'wrap')
|
||||
|
||||
def test_random_translation_constant(self):
|
||||
# constant output is (0000|abcd|0000)
|
||||
|
||||
if compat.forward_compatible(2020, 3, 25):
|
||||
# Test down shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[0., 0., 0.],
|
||||
[0., 1., 2.],
|
||||
[3., 4., 5.],
|
||||
[6., 7., 8],
|
||||
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'constant')
|
||||
|
||||
# Test up shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[3., 4., 5.],
|
||||
[6., 7., 8],
|
||||
[9., 10., 11.],
|
||||
[12., 13., 14.],
|
||||
[0., 0., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'constant')
|
||||
|
||||
# Test left shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[1., 2., 0.],
|
||||
[4., 5., 0.],
|
||||
[7., 8., 0.],
|
||||
[10., 11., 0.],
|
||||
[13., 14., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'constant')
|
||||
|
||||
# Test right shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[0., 0., 1.],
|
||||
[0., 3., 4],
|
||||
[0., 6., 7.],
|
||||
[0., 9., 10.],
|
||||
[0., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
|
||||
self._run_random_transform_with_mock(transform_matrix, expected_output,
|
||||
'constant')
|
||||
|
||||
def test_random_translation_nearest_interpolation(self):
|
||||
# nearest output is (aaaa|abcd|dddd)
|
||||
|
||||
if compat.forward_compatible(2020, 3, 25):
|
||||
# Test down shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[0., 0., 0.],
|
||||
[0., 1., 2.],
|
||||
[3., 4., 5.],
|
||||
[6., 7., 8],
|
||||
[9., 10., 11]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 0., 0., 1., -1., 0., 0.]])
|
||||
self._run_random_transform_with_mock(
|
||||
transform_matrix, expected_output,
|
||||
mode='constant', interpolation='nearest')
|
||||
|
||||
# Test up shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[3., 4., 5.],
|
||||
[6., 7., 8],
|
||||
[9., 10., 11.],
|
||||
[12., 13., 14.],
|
||||
[0., 0., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 0., 0., 1., 1., 0., 0.]])
|
||||
self._run_random_transform_with_mock(
|
||||
transform_matrix, expected_output,
|
||||
mode='constant', interpolation='nearest')
|
||||
|
||||
# Test left shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[1., 2., 0.],
|
||||
[4., 5., 0.],
|
||||
[7., 8., 0.],
|
||||
[10., 11., 0.],
|
||||
[13., 14., 0.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., 1., 0., 1., 0., 0., 0.]])
|
||||
self._run_random_transform_with_mock(
|
||||
transform_matrix, expected_output,
|
||||
mode='constant', interpolation='nearest')
|
||||
|
||||
# Test right shift by 1.
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray(
|
||||
[[0., 0., 1.],
|
||||
[0., 3., 4],
|
||||
[0., 6., 7.],
|
||||
[0., 9., 10.],
|
||||
[0., 12., 13.]]).reshape((1, 5, 3, 1)).astype(np.float32)
|
||||
# pyformat: enable
|
||||
transform_matrix = np.asarray([[1., 0., -1., 0., 1., 0., 0., 0.]])
|
||||
self._run_random_transform_with_mock(
|
||||
transform_matrix, expected_output,
|
||||
mode='constant', interpolation='nearest')
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class RandomRotationTest(keras_parameterized.TestCase):
|
||||
|
||||
|
@ -245,6 +245,7 @@ def _image_projective_transform_grad(op, grad):
|
||||
images = op.inputs[0]
|
||||
transforms = op.inputs[1]
|
||||
interpolation = op.get_attr("interpolation")
|
||||
fill_mode = op.get_attr("fill_mode")
|
||||
|
||||
image_or_images = ops.convert_to_tensor(images, name="images")
|
||||
transform_or_transforms = ops.convert_to_tensor(
|
||||
@ -267,5 +268,6 @@ def _image_projective_transform_grad(op, grad):
|
||||
images=grad,
|
||||
transforms=transforms,
|
||||
output_shape=array_ops.shape(image_or_images)[1:3],
|
||||
interpolation=interpolation)
|
||||
interpolation=interpolation,
|
||||
fill_mode=fill_mode)
|
||||
return [output, None, None]
|
||||
|
@ -113,7 +113,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -113,7 +113,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -113,7 +113,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -1778,7 +1778,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ImageProjectiveTransformV2"
|
||||
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ImageSummary"
|
||||
|
@ -113,7 +113,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -113,7 +113,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -113,7 +113,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'fill_value\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'nearest\', \'0.0\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'height_factor\', \'width_factor\', \'fill_mode\', \'interpolation\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'reflect\', \'bilinear\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -1778,7 +1778,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "ImageProjectiveTransformV2"
|
||||
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'images\', \'transforms\', \'output_shape\', \'interpolation\', \'fill_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ImageSummary"
|
||||
|
Loading…
Reference in New Issue
Block a user