Allow tf.math.invert_permutation to broadcast.
PiperOrigin-RevId: 289222211 Change-Id: I3b28536354ae924e020faf1607265968845becec
This commit is contained in:
parent
7402079323
commit
66fb5a1d93
tensorflow
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/kernels/transpose_op.h"
|
#include "tensorflow/core/kernels/transpose_op.h"
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -29,43 +28,15 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct InvertPermutations {
|
|
||||||
static void Run(OpKernelContext* context, const Tensor& input, Tensor* out,
|
|
||||||
int start, int limit) {
|
|
||||||
auto input_tensor = input.matrix<T>();
|
|
||||||
const T N = static_cast<T>(
|
|
||||||
input_tensor.dimension(1)); // Safe: bounds already checked.
|
|
||||||
auto output_tensor = out->matrix<T>();
|
|
||||||
for (int64 i = start; i < limit; ++i) {
|
|
||||||
for (int j = 0; j < N; ++j) {
|
|
||||||
const T d = internal::SubtleMustCopy(input_tensor(i, j));
|
|
||||||
OP_REQUIRES(context, FastBoundsCheck(d, N),
|
|
||||||
errors::InvalidArgument(d, " is not between 0 and ", N));
|
|
||||||
OP_REQUIRES(context, output_tensor(i, d) == -1,
|
|
||||||
errors::InvalidArgument(d, " is duplicated in the input."));
|
|
||||||
output_tensor(i, d) = j;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of
|
// inv = InvertPermutationOp(T<int32/int64> p) takes a permutation of
|
||||||
// integers 0, 1, ..., n - 1 and returns the inverted
|
// integers 0, 1, ..., n - 1 and returns the inverted
|
||||||
// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
|
// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
|
||||||
//
|
//
|
||||||
|
// REQUIRES: input is a vector of int32 or int64.
|
||||||
// REQUIRES: input is a permutation of 0, 1, ..., n-1.
|
// REQUIRES: input is a permutation of 0, 1, ..., n-1.
|
||||||
//
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class InvertPermutationOp : public OpKernel {
|
class InvertPermutationOp : public OpKernel {
|
||||||
@ -75,46 +46,28 @@ class InvertPermutationOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
OP_REQUIRES(context, input.dims() > 0,
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument("Permutation must have at least rank 1 "
|
context, TensorShapeUtils::IsVector(input.shape()),
|
||||||
"but is rank ",
|
errors::InvalidArgument("invert_permutation expects a 1D vector."));
|
||||||
input.dims()));
|
auto Tin = input.vec<T>();
|
||||||
|
|
||||||
const int64 perm_size = input.dim_size(input.dims() - 1);
|
|
||||||
OP_REQUIRES(context,
|
OP_REQUIRES(context,
|
||||||
FastBoundsCheck(perm_size, std::numeric_limits<int32>::max()),
|
FastBoundsCheck(Tin.size(), std::numeric_limits<int32>::max()),
|
||||||
errors::InvalidArgument("permutation of nonnegative int32s "
|
errors::InvalidArgument("permutation of nonnegative int32s "
|
||||||
"must have <= int32 max elements"));
|
"must have <= int32 max elements"));
|
||||||
Tensor input_reshaped;
|
const T N = static_cast<T>(Tin.size()); // Safe: bounds-checked above.
|
||||||
int64 batch_size = 1;
|
|
||||||
// The last dimension is the permutation dimension.
|
|
||||||
for (int i = 0; i < input.dims() - 1; ++i) {
|
|
||||||
batch_size *= input.shape().dim_size(i);
|
|
||||||
}
|
|
||||||
TensorShape batch_vectors = TensorShape({batch_size, perm_size});
|
|
||||||
// Note that we always have a batch size, including the scalar case.
|
|
||||||
OP_REQUIRES(context, input_reshaped.CopyFrom(input, batch_vectors),
|
|
||||||
errors::Internal("Failed to reshape In[0] from ",
|
|
||||||
input.shape().DebugString()));
|
|
||||||
|
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, input.shape(), &output));
|
context->allocate_output(0, input.shape(), &output));
|
||||||
output->flat<T>() = output->flat<T>().constant(T(-1));
|
auto Tout = output->vec<T>();
|
||||||
Tensor output_reshaped;
|
std::fill_n(Tout.data(), N, -1);
|
||||||
OP_REQUIRES(context, output_reshaped.CopyFrom(*output, batch_vectors),
|
for (int i = 0; i < N; ++i) {
|
||||||
errors::Internal("Failed to reshape Output[0] from ",
|
const T d = internal::SubtleMustCopy(Tin(i));
|
||||||
output->shape().DebugString()));
|
OP_REQUIRES(context, FastBoundsCheck(d, N),
|
||||||
|
errors::InvalidArgument(d, " is not between 0 and ", N));
|
||||||
const int64 cost_per_unit = perm_size;
|
OP_REQUIRES(context, Tout(d) == -1,
|
||||||
// Parallelize over outer dimensions
|
errors::InvalidArgument(d, " is duplicated in the input."));
|
||||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
Tout(d) = i;
|
||||||
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
}
|
||||||
cost_per_unit,
|
|
||||||
[&context, &input_reshaped, &output_reshaped](int start, int limit) {
|
|
||||||
InvertPermutations<T>::Run(context, input_reshaped,
|
|
||||||
&output_reshaped, start, limit);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1391,7 +1391,7 @@ REGISTER_OP("InvertPermutation")
|
|||||||
.Attr("T: {int32, int64} = DT_INT32")
|
.Attr("T: {int32, int64} = DT_INT32")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
ShapeHandle x;
|
ShapeHandle x;
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &x));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
|
||||||
c->set_output(0, x);
|
c->set_output(0, x);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
@ -399,9 +399,9 @@ TEST(ArrayOpsTest, UniqueWithCounts_ShapeFn) {
|
|||||||
|
|
||||||
TEST(ArrayOpsTest, InvertPermutation_ShapeFn) {
|
TEST(ArrayOpsTest, InvertPermutation_ShapeFn) {
|
||||||
ShapeInferenceTestOp op("InvertPermutation");
|
ShapeInferenceTestOp op("InvertPermutation");
|
||||||
|
INFER_OK(op, "?", "[?]");
|
||||||
INFER_OK(op, "[1]", "in0");
|
INFER_OK(op, "[1]", "in0");
|
||||||
INFER_OK(op, "[1,2,3]", "in0");
|
INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[]");
|
||||||
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArrayOpsTest, PadD_ShapeFn) {
|
TEST(ArrayOpsTest, PadD_ShapeFn) {
|
||||||
|
@ -44,7 +44,6 @@ from tensorflow.python.ops import init_ops
|
|||||||
from tensorflow.python.ops import map_fn
|
from tensorflow.python.ops import map_fn
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import sort_ops
|
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -1352,40 +1351,14 @@ class PadTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class InvertPermutationTest(test_util.TensorFlowTestCase):
|
class InvertPermutationTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testInvertPermutation(self):
|
def testInvertPermutation(self):
|
||||||
for dtype in [dtypes.int32, dtypes.int64]:
|
for dtype in [dtypes.int32, dtypes.int64]:
|
||||||
x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype)
|
with self.cached_session(use_gpu=True):
|
||||||
y = array_ops.invert_permutation(x)
|
x = constant_op.constant([3, 4, 0, 2, 1], dtype=dtype)
|
||||||
self.assertAllEqual(y.shape, [5])
|
y = array_ops.invert_permutation(x)
|
||||||
self.assertAllEqual(self.evaluate(y), [2, 4, 3, 0, 1])
|
self.assertAllEqual(y.get_shape(), [5])
|
||||||
|
self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1])
|
||||||
def testInvertPermutationCheckRank(self):
|
|
||||||
for dtype in [dtypes.int32, dtypes.int64]:
|
|
||||||
x = constant_op.constant(3, dtype=dtype)
|
|
||||||
with self.assertRaisesRegexp(Exception, "at least rank 1"):
|
|
||||||
self.evaluate(array_ops.invert_permutation(x))
|
|
||||||
|
|
||||||
def testInvertPermutationBatch(self):
|
|
||||||
for dtype in [dtypes.int32, dtypes.int64]:
|
|
||||||
x = constant_op.constant([[[3, 4, 0, 2, 1], [2, 3, 4, 0, 1]]],
|
|
||||||
dtype=dtype)
|
|
||||||
y = array_ops.invert_permutation(x)
|
|
||||||
self.assertAllEqual(y.shape, [1, 2, 5])
|
|
||||||
self.assertAllEqual(
|
|
||||||
self.evaluate(y), [[[2, 4, 3, 0, 1], [3, 4, 0, 1, 2]]])
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testInvertPermutationLargerBatch(self):
|
|
||||||
perm = np.array([np.random.permutation(20) for _ in range(10)],
|
|
||||||
dtype=np.int32)
|
|
||||||
|
|
||||||
for dtype in [dtypes.int32, dtypes.int64]:
|
|
||||||
x = constant_op.constant(perm, dtype=dtype)
|
|
||||||
y = array_ops.invert_permutation(x)
|
|
||||||
# Argsort should be equivalent to invert permutation.
|
|
||||||
z = sort_ops.argsort(x, axis=-1)
|
|
||||||
self.assertAllEqual(y.shape, [10, 20])
|
|
||||||
self.assertAllEqual(self.evaluate(y), self.evaluate(z))
|
|
||||||
|
|
||||||
|
|
||||||
class UnravelIndexTest(test_util.TensorFlowTestCase):
|
class UnravelIndexTest(test_util.TensorFlowTestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user