Differentiable isotonic regression in TensorFlow.

PiperOrigin-RevId: 327257991
Change-Id: Ided6e7f1d295bcd74c87e3a5601dd2bebde173a0
This commit is contained in:
Josip Djolonga 2020-08-18 10:37:14 -07:00 committed by TensorFlower Gardener
parent 58747588d2
commit 2c5e31114c
14 changed files with 690 additions and 33 deletions

View File

@ -1593,6 +1593,7 @@ Yuan (Terry) Tang, Yuchen Ying, Yves-Noel Weweler, zhangyujing, zjjott, zyeric,
color palette of the frame. This has been fixed now
* image.resize now considers proper pixel centers and has new kernels
(incl. anti-aliasing).
* Added an isotonic regression solver (tf.nn.isotonic_regression).
* Performance
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
dispatches the best kernel implementation based on CPU vector

View File

@ -1011,6 +1011,7 @@ cc_library(
"//tensorflow/core/kernels:grappler",
"//tensorflow/core/kernels:histogram_op",
"//tensorflow/core/kernels:io",
"//tensorflow/core/kernels:isotonic_regression_op",
"//tensorflow/core/kernels:lookup",
"//tensorflow/core/kernels:logging",
"//tensorflow/core/kernels:manip",

View File

@ -0,0 +1,24 @@
op {
graph_op_name: "IsotonicRegression"
visibility: HIDDEN
in_arg {
name: "input"
description: <<END
A (batch_size, dim)-tensor holding a batch of inputs.
END
}
out_arg {
name: "output"
description: <<END
A (batch_size, dim)-tensor holding the per-batch element solutions.
END
}
out_arg {
name: "segments"
description: <<END
An int32 (batch_size, dim)-tensor with the segments.
END
}
attr { name: "output_dtype" description: "Dtype of output." }
summary: "Solves a batch of isotonic regression problems."
}

View File

@ -7533,3 +7533,32 @@ tf_kernel_library(
name = "einsum_op",
deps = ["//tensorflow/core/kernels/linalg:einsum_op"],
)
tf_kernel_library(
name = "isotonic_regression_op",
srcs = [
"isotonic_regression_op.cc",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "isotonic_regression_op_test",
size = "small",
srcs = ["isotonic_regression_op_test.cc"],
deps = [
":isotonic_regression_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

View File

@ -0,0 +1,226 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/platform/threadpool.h"
namespace {
using tensorflow::int32;
using tensorflow::int64;
// The # of ops estimated for the isotonic regression solver is the size of the
// array multiplied by this constant. This is used by the thread pool executor
// when deciding how many threads to use.
constexpr int kCostMultiplier = 100;
// In separable chain-constrained problems, i.e., those of the form
//
// min_{y_1 >= y_2 >= ... >= y_n} \sum_{i=1}^n h_i(y_i)
//
// for any set of convex functions h_i, of particular importance are contiguous
// segments of coordinates, which this class represents. The interval is assumed
// to be half-closed and equal to [col_start(), col_limit()).
class Segment {
public:
// Creates the [col_index, col_index+1).
explicit Segment(int col_index)
: col_start_(col_index), col_limit_(col_index + 1) {}
// Returns the number of points in the segment.
int num_points() const { return col_limit_ - col_start_; }
// Merge another segment into this one.
void merge_with(const Segment& other) {
col_start_ = std::min(col_start_, other.col_start());
col_limit_ = std::max(col_limit_, other.col_limit());
}
int col_start() const { return col_start_; }
int col_limit() const { return col_limit_; }
private:
int col_start_;
int col_limit_;
};
// If we can solve for each segment {j, j+1, ..., j+m} the interval problem
//
// argmin_y \sum_{i=j}^{j+m} h_i(y),
//
// we can use such an oracle to solve the general problem. The following class
// implements such an oracle for the case when h_i is the squared (l2) loss,
// or formally h_i(y) = (y - x_i)^2, where x_i is the i-th input.
//
// TODO(josipd): We know how and can extend this to other functions if needed.
template <typename T>
class L2PavaSegment : public Segment {
public:
L2PavaSegment(T y, int col_index)
: Segment(col_index), y_sum_(y), minimum_(y) {}
void merge_with(const L2PavaSegment& other) {
Segment::merge_with(other);
y_sum_ += other.y_sum_;
minimum_ = y_sum_ / static_cast<T>(num_points());
}
T minimum() const { return minimum_; }
private:
T y_sum_; // The sum of the inputs within the segment.
T minimum_; // The minimum, cached to avoid expensive divisions.
};
// Solve one of the problems in the batch (the row_index'th one) using the
// pool-adjacent violators algorithm (PAVA).
//
// The PAVA algorithm goes back to
//
// Nonmetric Multidimensional Scaling: A numerical method
// Kruskal, J. B. (1964), Psychometrika (1964)
//
// For a more recent analysis, please refer to
//
// Active set algorithms for isotonic regression; a unifying framework
// Best, Michael J., and Nilotpal Chakravarti
// Mathematical Programming 47.1-3 (1990)
//
// Intuitively, the algorithm splits the inputs into blocks (starting from
// singleton ones), and then whenever there are two consecutive blocks whose
// minima violate the inequality constraint, they are merged. The solution is
// then block-wise constant, each block equal to the corresponding minimum.
//
// The tensors should be two dimensional, and the segment objects should
// support the minimum() and merge_with() methods.
template <typename SegmentType, typename FloatTensor, typename IntTensor>
void solve_pava(const std::function<SegmentType(int, int)>& make_segment,
FloatTensor* solution, IntTensor* segments, int row_index) {
const size_t n = solution->dimensions()[1];
std::vector<SegmentType> pools;
pools.reserve(n);
for (size_t col_index = 0; col_index < n; ++col_index) {
pools.push_back(make_segment(row_index, col_index));
// While the last two pools are decreasing, merge them.
while (pools.size() > 1 &&
pools.rbegin()->minimum() > (pools.rbegin() + 1)->minimum()) {
(pools.rbegin() + 1)->merge_with(*pools.rbegin());
pools.pop_back();
}
}
int segment_id = 0;
for (const auto& pool : pools) {
const auto pool_minimum = pool.minimum();
// The matrices are row major, so we can scan the memory linearly.
auto* solution_ptr = &(*solution)(row_index, pool.col_start());
auto* segments_ptr = &(*segments)(row_index, pool.col_start());
for (int i = pool.col_start(); i < pool.col_limit(); ++i) {
*solution_ptr++ = pool_minimum;
*segments_ptr++ = segment_id;
}
++segment_id;
}
}
// Solve a batch of problems using the pool-adjacent violators algorithm.
// The problems are solved in parallel using tensorflow's thread pool.
template <typename SegmentType, typename FloatTensor, typename IntTensor>
void solve_pava_batch(const std::function<SegmentType(int, int)>& make_segment,
FloatTensor* solution, IntTensor* segments,
tensorflow::OpKernelContext* context) {
const int batch_size = solution->dimensions()[0];
const int problem_size = solution->dimensions()[1];
auto thread_pool =
context->device()->tensorflow_cpu_worker_threads()->workers;
thread_pool->ParallelFor(
batch_size, kCostMultiplier * problem_size,
[&make_segment, &solution, &segments](int64 row_start, int64 row_limit) {
// Casting to int is safe, as we do boundary checks in `Compute`.
for (int row_index = static_cast<int>(row_start);
row_index < static_cast<int>(row_limit); ++row_index) {
solve_pava(make_segment, solution, segments, row_index);
}
});
}
} // namespace
template <typename Tin, typename Tout>
class IsotonicRegressionOp : public tensorflow::OpKernel {
public:
explicit IsotonicRegressionOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {}
void Compute(tensorflow::OpKernelContext* context) override {
// Grab the input tensor.
const tensorflow::Tensor& input_tensor = context->input(0);
const auto input = input_tensor.flat_inner_dims<Tin, 2>();
int int_max = std::numeric_limits<int32>::max();
OP_REQUIRES(context,
tensorflow::FastBoundsCheck(input.dimensions()[0], int_max) &&
tensorflow::FastBoundsCheck(input.dimensions()[1], int_max),
tensorflow::errors::InvalidArgument("Tensor too large"));
// Create the output tensor holding the minimizers.
const auto shape = input_tensor.shape();
tensorflow::Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, shape, &output_tensor));
auto output = output_tensor->flat_inner_dims<Tout, 2>();
// Create the output tensor holidng the segment memberships.
tensorflow::Tensor* segments_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(1, shape, &segments_tensor));
auto segments = segments_tensor->flat_inner_dims<int>();
auto make_l2_segment = [&input](int row_index, int col_index) {
return L2PavaSegment<Tout>(input(row_index, col_index), col_index);
};
solve_pava_batch<L2PavaSegment<Tout>>(make_l2_segment, &output, &segments,
context);
}
};
#define REGISTER_CPU_KERNEL(Tin, Tout) \
REGISTER_KERNEL_BUILDER(Name("IsotonicRegression") \
.Device(tensorflow::DEVICE_CPU) \
.TypeConstraint<Tin>("T") \
.TypeConstraint<Tout>("output_dtype"), \
IsotonicRegressionOp<Tin, Tout>);
// Float types have the same input and output.
#define REGISTER_CPU_SAME_KERNEL(T) REGISTER_CPU_KERNEL(T, T)
TF_CALL_FLOAT_TYPES(REGISTER_CPU_SAME_KERNEL);
// 8 and 16 bit integers get converted to 32 bit floats.
#define REGISTER_CPU_KERNEL_FLOAT(Tin) REGISTER_CPU_KERNEL(Tin, float)
TF_CALL_int16(REGISTER_CPU_KERNEL_FLOAT);
TF_CALL_int8(REGISTER_CPU_KERNEL_FLOAT);
// 32 and 64 bit integers get converted to 64 bit floats.
#define REGISTER_CPU_KERNEL_DOUBLE(Tin) REGISTER_CPU_KERNEL(Tin, double)
TF_CALL_int64(REGISTER_CPU_KERNEL_DOUBLE);
TF_CALL_int32(REGISTER_CPU_KERNEL_DOUBLE);

View File

@ -0,0 +1,139 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <functional>
#include <memory>
#include <vector>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
class IsotonicRegressionOpTest : public OpsTestBase {
public:
void MakeOp(DataType type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "IsotonicRegression")
.Input(FakeInput(type))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
class BenchmarkHelper : public IsotonicRegressionOpTest {
public:
void TestBody() override {}
void AddIncreasingInput(int batch_size, int input_size) {
std::vector<float> input_data(input_size * batch_size, 0);
for (int i = 0; i < input_data.size(); i++) {
input_data[i] = i;
}
AddInputFromArray<float>(TensorShape({batch_size, input_size}), input_data);
}
};
TEST_F(IsotonicRegressionOpTest, Constant) {
MakeOp(DT_FLOAT_REF);
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
test::FillValues<float>(&expected,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
test::ExpectClose(expected, *GetOutput((0)));
}
TEST_F(IsotonicRegressionOpTest, IncreasingInput) {
MakeOp(DT_FLOAT_REF);
AddInputFromArray<float>(TensorShape({5, 3}),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
test::FillValues<float>(&expected,
{2, 2, 2, 5, 5, 5, 8, 8, 8, 11, 11, 11, 14, 14, 14});
test::ExpectClose(expected, *GetOutput((0)));
Tensor expected_ord(allocator(), DT_INT32, TensorShape({5, 3}));
test::FillValues<int>(&expected_ord,
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
test::ExpectTensorEqual<int>(expected_ord, *GetOutput((1)));
}
TEST_F(IsotonicRegressionOpTest, Decreasing) {
MakeOp(DT_FLOAT_REF);
AddInputFromArray<float>(TensorShape({5, 3}),
{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
test::FillValues<float>(&expected,
{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1});
test::ExpectClose(expected, *GetOutput((0)));
Tensor expected_ord(allocator(), DT_INT32, TensorShape({5, 3}));
test::FillValues<int>(&expected_ord,
{0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2});
test::ExpectTensorEqual<int>(expected_ord, *GetOutput((1)));
}
static void BM_IncreasingSequence(benchmark::State& state) {
int batch_size = state.range(0);
int input_size = state.range(1);
for (auto _ : state) {
state.PauseTiming();
BenchmarkHelper helper;
helper.MakeOp(DT_FLOAT_REF);
helper.AddIncreasingInput(batch_size, input_size);
state.ResumeTiming();
Status stat = helper.RunOpKernel();
}
state.SetItemsProcessed(
static_cast<int64>(batch_size * input_size * state.iterations()));
}
BENCHMARK(BM_IncreasingSequence)
->Args({1, 1 << 0})
->Args({1, 1 << 5})
->Args({1, 1 << 8})
->Args({1, 1 << 10})
->Args({1, 1 << 20})
->Args({1, 2 << 20})
->Args({1 << 0, 1 << 10})
->Args({1 << 1, 1 << 10})
->Args({1 << 4, 1 << 10})
->Args({1 << 6, 1 << 10})
->Args({1 << 9, 1 << 10})
->Args({1 << 10, 1 << 10});
} // namespace
} // namespace tensorflow

View File

@ -3406,4 +3406,16 @@ REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
.Attr("padding_list: list(int) = []")
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
REGISTER_OP("IsotonicRegression")
.Input("input: T")
.Output("output: output_dtype")
.Output("segments: int32")
.Attr("T: realnumbertype")
.Attr("output_dtype: {half, bfloat16, float, double} = DT_FLOAT")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* context) {
context->set_output(0, context->input(0));
context->set_output(1, context->input(0));
return tensorflow::Status::OK();
});
} // namespace tensorflow

View File

@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 350> a = {{
static std::array<OpIndexInfo, 351> a = {{
{"Acosh"},
{"AllToAll", 1, {0}},
{"ApproximateEqual"},
@ -160,6 +160,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
{"Inv"},
{"Invert"},
{"InvertPermutation"},
{"IsotonicRegression"},
{"LMDBReader"},
{"LeakyReluGrad", 1, {0}},
{"LeftShift"},
@ -413,7 +414,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 466> a = {{
static std::array<OpIndexInfo, 467> a = {{
{"Abs"},
{"AccumulateNV2"},
{"Acos"},
@ -577,6 +578,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
{"InvGrad"},
{"Invert"},
{"InvertPermutation"},
{"IsotonicRegression", 1, {0}},
{"L2Loss"},
{"LMDBReader"},
{"LeakyRelu"},

View File

@ -1142,3 +1142,48 @@ def _NthElementGrad(op, grad):
num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
return [math_ops.divide(indicators, num_selected) * grad, None]
def _MeanAggregator(inputs, segments):
"""Replaces each segment with its mean along the last axis.
Specifically, each value in the `inputs` tensor gets replaced by the mean
value computed from the values that belong to the same segment.
Args:
inputs: A 2-tensor. Aggregation is done over dimension 1.
segments: A 2-tensor, same shape as `input`.
Returns:
The result, same shape and type as `inputs`.
"""
result = []
for inputs_i, segments_i in zip(
array_ops.split(inputs, inputs.shape[0]),
array_ops.split(segments, segments.shape[0])):
# Note that we do not use tf.math.segment_mean, as it has no TPU support.
means_i = math_ops.unsorted_segment_mean(
inputs_i, segments_i, num_segments=math_ops.reduce_max(segments_i) + 1)
result.append(
array_ops.reshape(array_ops.gather(means_i, segments_i), [-1]))
return array_ops.stack(result, axis=0)
# We have to register the gradients for these ops so that tensorflow will know
# how to differentiate them.
@ops.RegisterGradient("IsotonicRegression")
def _IsotonicRegressionGrad(op, grad_output, grad_segments):
"""Gradient for the isotonic regression function.
Args:
op: The IsotonicRegression tensorflow op.
grad_output: Tensor of incoming gradients with respect to the output.
grad_segments: Tensor of incoming gradients with respect to the segments.
Returns:
A tensor, same size as `grad_output` with the gradient with respect to
the input.
"""
del grad_segments # Discrete, non-differentiable.
segments = op.outputs[1]
return _MeanAggregator(grad_output, segments)

View File

@ -3566,46 +3566,49 @@ def _flatten_outer_dims(logits):
return output
def _softmax(logits, compute_op, dim=-1, name=None):
"""Helper function for softmax and log_softmax.
def _wrap_2d_function(inputs, compute_op, dim=-1, name=None):
"""Helper function for ops that accept and return 2d inputs of same shape.
It reshapes and transposes the input logits into a 2-D Tensor and then invokes
the tf.nn._softmax or tf.nn._log_softmax function. The output would be
transposed and reshaped back.
It reshapes and transposes the inputs into a 2-D Tensor and then invokes
the given function. The output would be transposed and reshaped back.
If the given function returns a tuple of tensors, each of them will be
transposed and reshaped.
Args:
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
inputs: A non-empty `Tensor`. Must be one of the following types: `half`,
`float32`, `float64`.
compute_op: Either gen_nn_ops.softmax or gen_nn_ops.log_softmax
compute_op: The function to wrap. Must accept the input tensor as its first
arugment, and a second keyword argument `name`.
dim: The dimension softmax would be performed on. The default is -1 which
indicates the last dimension.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
A `Tensor`. Has the same shape as inputs. If compute_op returns multiple
tensors, each of them have the same shape as the input.
Raises:
InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
dimension of `logits`.
InvalidArgumentError: if `inputs` is empty or `dim` is beyond the last
dimension of `inputs`.
"""
def _swap_axis(logits, dim_index, last_index, name=None):
def _swap_axis(input_tensor, dim_index, last_index, name=None):
"""Swaps logits's dim_index and last_index."""
return array_ops.transpose(
logits,
input_tensor,
array_ops.concat([
math_ops.range(dim_index), [last_index],
math_ops.range(dim_index + 1, last_index), [dim_index]
], 0),
name=name)
logits = ops.convert_to_tensor(logits)
inputs = ops.convert_to_tensor(inputs)
# We need its original shape for shape inference.
shape = logits.get_shape()
shape = inputs.get_shape()
is_last_dim = (dim == -1) or (dim == shape.ndims - 1)
if is_last_dim:
return compute_op(logits, name=name)
return compute_op(inputs, name=name)
dim_val = dim
if isinstance(dim, ops.Tensor):
@ -3618,10 +3621,10 @@ def _softmax(logits, compute_op, dim=-1, name=None):
shape.ndims))
# If dim is not the last dimension, we have to do a transpose so that we can
# still perform softmax on its last dimension.
# still perform the op on its last dimension.
# In case dim is negative (and is not last dimension -1), add shape.ndims
ndims = array_ops.rank(logits)
ndims = array_ops.rank(inputs)
if not isinstance(dim, ops.Tensor):
if dim < 0:
dim += ndims
@ -3629,20 +3632,24 @@ def _softmax(logits, compute_op, dim=-1, name=None):
dim = array_ops.where(math_ops.less(dim, 0), dim + ndims, dim)
# Swap logits' dimension of dim and its last dimension.
input_rank = array_ops.rank(logits)
input_rank = array_ops.rank(inputs)
dim_axis = dim % shape.ndims
logits = _swap_axis(logits, dim_axis, math_ops.subtract(input_rank, 1))
inputs = _swap_axis(inputs, dim_axis, math_ops.subtract(input_rank, 1))
# Do the actual softmax on its last dimension.
output = compute_op(logits)
# Do the actual call on its last dimension.
def fix_output(output):
output = _swap_axis(
output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
output = _swap_axis(
output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
# Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
return output
# Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
return output
outputs = compute_op(inputs)
if isinstance(outputs, tuple):
return tuple(fix_output(output) for output in outputs)
else:
return fix_output(outputs)
@tf_export(v1=["nn.softmax", "math.softmax"])
@ -3687,7 +3694,7 @@ def softmax(logits, axis=None, name=None, dim=None):
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
if axis is None:
axis = -1
return _softmax(logits, gen_nn_ops.softmax, axis, name)
return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
@tf_export("nn.softmax", "math.softmax", v1=[])
@ -3715,7 +3722,7 @@ def softmax_v2(logits, axis=None, name=None):
"""
if axis is None:
axis = -1
return _softmax(logits, gen_nn_ops.softmax, axis, name)
return _wrap_2d_function(logits, gen_nn_ops.softmax, axis, name)
@tf_export(v1=["nn.log_softmax", "math.log_softmax"])
@ -3746,7 +3753,7 @@ def log_softmax(logits, axis=None, name=None, dim=None):
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
if axis is None:
axis = -1
return _softmax(logits, gen_nn_ops.log_softmax, axis, name)
return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
@tf_export("nn.log_softmax", "math.log_softmax", v1=[])
@ -3774,7 +3781,7 @@ def log_softmax_v2(logits, axis=None, name=None):
"""
if axis is None:
axis = -1
return _softmax(logits, gen_nn_ops.log_softmax, axis, name)
return _wrap_2d_function(logits, gen_nn_ops.log_softmax, axis, name)
def _ensure_xent_args(name, sentinel, labels, logits):
@ -5674,3 +5681,78 @@ tf_export(v1=["nn.quantized_relu_x"])(
dispatch.add_dispatch_support(gen_nn_ops.quantized_relu_x))
tf_export(v1=["nn.quantized_max_pool"])(
dispatch.add_dispatch_support(gen_nn_ops.quantized_max_pool))
@tf_export("nn.isotonic_regression", v1=[])
@dispatch.add_dispatch_support
def isotonic_regression(inputs, decreasing=True, axis=-1):
r"""Solves isotonic regression problems along the given axis.
For each vector x, the problem solved is
$$\argmin_{y_1 >= y_2 >= ... >= y_n} \sum_i (x_i - y_i)^2.$$
As the solution is component-wise constant, a second tensor is returned that
encodes the segments. The problems are solved over the given axis.
Consider the following example, where we solve a batch of two problems. The
first input is [3, 1, 2], while the second [1, 3, 4] (as the axis is 1).
>>> x = tf.constant([[3, 1, 2], [1, 3, 4]], dtype=tf.float32)
>>> y, segments = tf.nn.isotonic_regression(x, axis=1)
>>> y # The solution.
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[3. , 1.5 , 1.5 ],
[2.6666667, 2.6666667, 2.6666667]], dtype=float32)>
Note that the first solution has two blocks [2] and [1.5, 1.5]. The second
solution is constant, and thus has a single segment. These segments are
exactly what the second returned tensor encodes:
>>> segments
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[0, 1, 1],
[0, 0, 0]], dtype=int32)>
Args:
inputs: A tensor holding the inputs.
decreasing: If set to False, the inequalities in the optimizing constrained
are flipped.
axis: The axis along which the problems should be solved.
Returns:
output: The solutions, same shape as type as the input.
segments: An int32 tensor, same shape as the input indicating the segments
that have the same value. Specifically, those positions that have the same
value correspond to the same segment. These values start at zero, and are
monotonously increasing for each solution.
"""
type_promotions = {
# Float types get mapped to themselves, int8/16 to float32, rest to double
dtypes.float32:
dtypes.float32,
dtypes.half:
dtypes.half,
dtypes.bfloat16:
dtypes.bfloat16,
dtypes.int8:
dtypes.float32,
dtypes.int16:
dtypes.float32,
}
inputs = ops.convert_to_tensor(inputs)
try:
output_dtype = type_promotions[inputs.dtype]
except KeyError:
output_dtype = dtypes.float64
def compute_on_matrix(matrix, name=None):
iso_fn = functools.partial(
gen_nn_ops.isotonic_regression, output_dtype=output_dtype, name=name)
if decreasing:
return iso_fn(matrix)
else:
output, segments = iso_fn(-matrix)
return -output, segments
return _wrap_2d_function(inputs, compute_on_matrix, axis)

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_impl
@ -1701,5 +1702,88 @@ class RaggedEmbeddingTest(test_lib.TestCase):
actual)
class IsotonicTest(parameterized.TestCase, test_lib.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_increasing_and_decreasing(self):
x = constant_op.constant([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
dtype=dtypes.float64)
y, segments = nn_ops.isotonic_regression(x, decreasing=False)
self.assertAllClose(y, x)
self.assertAllClose(segments, [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])
y, segments = nn_ops.isotonic_regression(x, decreasing=True)
self.assertAllClose(
y,
[
[2, 2, 2, 2, 2], # Average of the inputs.
[7, 7, 7, 7, 7]
])
self.assertAllClose(segments, array_ops.zeros((2, 5)))
y, segments = nn_ops.isotonic_regression(-x, decreasing=True)
self.assertAllClose(segments, [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])
self.assertAllClose(y, -x)
y, segments = nn_ops.isotonic_regression(-x, decreasing=False)
self.assertAllClose(
-y,
[
[2, 2, 2, 2, 2], # Average of the inputs.
[7, 7, 7, 7, 7]
])
self.assertAllClose(segments, array_ops.zeros((2, 5)))
@test_util.run_in_graph_and_eager_modes
def test_different_axis(self):
x = constant_op.constant([[0, 6, 2, 8, 4], [5, 1, 7, 3, 9]],
dtype=dtypes.float64)
y, segments = nn_ops.isotonic_regression(x, decreasing=True, axis=0)
self.assertAllClose(
y,
[
[2.5, 6, 4.5, 8, 6.5], # Either identity or average.
[2.5, 1, 4.5, 3, 6.5]
])
self.assertAllClose(segments, [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0]])
@test_util.run_v2_only
def testGradientV2(self, dtype=np.float64, batch_size=30, dimensions=50):
@def_function.function
def ComputeIsotonicFn(x):
y, _ = nn_ops.isotonic_regression(x) # No gradient wrt segments.
return y
np.random.seed(0)
x_init = np.random.randn(batch_size, dimensions).astype(dtype)
grad_theoretical, grad_numerical = gradient_checker_v2.compute_gradient(
ComputeIsotonicFn, [x_init], delta=1e-5)
self.assertAllClose(grad_theoretical, grad_numerical)
@test_util.run_v1_only("compute_gradient_error is v1 only")
def testGradientV1(self, dtype=np.float64, batch_size=30, dimensions=50):
np.random.seed(0)
x_init = np.random.randn(batch_size, dimensions).astype(dtype)
with self.cached_session():
x = array_ops.placeholder(dtype, (batch_size, dimensions))
y, _ = nn_ops.isotonic_regression(x) # Segments have no gradient.
max_error = gradient_checker.compute_gradient_error(
x, (batch_size, dimensions), y, (batch_size, dimensions), x_init)
self.assertAllClose(max_error, 0.)
@parameterized.parameters([[dtypes.half, dtypes.half],
[dtypes.bfloat16, dtypes.bfloat16],
[dtypes.float32, dtypes.float32],
[dtypes.float64, dtypes.float64],
[dtypes.int32, dtypes.float64],
[dtypes.int16, dtypes.float32]])
def testTypePromotion(self, dtype_in, expected_dtype_out):
x = constant_op.constant([[0, 6, 2, 8, 4], [5, 1, 7, 3, 9]], dtype=dtype_in)
y, segments = nn_ops.isotonic_regression(x)
self.assertEqual(y.dtype, expected_dtype_out)
self.assertEqual(segments.dtype, dtypes.int32)
if __name__ == "__main__":
test_lib.main()

View File

@ -1996,6 +1996,10 @@ tf_module {
name: "IsVariableInitialized"
argspec: "args=[\'ref\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "IsotonicRegression"
argspec: "args=[\'input\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "Iterator"
argspec: "args=[\'shared_name\', \'container\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -176,6 +176,10 @@ tf_module {
name: "in_top_k"
argspec: "args=[\'targets\', \'predictions\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "isotonic_regression"
argspec: "args=[\'inputs\', \'decreasing\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'-1\'], "
}
member_method {
name: "l2_loss"
argspec: "args=[\'t\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -1996,6 +1996,10 @@ tf_module {
name: "IsVariableInitialized"
argspec: "args=[\'ref\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "IsotonicRegression"
argspec: "args=[\'input\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "Iterator"
argspec: "args=[\'shared_name\', \'container\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "