Differentiable isotonic regression in TensorFlow.
PiperOrigin-RevId: 327257991 Change-Id: Ided6e7f1d295bcd74c87e3a5601dd2bebde173a0
This commit is contained in:
parent
58747588d2
commit
2c5e31114c
@ -1593,6 +1593,7 @@ Yuan (Terry) Tang, Yuchen Ying, Yves-Noel Weweler, zhangyujing, zjjott, zyeric,
|
||||
color palette of the frame. This has been fixed now
|
||||
* image.resize now considers proper pixel centers and has new kernels
|
||||
(incl. anti-aliasing).
|
||||
* Added an isotonic regression solver (tf.nn.isotonic_regression).
|
||||
* Performance
|
||||
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
|
||||
dispatches the best kernel implementation based on CPU vector
|
||||
|
@ -1011,6 +1011,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:grappler",
|
||||
"//tensorflow/core/kernels:histogram_op",
|
||||
"//tensorflow/core/kernels:io",
|
||||
"//tensorflow/core/kernels:isotonic_regression_op",
|
||||
"//tensorflow/core/kernels:lookup",
|
||||
"//tensorflow/core/kernels:logging",
|
||||
"//tensorflow/core/kernels:manip",
|
||||
|
@ -0,0 +1,24 @@
|
||||
op {
|
||||
graph_op_name: "IsotonicRegression"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
A (batch_size, dim)-tensor holding a batch of inputs.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A (batch_size, dim)-tensor holding the per-batch element solutions.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "segments"
|
||||
description: <<END
|
||||
An int32 (batch_size, dim)-tensor with the segments.
|
||||
END
|
||||
}
|
||||
attr { name: "output_dtype" description: "Dtype of output." }
|
||||
summary: "Solves a batch of isotonic regression problems."
|
||||
}
|
@ -7533,3 +7533,32 @@ tf_kernel_library(
|
||||
name = "einsum_op",
|
||||
deps = ["//tensorflow/core/kernels/linalg:einsum_op"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "isotonic_regression_op",
|
||||
srcs = [
|
||||
"isotonic_regression_op.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "isotonic_regression_op_test",
|
||||
size = "small",
|
||||
srcs = ["isotonic_regression_op_test.cc"],
|
||||
deps = [
|
||||
":isotonic_regression_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
226
tensorflow/core/kernels/isotonic_regression_op.cc
Normal file
226
tensorflow/core/kernels/isotonic_regression_op.cc
Normal file
@ -0,0 +1,226 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::int32;
|
||||
using tensorflow::int64;
|
||||
|
||||
// The # of ops estimated for the isotonic regression solver is the size of the
|
||||
// array multiplied by this constant. This is used by the thread pool executor
|
||||
// when deciding how many threads to use.
|
||||
constexpr int kCostMultiplier = 100;
|
||||
|
||||
// In separable chain-constrained problems, i.e., those of the form
|
||||
//
|
||||
// min_{y_1 >= y_2 >= ... >= y_n} \sum_{i=1}^n h_i(y_i)
|
||||
//
|
||||
// for any set of convex functions h_i, of particular importance are contiguous
|
||||
// segments of coordinates, which this class represents. The interval is assumed
|
||||
// to be half-closed and equal to [col_start(), col_limit()).
|
||||
class Segment {
|
||||
public:
|
||||
// Creates the [col_index, col_index+1).
|
||||
explicit Segment(int col_index)
|
||||
: col_start_(col_index), col_limit_(col_index + 1) {}
|
||||
|
||||
// Returns the number of points in the segment.
|
||||
int num_points() const { return col_limit_ - col_start_; }
|
||||
|
||||
// Merge another segment into this one.
|
||||
void merge_with(const Segment& other) {
|
||||
col_start_ = std::min(col_start_, other.col_start());
|
||||
col_limit_ = std::max(col_limit_, other.col_limit());
|
||||
}
|
||||
|
||||
int col_start() const { return col_start_; }
|
||||
|
||||
int col_limit() const { return col_limit_; }
|
||||
|
||||
private:
|
||||
int col_start_;
|
||||
int col_limit_;
|
||||
};
|
||||
|
||||
// If we can solve for each segment {j, j+1, ..., j+m} the interval problem
|
||||
//
|
||||
// argmin_y \sum_{i=j}^{j+m} h_i(y),
|
||||
//
|
||||
// we can use such an oracle to solve the general problem. The following class
|
||||
// implements such an oracle for the case when h_i is the squared (l2) loss,
|
||||
// or formally h_i(y) = (y - x_i)^2, where x_i is the i-th input.
|
||||
//
|
||||
// TODO(josipd): We know how and can extend this to other functions if needed.
|
||||
template <typename T>
|
||||
class L2PavaSegment : public Segment {
|
||||
public:
|
||||
L2PavaSegment(T y, int col_index)
|
||||
: Segment(col_index), y_sum_(y), minimum_(y) {}
|
||||
|
||||
void merge_with(const L2PavaSegment& other) {
|
||||
Segment::merge_with(other);
|
||||
y_sum_ += other.y_sum_;
|
||||
minimum_ = y_sum_ / static_cast<T>(num_points());
|
||||
}
|
||||
|
||||
T minimum() const { return minimum_; }
|
||||
|
||||
private:
|
||||
T y_sum_; // The sum of the inputs within the segment.
|
||||
T minimum_; // The minimum, cached to avoid expensive divisions.
|
||||
};
|
||||
|
||||
// Solve one of the problems in the batch (the row_index'th one) using the
|
||||
// pool-adjacent violators algorithm (PAVA).
|
||||
//
|
||||
// The PAVA algorithm goes back to
|
||||
//
|
||||
// Nonmetric Multidimensional Scaling: A numerical method
|
||||
// Kruskal, J. B. (1964), Psychometrika (1964)
|
||||
//
|
||||
// For a more recent analysis, please refer to
|
||||
//
|
||||
// Active set algorithms for isotonic regression; a unifying framework
|
||||
// Best, Michael J., and Nilotpal Chakravarti
|
||||
// Mathematical Programming 47.1-3 (1990)
|
||||
//
|
||||
// Intuitively, the algorithm splits the inputs into blocks (starting from
|
||||
// singleton ones), and then whenever there are two consecutive blocks whose
|
||||
// minima violate the inequality constraint, they are merged. The solution is
|
||||
// then block-wise constant, each block equal to the corresponding minimum.
|
||||
//
|
||||
// The tensors should be two dimensional, and the segment objects should
|
||||
// support the minimum() and merge_with() methods.
|
||||
template <typename SegmentType, typename FloatTensor, typename IntTensor>
|
||||
void solve_pava(const std::function<SegmentType(int, int)>& make_segment,
|
||||
FloatTensor* solution, IntTensor* segments, int row_index) {
|
||||
const size_t n = solution->dimensions()[1];
|
||||
std::vector<SegmentType> pools;
|
||||
pools.reserve(n);
|
||||
|
||||
for (size_t col_index = 0; col_index < n; ++col_index) {
|
||||
pools.push_back(make_segment(row_index, col_index));
|
||||
|
||||
// While the last two pools are decreasing, merge them.
|
||||
while (pools.size() > 1 &&
|
||||
pools.rbegin()->minimum() > (pools.rbegin() + 1)->minimum()) {
|
||||
(pools.rbegin() + 1)->merge_with(*pools.rbegin());
|
||||
pools.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
int segment_id = 0;
|
||||
for (const auto& pool : pools) {
|
||||
const auto pool_minimum = pool.minimum();
|
||||
// The matrices are row major, so we can scan the memory linearly.
|
||||
auto* solution_ptr = &(*solution)(row_index, pool.col_start());
|
||||
auto* segments_ptr = &(*segments)(row_index, pool.col_start());
|
||||
for (int i = pool.col_start(); i < pool.col_limit(); ++i) {
|
||||
*solution_ptr++ = pool_minimum;
|
||||
*segments_ptr++ = segment_id;
|
||||
}
|
||||
++segment_id;
|
||||
}
|
||||
}
|
||||
|
||||
// Solve a batch of problems using the pool-adjacent violators algorithm.
|
||||
// The problems are solved in parallel using tensorflow's thread pool.
|
||||
template <typename SegmentType, typename FloatTensor, typename IntTensor>
|
||||
void solve_pava_batch(const std::function<SegmentType(int, int)>& make_segment,
|
||||
FloatTensor* solution, IntTensor* segments,
|
||||
tensorflow::OpKernelContext* context) {
|
||||
const int batch_size = solution->dimensions()[0];
|
||||
const int problem_size = solution->dimensions()[1];
|
||||
|
||||
auto thread_pool =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
|
||||
thread_pool->ParallelFor(
|
||||
batch_size, kCostMultiplier * problem_size,
|
||||
[&make_segment, &solution, &segments](int64 row_start, int64 row_limit) {
|
||||
// Casting to int is safe, as we do boundary checks in `Compute`.
|
||||
for (int row_index = static_cast<int>(row_start);
|
||||
row_index < static_cast<int>(row_limit); ++row_index) {
|
||||
solve_pava(make_segment, solution, segments, row_index);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tin, typename Tout>
|
||||
class IsotonicRegressionOp : public tensorflow::OpKernel {
|
||||
public:
|
||||
explicit IsotonicRegressionOp(tensorflow::OpKernelConstruction* context)
|
||||
: tensorflow::OpKernel(context) {}
|
||||
|
||||
void Compute(tensorflow::OpKernelContext* context) override {
|
||||
// Grab the input tensor.
|
||||
const tensorflow::Tensor& input_tensor = context->input(0);
|
||||
const auto input = input_tensor.flat_inner_dims<Tin, 2>();
|
||||
int int_max = std::numeric_limits<int32>::max();
|
||||
OP_REQUIRES(context,
|
||||
tensorflow::FastBoundsCheck(input.dimensions()[0], int_max) &&
|
||||
tensorflow::FastBoundsCheck(input.dimensions()[1], int_max),
|
||||
tensorflow::errors::InvalidArgument("Tensor too large"));
|
||||
|
||||
// Create the output tensor holding the minimizers.
|
||||
const auto shape = input_tensor.shape();
|
||||
tensorflow::Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{0}, 0, shape, &output_tensor));
|
||||
auto output = output_tensor->flat_inner_dims<Tout, 2>();
|
||||
|
||||
// Create the output tensor holidng the segment memberships.
|
||||
tensorflow::Tensor* segments_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(1, shape, &segments_tensor));
|
||||
auto segments = segments_tensor->flat_inner_dims<int>();
|
||||
|
||||
auto make_l2_segment = [&input](int row_index, int col_index) {
|
||||
return L2PavaSegment<Tout>(input(row_index, col_index), col_index);
|
||||
};
|
||||
solve_pava_batch<L2PavaSegment<Tout>>(make_l2_segment, &output, &segments,
|
||||
context);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CPU_KERNEL(Tin, Tout) \
|
||||
REGISTER_KERNEL_BUILDER(Name("IsotonicRegression") \
|
||||
.Device(tensorflow::DEVICE_CPU) \
|
||||
.TypeConstraint<Tin>("T") \
|
||||
.TypeConstraint<Tout>("output_dtype"), \
|
||||
IsotonicRegressionOp<Tin, Tout>);
|
||||
|
||||
// Float types have the same input and output.
|
||||
#define REGISTER_CPU_SAME_KERNEL(T) REGISTER_CPU_KERNEL(T, T)
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_CPU_SAME_KERNEL);
|
||||
|
||||
// 8 and 16 bit integers get converted to 32 bit floats.
|
||||
#define REGISTER_CPU_KERNEL_FLOAT(Tin) REGISTER_CPU_KERNEL(Tin, float)
|
||||
TF_CALL_int16(REGISTER_CPU_KERNEL_FLOAT);
|
||||
TF_CALL_int8(REGISTER_CPU_KERNEL_FLOAT);
|
||||
|
||||
// 32 and 64 bit integers get converted to 64 bit floats.
|
||||
#define REGISTER_CPU_KERNEL_DOUBLE(Tin) REGISTER_CPU_KERNEL(Tin, double)
|
||||
TF_CALL_int64(REGISTER_CPU_KERNEL_DOUBLE);
|
||||
TF_CALL_int32(REGISTER_CPU_KERNEL_DOUBLE);
|
139
tensorflow/core/kernels/isotonic_regression_op_test.cc
Normal file
139
tensorflow/core/kernels/isotonic_regression_op_test.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class IsotonicRegressionOpTest : public OpsTestBase {
|
||||
public:
|
||||
void MakeOp(DataType type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "IsotonicRegression")
|
||||
.Input(FakeInput(type))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
class BenchmarkHelper : public IsotonicRegressionOpTest {
|
||||
public:
|
||||
void TestBody() override {}
|
||||
|
||||
void AddIncreasingInput(int batch_size, int input_size) {
|
||||
std::vector<float> input_data(input_size * batch_size, 0);
|
||||
for (int i = 0; i < input_data.size(); i++) {
|
||||
input_data[i] = i;
|
||||
}
|
||||
AddInputFromArray<float>(TensorShape({batch_size, input_size}), input_data);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(IsotonicRegressionOpTest, Constant) {
|
||||
MakeOp(DT_FLOAT_REF);
|
||||
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||
test::FillValues<float>(&expected,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
test::ExpectClose(expected, *GetOutput((0)));
|
||||
}
|
||||
|
||||
TEST_F(IsotonicRegressionOpTest, IncreasingInput) {
|
||||
MakeOp(DT_FLOAT_REF);
|
||||
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||
test::FillValues<float>(&expected,
|
||||
{2, 2, 2, 5, 5, 5, 8, 8, 8, 11, 11, 11, 14, 14, 14});
|
||||
test::ExpectClose(expected, *GetOutput((0)));
|
||||
|
||||
Tensor expected_ord(allocator(), DT_INT32, TensorShape({5, 3}));
|
||||
test::FillValues<int>(&expected_ord,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
test::ExpectTensorEqual<int>(expected_ord, *GetOutput((1)));
|
||||
}
|
||||
|
||||
TEST_F(IsotonicRegressionOpTest, Decreasing) {
|
||||
MakeOp(DT_FLOAT_REF);
|
||||
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||
test::FillValues<float>(&expected,
|
||||
{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1});
|
||||
test::ExpectClose(expected, *GetOutput((0)));
|
||||
|
||||
Tensor expected_ord(allocator(), DT_INT32, TensorShape({5, 3}));
|
||||
test::FillValues<int>(&expected_ord,
|
||||
{0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2});
|
||||
test::ExpectTensorEqual<int>(expected_ord, *GetOutput((1)));
|
||||
}
|
||||
|
||||
static void BM_IncreasingSequence(benchmark::State& state) {
|
||||
int batch_size = state.range(0);
|
||||
int input_size = state.range(1);
|
||||
|
||||
for (auto _ : state) {
|
||||
state.PauseTiming();
|
||||
BenchmarkHelper helper;
|
||||
helper.MakeOp(DT_FLOAT_REF);
|
||||
helper.AddIncreasingInput(batch_size, input_size);
|
||||
state.ResumeTiming();
|
||||
Status stat = helper.RunOpKernel();
|
||||
}
|
||||
state.SetItemsProcessed(
|
||||
static_cast<int64>(batch_size * input_size * state.iterations()));
|
||||
}
|
||||
|
||||
BENCHMARK(BM_IncreasingSequence)
|
||||
->Args({1, 1 << 0})
|
||||
->Args({1, 1 << 5})
|
||||
->Args({1, 1 << 8})
|
||||
->Args({1, 1 << 10})
|
||||
->Args({1, 1 << 20})
|
||||
->Args({1, 2 << 20})
|
||||
->Args({1 << 0, 1 << 10})
|
||||
->Args({1 << 1, 1 << 10})
|
||||
->Args({1 << 4, 1 << 10})
|
||||
->Args({1 << 6, 1 << 10})
|
||||
->Args({1 << 9, 1 << 10})
|
||||
->Args({1 << 10, 1 << 10});
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -3406,4 +3406,16 @@ REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
|
||||
|
||||
REGISTER_OP("IsotonicRegression")
|
||||
.Input("input: T")
|
||||
.Output("output: output_dtype")
|
||||
.Output("segments: int32")
|
||||
.Attr("T: realnumbertype")
|
||||
.Attr("output_dtype: {half, bfloat16, float, double} = DT_FLOAT")
|
||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* context) {
|
||||
context->set_output(0, context->input(0));
|
||||
context->set_output(1, context->input(0));
|
||||
return tensorflow::Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 350> a = {{
|
||||
static std::array<OpIndexInfo, 351> a = {{
|
||||
{"Acosh"},
|
||||
{"AllToAll", 1, {0}},
|
||||
{"ApproximateEqual"},
|
||||
@ -160,6 +160,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
{"Inv"},
|
||||
{"Invert"},
|
||||
{"InvertPermutation"},
|
||||
{"IsotonicRegression"},
|
||||
{"LMDBReader"},
|
||||
{"LeakyReluGrad", 1, {0}},
|
||||
{"LeftShift"},
|
||||
@ -413,7 +414,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 466> a = {{
|
||||
static std::array<OpIndexInfo, 467> a = {{
|
||||
{"Abs"},
|
||||
{"AccumulateNV2"},
|
||||
{"Acos"},
|
||||
@ -577,6 +578,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
{"InvGrad"},
|
||||
{"Invert"},
|
||||
{"InvertPermutation"},
|
||||
{"IsotonicRegression", 1, {0}},
|
||||
{"L2Loss"},
|
||||
{"LMDBReader"},
|
||||
{"LeakyRelu"},
|
||||
|
@ -1142,3 +1142,48 @@ def _NthElementGrad(op, grad):
|
||||
num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
|
||||
|
||||
return [math_ops.divide(indicators, num_selected) * grad, None]
|
||||
|
||||
|
||||
def _MeanAggregator(inputs, segments):
|
||||
"""Replaces each segment with its mean along the last axis.
|
||||
|
||||
Specifically, each value in the `inputs` tensor gets replaced by the mean
|
||||
value computed from the values that belong to the same segment.
|
||||
|
||||
Args:
|
||||
inputs: A 2-tensor. Aggregation is done over dimension 1.
|
||||
segments: A 2-tensor, same shape as `input`.
|
||||
|
||||
Returns:
|
||||
The result, same shape and type as `inputs`.
|
||||
"""
|
||||
result = []
|
||||
for inputs_i, segments_i in zip(
|
||||
array_ops.split(inputs, inputs.shape[0]),
|
||||
array_ops.split(segments, segments.shape[0])):
|
||||
# Note that we do not use tf.math.segment_mean, as it has no TPU support.
|
||||
means_i = math_ops.unsorted_segment_mean(
|
||||
inputs_i, segments_i, num_segments=math_ops.reduce_max(segments_i) + 1)
|
||||
result.append(
|
||||
array_ops.reshape(array_ops.gather(means_i, segments_i), [-1]))
|
||||
return array_ops.stack(result, axis=0)
|
||||
|
||||
|
||||
# We have to register the gradients for these ops so that tensorflow will know
|
||||
# how to differentiate them.
|
||||
@ops.RegisterGradient("IsotonicRegression")
|
||||
def _IsotonicRegressionGrad(op, grad_output, grad_segments):
|
||||
"""Gradient for the isotonic regression function.
|
||||
|
||||
Args:
|
||||
op: The IsotonicRegression tensorflow op.
|
||||
grad_output: Tensor of incoming gradients with respect to the output.
|
||||
grad_segments: Tensor of incoming gradients with respect to the segments.
|
||||
|
||||
Returns:
|
||||
A tensor, same size as `grad_output` with the gradient with respect to
|
||||
the input.
|
||||
"""
|
||||
del grad_segments # Discrete, non-differentiable.
|
||||
segments = op.outputs[1]
|
||||
return _MeanAggregator(grad_output, segments)
|
||||
|
@ -3566,46 +3566,49 @@ def _flatten_outer_dims(logits):
|
||||
return output
|
||||
|
||||
|
||||
def _softmax(logits, compute_op, dim=-1, name=None):
|
||||
"""Helper function for softmax and log_softmax.
|
||||
def _wrap_2d_function(inputs, compute_op, dim=-1, name=None):
|
||||
"""Helper function for ops that accept and return 2d inputs of same shape.
|
||||
|
||||
It reshapes and transposes the input logits into a 2-D Tensor and then invokes
|
||||
the tf.nn._softmax or tf.nn._log_softmax function. The output would be
|
||||
transposed and reshaped back.
|
||||
It reshapes and transposes the inputs into a 2-D Tensor and then invokes
|
||||
the given function. The output would be transposed and reshaped back.
|
||||
If the given function returns a tuple of tensors, each of them will be
|
||||
transposed and reshaped.
|
||||
|
||||
Args:
|
||||
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
|
||||
inputs: A non-empty `Tensor`. Must be one of the following types: `half`,
|
||||
`float32`, `float64`.
|
||||
compute_op: Either gen_nn_ops.softmax or gen_nn_ops.log_softmax
|
||||
compute_op: The function to wrap. Must accept the input tensor as its first
|
||||
arugment, and a second keyword argument `name`.
|
||||
dim: The dimension softmax would be performed on. The default is -1 which
|
||||
indicates the last dimension.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
|
||||
A `Tensor`. Has the same shape as inputs. If compute_op returns multiple
|
||||
tensors, each of them have the same shape as the input.
|
||||
Raises:
|
||||
InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
|
||||
dimension of `logits`.
|
||||
InvalidArgumentError: if `inputs` is empty or `dim` is beyond the last
|
||||
dimension of `inputs`.
|
||||
"""
|
||||
|
||||
def _swap_axis(logits, dim_index, last_index, name=None):
|
||||
def _swap_axis(input_tensor, dim_index, last_index, name=None):
|
||||
"""Swaps logits's dim_index and last_index."""
|
||||
return array_ops.transpose(
|
||||
logits,
|
||||
input_tensor,
|
||||
array_ops.concat([
|
||||
math_ops.range(dim_index), [last_index],
|
||||
math_ops.range(dim_index + 1, last_index), [dim_index]
|
||||
], 0),
|
||||
name=name)
|
||||
|
||||
logits = ops.convert_to_tensor(logits)
|
||||
inputs = ops.convert_to_tensor(inputs)
|
||||
|
||||
# We need its original shape for shape inference.
|
||||
shape = logits.get_shape()
|
||||
shape = inputs.get_shape()
|
||||
is_last_dim = (dim == -1) or (dim == shape.ndims - 1)
|
||||
|
||||
if is_last_dim:
|
||||
return compute_op(logits, name=name)
|
||||
return compute_op(inputs, name=name)
|
||||
|
||||
dim_val = dim
|
||||
if isinstance(dim, ops.Tensor):
|
||||
@ -3618,10 +3621,10 @@ def _softmax(logits, compute_op, dim=-1, name=None):
|
||||
shape.ndims))
|
||||
|
||||
# If dim is not the last dimension, we have to do a transpose so that we can
|
||||
# still perform softmax on its last dimension.
|
||||
# still perform the op on its last dimension.
|
||||
|
||||
# In case dim is negative (and is not last dimension -1), add shape.ndims
|
||||
ndims = array_ops.rank(logits)
|
||||
ndims = array_ops.rank(inputs)
|
||||
if not isinstance(dim, ops.Tensor):
|
||||
if dim < 0:
|
||||
dim += ndims
|
||||
@ -3629,20 +3632,24 @@ def _softmax(logits, compute_op, dim=-1, name=None):
|
||||
dim = array_ops.where(math_ops.less(dim, 0), dim + ndims, dim)
|
||||
|
||||
# Swap logits' dimension of dim and its last dimension.
|
||||
input_rank = array_ops.rank(logits)
|
||||
input_rank = array_ops.rank(inputs)
|
||||
dim_axis = dim % shape.ndims
|
||||
logits = _swap_axis(logits, dim_axis, math_ops.subtract(input_rank, 1))
|
||||
inputs = _swap_axis(inputs, dim_axis, math_ops.subtract(input_rank, 1))
|
||||
|
||||
# Do the actual softmax on its last dimension.
|
||||
output = compute_op(logits)
|
||||
# Do the actual call on its last dimension.
|
||||
def fix_output(output):
|
||||
output = _swap_axis(
|
||||
output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
|
||||
|
||||
output = _swap_axis(
|
||||
output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
|
||||
# Make shape inference work since transpose may erase its static shape.
|
||||
output.set_shape(shape)
|
||||
return output
|
||||
|
||||
# Make shape inference work since transpose may erase its static shape.
|
||||
output.set_shape(shape)
|
||||
|
||||
return output
|
||||
outputs = compute_op(inputs)
|
||||
if isinstance(outputs, tuple):
|
||||
return tuple(fix_output(output) for output in outputs)
|
||||
else:
|
||||
return fix_output(outputs)
|
||||
|
||||
|
||||
@tf_export(v1=["nn.softmax", "math.softmax"])
|
||||
@ -3687,7 +3694,7 @@ def softmax(logits, axis=None, name=None, dim=None):
|
||||
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
|
||||
if axis is None:
|
||||
axis = -1
|
||||
return _softmax(logits, gen_nn_ops.softmax, axis, name)
|
||||
return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
|
||||
|
||||
|
||||
@tf_export("nn.softmax", "math.softmax", v1=[])
|
||||
@ -3715,7 +3722,7 @@ def softmax_v2(logits, axis=None, name=None):
|
||||
"""
|
||||
if axis is None:
|
||||
axis = -1
|
||||
return _softmax(logits, gen_nn_ops.softmax, axis, name)
|
||||
return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
|
||||
|
||||
|
||||
@tf_export(v1=["nn.log_softmax", "math.log_softmax"])
|
||||
@ -3746,7 +3753,7 @@ def log_softmax(logits, axis=None, name=None, dim=None):
|
||||
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
|
||||
if axis is None:
|
||||
axis = -1
|
||||
return _softmax(logits, gen_nn_ops.log_softmax, axis, name)
|
||||
return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
|
||||
|
||||
|
||||
@tf_export("nn.log_softmax", "math.log_softmax", v1=[])
|
||||
@ -3774,7 +3781,7 @@ def log_softmax_v2(logits, axis=None, name=None):
|
||||
"""
|
||||
if axis is None:
|
||||
axis = -1
|
||||
return _softmax(logits, gen_nn_ops.log_softmax, axis, name)
|
||||
return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
|
||||
|
||||
|
||||
def _ensure_xent_args(name, sentinel, labels, logits):
|
||||
@ -5674,3 +5681,78 @@ tf_export(v1=["nn.quantized_relu_x"])(
|
||||
dispatch.add_dispatch_support(gen_nn_ops.quantized_relu_x))
|
||||
tf_export(v1=["nn.quantized_max_pool"])(
|
||||
dispatch.add_dispatch_support(gen_nn_ops.quantized_max_pool))
|
||||
|
||||
|
||||
@tf_export("nn.isotonic_regression", v1=[])
|
||||
@dispatch.add_dispatch_support
|
||||
def isotonic_regression(inputs, decreasing=True, axis=-1):
|
||||
r"""Solves isotonic regression problems along the given axis.
|
||||
|
||||
For each vector x, the problem solved is
|
||||
|
||||
$$\argmin_{y_1 >= y_2 >= ... >= y_n} \sum_i (x_i - y_i)^2.$$
|
||||
|
||||
As the solution is component-wise constant, a second tensor is returned that
|
||||
encodes the segments. The problems are solved over the given axis.
|
||||
|
||||
Consider the following example, where we solve a batch of two problems. The
|
||||
first input is [3, 1, 2], while the second [1, 3, 4] (as the axis is 1).
|
||||
>>> x = tf.constant([[3, 1, 2], [1, 3, 4]], dtype=tf.float32)
|
||||
>>> y, segments = tf.nn.isotonic_regression(x, axis=1)
|
||||
>>> y # The solution.
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
|
||||
array([[3. , 1.5 , 1.5 ],
|
||||
[2.6666667, 2.6666667, 2.6666667]], dtype=float32)>
|
||||
|
||||
Note that the first solution has two blocks [2] and [1.5, 1.5]. The second
|
||||
solution is constant, and thus has a single segment. These segments are
|
||||
exactly what the second returned tensor encodes:
|
||||
|
||||
>>> segments
|
||||
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
|
||||
array([[0, 1, 1],
|
||||
[0, 0, 0]], dtype=int32)>
|
||||
|
||||
|
||||
Args:
|
||||
inputs: A tensor holding the inputs.
|
||||
decreasing: If set to False, the inequalities in the optimizing constrained
|
||||
are flipped.
|
||||
axis: The axis along which the problems should be solved.
|
||||
|
||||
Returns:
|
||||
output: The solutions, same shape as type as the input.
|
||||
segments: An int32 tensor, same shape as the input indicating the segments
|
||||
that have the same value. Specifically, those positions that have the same
|
||||
value correspond to the same segment. These values start at zero, and are
|
||||
monotonously increasing for each solution.
|
||||
"""
|
||||
type_promotions = {
|
||||
# Float types get mapped to themselves, int8/16 to float32, rest to double
|
||||
dtypes.float32:
|
||||
dtypes.float32,
|
||||
dtypes.half:
|
||||
dtypes.half,
|
||||
dtypes.bfloat16:
|
||||
dtypes.bfloat16,
|
||||
dtypes.int8:
|
||||
dtypes.float32,
|
||||
dtypes.int16:
|
||||
dtypes.float32,
|
||||
}
|
||||
inputs = ops.convert_to_tensor(inputs)
|
||||
try:
|
||||
output_dtype = type_promotions[inputs.dtype]
|
||||
except KeyError:
|
||||
output_dtype = dtypes.float64
|
||||
|
||||
def compute_on_matrix(matrix, name=None):
|
||||
iso_fn = functools.partial(
|
||||
gen_nn_ops.isotonic_regression, output_dtype=output_dtype, name=name)
|
||||
if decreasing:
|
||||
return iso_fn(matrix)
|
||||
else:
|
||||
output, segments = iso_fn(-matrix)
|
||||
return -output, segments
|
||||
|
||||
return _wrap_2d_function(inputs, compute_on_matrix, axis)
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import nn_impl
|
||||
@ -1701,5 +1702,88 @@ class RaggedEmbeddingTest(test_lib.TestCase):
|
||||
actual)
|
||||
|
||||
|
||||
class IsotonicTest(parameterized.TestCase, test_lib.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_increasing_and_decreasing(self):
|
||||
x = constant_op.constant([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
|
||||
dtype=dtypes.float64)
|
||||
y, segments = nn_ops.isotonic_regression(x, decreasing=False)
|
||||
self.assertAllClose(y, x)
|
||||
self.assertAllClose(segments, [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])
|
||||
|
||||
y, segments = nn_ops.isotonic_regression(x, decreasing=True)
|
||||
self.assertAllClose(
|
||||
y,
|
||||
[
|
||||
[2, 2, 2, 2, 2], # Average of the inputs.
|
||||
[7, 7, 7, 7, 7]
|
||||
])
|
||||
self.assertAllClose(segments, array_ops.zeros((2, 5)))
|
||||
|
||||
y, segments = nn_ops.isotonic_regression(-x, decreasing=True)
|
||||
self.assertAllClose(segments, [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])
|
||||
|
||||
self.assertAllClose(y, -x)
|
||||
y, segments = nn_ops.isotonic_regression(-x, decreasing=False)
|
||||
self.assertAllClose(
|
||||
-y,
|
||||
[
|
||||
[2, 2, 2, 2, 2], # Average of the inputs.
|
||||
[7, 7, 7, 7, 7]
|
||||
])
|
||||
self.assertAllClose(segments, array_ops.zeros((2, 5)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_different_axis(self):
|
||||
x = constant_op.constant([[0, 6, 2, 8, 4], [5, 1, 7, 3, 9]],
|
||||
dtype=dtypes.float64)
|
||||
y, segments = nn_ops.isotonic_regression(x, decreasing=True, axis=0)
|
||||
self.assertAllClose(
|
||||
y,
|
||||
[
|
||||
[2.5, 6, 4.5, 8, 6.5], # Either identity or average.
|
||||
[2.5, 1, 4.5, 3, 6.5]
|
||||
])
|
||||
self.assertAllClose(segments, [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0]])
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGradientV2(self, dtype=np.float64, batch_size=30, dimensions=50):
|
||||
|
||||
@def_function.function
|
||||
def ComputeIsotonicFn(x):
|
||||
y, _ = nn_ops.isotonic_regression(x) # No gradient wrt segments.
|
||||
return y
|
||||
|
||||
np.random.seed(0)
|
||||
x_init = np.random.randn(batch_size, dimensions).astype(dtype)
|
||||
grad_theoretical, grad_numerical = gradient_checker_v2.compute_gradient(
|
||||
ComputeIsotonicFn, [x_init], delta=1e-5)
|
||||
self.assertAllClose(grad_theoretical, grad_numerical)
|
||||
|
||||
@test_util.run_v1_only("compute_gradient_error is v1 only")
|
||||
def testGradientV1(self, dtype=np.float64, batch_size=30, dimensions=50):
|
||||
np.random.seed(0)
|
||||
x_init = np.random.randn(batch_size, dimensions).astype(dtype)
|
||||
with self.cached_session():
|
||||
x = array_ops.placeholder(dtype, (batch_size, dimensions))
|
||||
y, _ = nn_ops.isotonic_regression(x) # Segments have no gradient.
|
||||
max_error = gradient_checker.compute_gradient_error(
|
||||
x, (batch_size, dimensions), y, (batch_size, dimensions), x_init)
|
||||
self.assertAllClose(max_error, 0.)
|
||||
|
||||
@parameterized.parameters([[dtypes.half, dtypes.half],
|
||||
[dtypes.bfloat16, dtypes.bfloat16],
|
||||
[dtypes.float32, dtypes.float32],
|
||||
[dtypes.float64, dtypes.float64],
|
||||
[dtypes.int32, dtypes.float64],
|
||||
[dtypes.int16, dtypes.float32]])
|
||||
def testTypePromotion(self, dtype_in, expected_dtype_out):
|
||||
x = constant_op.constant([[0, 6, 2, 8, 4], [5, 1, 7, 3, 9]], dtype=dtype_in)
|
||||
y, segments = nn_ops.isotonic_regression(x)
|
||||
self.assertEqual(y.dtype, expected_dtype_out)
|
||||
self.assertEqual(segments.dtype, dtypes.int32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
||||
|
@ -1996,6 +1996,10 @@ tf_module {
|
||||
name: "IsVariableInitialized"
|
||||
argspec: "args=[\'ref\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "IsotonicRegression"
|
||||
argspec: "args=[\'input\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Iterator"
|
||||
argspec: "args=[\'shared_name\', \'container\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -176,6 +176,10 @@ tf_module {
|
||||
name: "in_top_k"
|
||||
argspec: "args=[\'targets\', \'predictions\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "isotonic_regression"
|
||||
argspec: "args=[\'inputs\', \'decreasing\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'-1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "l2_loss"
|
||||
argspec: "args=[\'t\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -1996,6 +1996,10 @@ tf_module {
|
||||
name: "IsVariableInitialized"
|
||||
argspec: "args=[\'ref\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "IsotonicRegression"
|
||||
argspec: "args=[\'input\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Iterator"
|
||||
argspec: "args=[\'shared_name\', \'container\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user