Add Matrix{Diag,SetDiag,DiagPart}V3 ops with alignment options.

V2 ops always align the diagonals to the left (LEFT_LEFT) in the compact format. V3 ops support 4 alignments: RIGHT_LEFT, LEFT_RIGHT, LEFT_LEFT, and RIGHT_RIGHT. We would like to use RIGHT_LEFT as the default alignment. This contradicts with v2's behavior so we need new a version.

V2 has never been exposed to the public APIs. We will skip V2 and go from V1 to V3 directly. V3 features are currently under forward compatibility guards and will be enabled automatically in ~3 weeks from now.

This commit contains
- V3 API definitions.
- Modifications to C++ Matrix{Diag,SetDiag,DiagPart}Op kernels (CPU, GPU, XLA) and shape inference functions to support v3.
- Additional tests and gradient implementations in Python for v3.
- Pfor and TFLite TOCO converters for v3.
- The TFLite MLIR converter for MatrixDiagV3 is intentionally left out because of an MLIR test infrastructure issue and will be added in a separate commit.

Notes:
- Python changes cannot be in a separate follow-up commit because all kernel tests are in Python. (No C++ tests.)
- All three ops have to be in the same commit because their gradients call each other.
PiperOrigin-RevId: 280527550
Change-Id: I88e91abab5c4b50419204807ede4fa60657f048a
This commit is contained in:
Penporn Koanantakool 2019-11-14 15:29:16 -08:00 committed by TensorFlower Gardener
parent 39e4f9be03
commit 90e67385a4
41 changed files with 2273 additions and 992 deletions

View File

@ -26,9 +26,82 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
# LINT.IfChange
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
# LINT.ThenChange(
# //tensorflow/python/kernel_tests/diag_op_test.py,
# //tensorflow/python/ops/array_ops.py,
# //tensorflow/python/ops/parallel_for/array_test.py
# )
default_v2_alignment = "LEFT_LEFT"
alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"]
def zip_to_first_list_length(a, b):
if len(b) > len(a):
return zip(a, b[:len(a)])
return zip(a, b + [None] * (len(a) - len(b)))
# Routines to convert test cases to have diagonals in a specified alignment.
# Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py
def repack_diagonals(packed_diagonals,
diag_index,
num_rows,
num_cols,
align=None):
# The original test cases are LEFT_LEFT aligned.
if align == default_v2_alignment or align is None:
return packed_diagonals
align = align.split("_")
d_lower, d_upper = diag_index
batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1)
max_diag_len = packed_diagonals.shape[-1]
index = (slice(None),) * batch_dims
repacked_diagonals = np.zeros_like(packed_diagonals)
# Aligns each diagonal row-by-row.
for diag_index in range(d_lower, d_upper + 1):
diag_len = min(num_rows + min(0, diag_index), num_cols - max(0, diag_index))
row_index = d_upper - diag_index
padding_len = max_diag_len - diag_len
left_align = (diag_index >= 0 and
align[0] == "LEFT") or (diag_index <= 0 and
align[1] == "LEFT")
# Prepares index tuples.
extra_dim = tuple() if d_lower == d_upper else (row_index,)
packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),)
repacked_last_dim = (slice(None),) if left_align else (slice(
padding_len, max_diag_len, 1),)
packed_index = index + extra_dim + packed_last_dim
repacked_index = index + extra_dim + repacked_last_dim
# Repacks the diagonal.
repacked_diagonals[repacked_index] = packed_diagonals[packed_index]
return repacked_diagonals
def repack_diagonals_in_tests(tests, align=None):
# The original test cases are LEFT_LEFT aligned.
if align == default_v2_alignment or align is None:
return tests
new_tests = dict()
# Loops through each case.
for diag_index, (packed_diagonals, padded_diagonals) in tests.items():
num_rows, num_cols = padded_diagonals.shape[-2:]
repacked_diagonals = repack_diagonals(
packed_diagonals, diag_index, num_rows, num_cols, align=align)
new_tests[diag_index] = (repacked_diagonals, padded_diagonals)
return new_tests
# Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2.
# Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py
def square_cases():
def square_cases(align=None):
# pyformat: disable
mat = np.array([[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 1],
@ -103,10 +176,10 @@ def square_cases():
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]))
# pyformat: enable
return (mat, tests)
return (mat, repack_diagonals_in_tests(tests, align))
def tall_cases():
def tall_cases(align=None):
# pyformat: disable
mat = np.array([[[1, 2, 3],
[4, 5, 6],
@ -191,10 +264,10 @@ def tall_cases():
[0, 0, 0],
[0, 0, 0]]]))
# pyformat: enable
return (mat, tests)
return (mat, repack_diagonals_in_tests(tests, align))
def fat_cases():
def fat_cases(align=None):
# pyformat: disable
mat = np.array([[[1, 2, 3, 4],
[5, 6, 7, 8],
@ -259,7 +332,11 @@ def fat_cases():
[0, 9, 1, 2],
[0, 0, 5, 6]]]))
# pyformat: enable
return (mat, tests)
return (mat, repack_diagonals_in_tests(tests, align))
def all_tests(align=None):
return [square_cases(align), tall_cases(align), fat_cases(align)]
class MatrixDiagTest(xla_test.XLATestCase):
@ -327,39 +404,31 @@ class MatrixDiagTest(xla_test.XLATestCase):
# From here onwards are v2-only tests.
def testSquare(self):
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for _, tests in [square_cases()]:
for diag_index, (vecs, solution) in tests.items():
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs[0],
"k": diag_index
}, solution[0])
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
for align in alignment_list:
for _, tests in [square_cases(align)]:
for diag_index, (vecs, solution) in tests.items():
params = {"diagonal": vecs[0], "k": diag_index, "align": align}
self._assertOpOutputMatchesExpected(params, solution[0])
def testSquareBatch(self):
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for _, tests in [square_cases()]:
for diag_index, (vecs, solution) in tests.items():
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"k": diag_index
}, solution)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
for align in alignment_list:
for _, tests in [square_cases(align)]:
for diag_index, (vecs, solution) in tests.items():
params = {"diagonal": vecs, "k": diag_index, "align": align}
self._assertOpOutputMatchesExpected(params, solution)
def testRectangularBatch(self):
# LINT.IfChange
if not compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
if not compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
return
# Stores expected num_rows and num_cols (when the other is given).
# expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
test_list = list()
# Do not align the test cases here. Re-alignment needs to happen after the
# solution shape is updated.
# Square cases:
expected = {
(-1, -1): (5, 4),
@ -389,43 +458,65 @@ class MatrixDiagTest(xla_test.XLATestCase):
test_list.append((expected, fat_cases()))
# Giving both num_rows and num_cols
for _, tests in [tall_cases(), fat_cases()]:
align = alignment_list[0]
for _, tests in [tall_cases(align), fat_cases(align)]:
for diag_index, (vecs, solution) in tests.items():
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"k": diag_index,
"num_rows": solution.shape[-2],
"num_cols": solution.shape[-1]
"num_cols": solution.shape[-1],
"align": align
}, solution)
# We go through each alignment in a round-robin manner.
align_index = 0
# Giving just num_rows or num_cols.
for expected, (_, tests) in test_list:
for diag_index, (new_num_rows, new_num_cols) in expected.items():
align = alignment_list[align_index]
align_index = (align_index + 1) % len(alignment_list)
vecs, solution = tests[diag_index]
solution_given_num_rows = solution.take(
indices=range(new_num_cols), axis=-1)
# Repacks the diagonal input according to the new solution shape.
vecs_given_num_rows = repack_diagonals(
vecs,
diag_index,
solution_given_num_rows.shape[-2],
new_num_cols,
align=align)
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"diagonal": vecs_given_num_rows,
"k": diag_index,
"num_rows": solution_given_num_rows.shape[-2]
"num_rows": solution_given_num_rows.shape[-2],
"align": align
}, solution_given_num_rows)
solution_given_num_cols = solution.take(
indices=range(new_num_rows), axis=-2)
# Repacks the diagonal input according to the new solution shape.
vecs_given_num_cols = repack_diagonals(
vecs,
diag_index,
new_num_rows,
solution_given_num_cols.shape[-1],
align=align)
self._assertOpOutputMatchesExpected(
{
"diagonal": vecs,
"diagonal": vecs_given_num_cols,
"k": diag_index,
"num_cols": solution_given_num_cols.shape[-1]
"num_cols": solution_given_num_cols.shape[-1],
"align": align
}, solution_given_num_cols)
def testPadding(self):
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for padding_value in [555, -11]:
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
for padding_value, align in zip_to_first_list_length([555, -11],
alignment_list):
for _, tests in all_tests(align):
for diag_index, (vecs, solution) in tests.items():
mask = (solution == 0)
solution = solution + (mask * padding_value)
@ -435,7 +526,8 @@ class MatrixDiagTest(xla_test.XLATestCase):
"k": diag_index,
"num_rows": solution.shape[-2],
"num_cols": solution.shape[-1],
"padding_value": padding_value
"padding_value": padding_value,
"align": align
}, solution)
@ -542,36 +634,36 @@ class MatrixSetDiagTest(xla_test.XLATestCase):
# From here onwards are v2-only tests.
def testSingleMatrix(self):
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (vecs, banded_mat) in tests.items():
mask = (banded_mat[0] == 0)
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat[0]
self._assertOpOutputMatchesExpected(
{
"input": input_mat,
"diagonal": vecs[0],
"k": diag_index
}, solution)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
for align in alignment_list:
for _, tests in all_tests(align):
for diag_index, (vecs, banded_mat) in tests.items():
mask = (banded_mat[0] == 0)
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat[0]
self._assertOpOutputMatchesExpected(
{
"input": input_mat,
"diagonal": vecs[0],
"k": diag_index,
"align": align
}, solution)
def testBatch(self):
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for _, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (vecs, banded_mat) in tests.items():
mask = (banded_mat == 0)
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat
self._assertOpOutputMatchesExpected(
{
"input": input_mat,
"diagonal": vecs,
"k": diag_index
}, solution)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
for align in alignment_list:
for _, tests in all_tests(align):
for diag_index, (vecs, banded_mat) in tests.items():
mask = (banded_mat == 0)
input_mat = np.random.randint(10, size=mask.shape)
solution = input_mat * mask + banded_mat
self._assertOpOutputMatchesExpected(
{
"input": input_mat,
"diagonal": vecs,
"k": diag_index,
"align": align
}, solution)
class MatrixDiagPartTest(xla_test.XLATestCase):
@ -613,33 +705,35 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
# From here onwards are v2-only tests.
def testSingleMatrix(self):
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (solution, _) in tests.items():
self._assertOpOutputMatchesExpected({
"input": mat[0],
"k": diag_index
}, solution[0])
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
for align in alignment_list:
test_list = [square_cases(align), tall_cases(align), fat_cases(align)]
for mat, tests in test_list:
for diag_index, (solution, _) in tests.items():
self._assertOpOutputMatchesExpected(
{
"input": mat[0],
"k": diag_index,
"align": align
}, solution[0])
def testBatch(self):
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
for diag_index, (solution, _) in tests.items():
self._assertOpOutputMatchesExpected({
"input": mat,
"k": diag_index
}, solution)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
for align in alignment_list:
for mat, tests in all_tests(align):
for diag_index, (solution, _) in tests.items():
self._assertOpOutputMatchesExpected(
{
"input": mat,
"k": diag_index,
"align": align
}, solution)
def testPadding(self):
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
for padding_value in [555, -11]:
for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
for padding_value, align in zip_to_first_list_length([555, -11],
alignment_list):
for mat, tests in all_tests(align):
for diag_index, (solution, _) in tests.items():
mask = (solution == 0)
solution = solution + (mask * padding_value)
@ -647,7 +741,8 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
{
"input": mat,
"k": diag_index,
"padding_value": padding_value
"padding_value": padding_value,
"align": align
}, solution)

View File

@ -26,6 +26,30 @@ limitations under the License.
namespace tensorflow {
namespace {
// Calculates the diagonal length of a diagonal.
static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) {
return std::min(num_rows + std::min(0, diag_index),
num_cols - std::max(0, diag_index));
}
// Checks if a diagonal is to be aligned left or right.
static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal,
bool left_align_subdiagonal) {
return (diag_index >= 0 && left_align_superdiagonal) ||
(diag_index <= 0 && left_align_subdiagonal);
}
// Reads the diagonal packing alignment.
void ReadAlignment(OpKernelConstruction* context,
bool* left_align_superdiagonal,
bool* left_align_subdiagonal) {
string align;
OP_REQUIRES_OK(context, context->GetAttr("align", &align));
*left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
*left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
}
// Reads or infers lower_diag_index and upper_diag_index from kernel's input
// parameter "k". Also validates their values.
std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) {
@ -94,7 +118,9 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
const TensorShape& input_shape, const int64 diag_rank,
const int64 num_diags, const int64 lower_diag_index,
const int64 upper_diag_index, const int64 max_diag_len,
const int64 num_rows, const int64 num_cols) {
const int64 num_rows, const int64 num_cols,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
// Creates a padding config.
const int input_rank = input_shape.dims();
xla::PaddingConfig padding_config;
@ -108,29 +134,26 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
// XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
//
// For example,
// diag = [[2, 3, 0], k = (-1, 1), and num_rows = 4.
// diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT".
// [4, 5, 6],
// [7, 8, 9]]
// The expected output is [[4, 2, 0],
// [7, 5, 4],
// [0, 8, 6],
// [0, 0, 9]]
// The expected output is [[7, 4, 2, 0],
// [0, 8, 5, 3],
// [0, 0, 9, 6]].
// The 1st diagonal is created by:
// 1) Extracting diag_slice = [1, 2, 0].
// 2) Padding the vector to be as long as num_rows,
// diag_slice = [1, 2, 0, 0],
// then broadcasting diag_slice row-wise to a full matrix,
// diag_broadcast = [[1, 1, 1],
// [2, 2, 2],
// [0, 0, 0],
// [0, 0, 0]]
// 1) Extracting diag_slice = [0, 2, 3] which is right-aligned.
// 2) Padding the vector (in the same direction) to be as long as num_cols,
// diag_slice = [0, 0, 2, 3],
// then broadcasting diag_slice column-wise to a full matrix,
// diag_broadcast = [[0, 0, 2, 3],
// [0, 0, 2, 3],
// [0, 0, 2, 3]].
// The padding value can be anything because it will not appear in the
// results after masking. Here, we use zero.
// 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
// mask = [[0, 1, 0], --> output = [[x, 2, x],
// [0, 0, 1], [x, x, 3],
// [0, 0, 0], [x, x, x],
// [0, 0, 0]] [x, x, x]],
// mask = [[0, 0, 1, 0], --> output = [[x, x, 2, x],
// [0, 0, 0, 1], [x, x, x, 3],
// [0, 0, 0, 0]] [x, x, x, x]],
// where x denotes the existing input contents.
std::vector<int64> broadcast_dimensions(input_rank - 1);
absl::c_iota(broadcast_dimensions, 0);
@ -140,6 +163,8 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
// Extracts a single diagonal.
auto diag_slice = diag;
if (num_diags > 1) {
// The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len].
// We call Collapse to make the dims: [<batch_dim>, max_diag_len].
const int64 mapped_diag_index = upper_diag_index - diag_index;
diag_slice = xla::Collapse(
xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
@ -147,20 +172,51 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
{diag_rank - 2, diag_rank - 1});
}
// Pads if necessary. Always pad at the end because shorter diagonals in
// the input come padded at the end.
const int64 padding_length =
((diag_index <= 0) ? num_cols : num_rows) - max_diag_len;
const xla::XlaOp zero = xla::ScalarLike(input, 0);
if (padding_length > 0) {
// Pad if necessary.
// - If the diagonal has the longest length, i.e., min(num_rows, num_cols),
// no padding is necessary. It is broadcast column-wise if it is a sub-
// diagonal, row-wise if superdiagonal.
// - Otherwise, pad and keep the old alignment (shorter diagonals in the
// input come pre-padded). max_len in the table refers to max_diag_len.
// -------------------------------------------------------------------
// | Diag | Align | Broadcast | padding_low | padding_high |
// -------------------------------------------------------------------
// | Super | Left | Row-wise | 0 | #rows - max_len |
// | | Right | Column-wise | #cols - max_len | 0 |
// -------------------------------------------------------------------
// | Sub | Left | Column-wise | 0 | #cols - max_len |
// | | Right | Row-wise | #rows - max_len | 0 |
// -------------------------------------------------------------------
if (num_cols - num_rows <= diag_index && diag_index <= 0) {
broadcast_dimensions.back() = input_rank - 1; // Column-wise.
} else if (0 <= diag_index && diag_index <= num_cols - num_rows) {
broadcast_dimensions.back() = input_rank - 2; // Row-wise.
} else {
int length_to_pad_to;
if ((diag_index > 0 && left_align_superdiagonal) ||
(diag_index < 0 && !left_align_subdiagonal)) {
length_to_pad_to = num_rows;
broadcast_dimensions.back() = input_rank - 2; // Row-wise.
} else {
length_to_pad_to = num_cols;
broadcast_dimensions.back() = input_rank - 1; // Column-wise.
}
int padding_low = length_to_pad_to - max_diag_len;
int padding_high = 0;
if (IsLeftAligned(diag_index, left_align_superdiagonal,
left_align_subdiagonal)) {
std::swap(padding_low, padding_high);
}
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_high(padding_length);
->set_edge_padding_low(padding_low);
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_high(padding_high);
const xla::XlaOp zero = xla::ScalarLike(input, 0);
diag_slice = xla::Pad(diag_slice, zero, padding_config);
}
// Broadcasts column-wise for subdiagonals; row-wise for superdiagonals.
broadcast_dimensions.back() =
(diag_index <= 0) ? input_rank - 1 : input_rank - 2;
// Broadcast and mask.
xla::XlaOp diag_broadcast = xla::BroadcastInDim(
diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
const auto mask = xla::GetDiagonalMask(output, diag_index);
@ -173,11 +229,17 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
class MatrixDiagOp : public XlaOpKernel {
public:
explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {
// MatrixDiagV3-specific.
if (context->HasAttr("align")) {
ReadAlignment(context, &left_align_superdiagonal_,
&left_align_subdiagonal_);
}
}
void Compile(XlaOpKernelContext* context) override {
OP_REQUIRES(
context, context->num_inputs() >= 1,
context, context->num_inputs() >= kNumV1Inputs,
errors::InvalidArgument("MatrixDiag op must have at least one input"));
const TensorShape diag_shape = context->InputShape(0);
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
@ -199,7 +261,7 @@ class MatrixDiagOp : public XlaOpKernel {
// MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
// one input, so we have to check the number of inputs before reading
// additional parameters for MatrixDiagV2.
if (context->num_inputs() > 1) {
if (context->num_inputs() > kNumV1Inputs) {
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
@ -252,8 +314,14 @@ class MatrixDiagOp : public XlaOpKernel {
context->SetOutput(
0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
lower_diag_index, upper_diag_index, max_diag_len,
num_rows, num_cols));
num_rows, num_cols, left_align_superdiagonal_,
left_align_subdiagonal_));
}
private:
bool left_align_superdiagonal_ = true;
bool left_align_subdiagonal_ = true;
static constexpr int kNumV1Inputs = 1;
};
REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
@ -263,12 +331,24 @@ REGISTER_XLA_OP(Name("MatrixDiagV2")
.CompileTimeConstantInput("num_cols")
.CompileTimeConstantInput("padding_value"),
MatrixDiagOp);
REGISTER_XLA_OP(Name("MatrixDiagV3")
.CompileTimeConstantInput("k")
.CompileTimeConstantInput("num_rows")
.CompileTimeConstantInput("num_cols")
.CompileTimeConstantInput("padding_value"),
MatrixDiagOp);
class MatrixDiagPartOp : public XlaOpKernel {
public:
explicit MatrixDiagPartOp(OpKernelConstruction* context)
: XlaOpKernel(context),
is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
// MatrixDiagPartV3-specific.
if (context->HasAttr("align")) {
ReadAlignment(context, &left_align_superdiagonal_,
&left_align_subdiagonal_);
}
}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
@ -290,7 +370,7 @@ class MatrixDiagPartOp : public XlaOpKernel {
// MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
// MatrixDiagPart only has one input, so we have to check the number of
// inputs before reading additional parameters in MatrixDiagV2.
if (context->num_inputs() > 1) {
if (context->num_inputs() > kNumV1Inputs) {
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
padding_value = context->Input(2);
}
@ -314,25 +394,34 @@ class MatrixDiagPartOp : public XlaOpKernel {
// Computes output.
xla::XlaOp input = context->Input(0);
std::vector<xla::XlaOp> diag_list;
xla::PaddingConfig padding_config;
xla::PaddingConfig padding_config =
xla::MakeNoPaddingConfig(input_rank - 1);
if (num_diags == 1) {
context->SetOutput(
0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index)
: xla::GetMatrixDiagonal(input, upper_diag_index));
return;
}
padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
--diag_index) {
xla::XlaOp single_diag =
is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index)
: xla::GetMatrixDiagonal(input, diag_index);
const int64 diag_length =
(diag_index >= 0) ? (num_cols - diag_index) : (num_rows + diag_index);
const int64 padding_length = max_diag_len - diag_length;
if (padding_length > 0) {
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_high(padding_length);
const int64 diag_len = ComputeDiagLen(diag_index, num_rows, num_cols);
const int64 padding_len = max_diag_len - diag_len;
if (padding_len > 0) {
if (IsLeftAligned(diag_index, left_align_superdiagonal_,
left_align_subdiagonal_)) {
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_low(0);
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_high(padding_len);
} else {
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_low(padding_len);
padding_config.mutable_dimensions(input_rank - 2)
->set_edge_padding_high(0);
}
single_diag = xla::Pad(single_diag, padding_value, padding_config);
}
diag_list.emplace_back(single_diag);
@ -344,6 +433,9 @@ class MatrixDiagPartOp : public XlaOpKernel {
private:
const bool is_gpu_;
bool left_align_superdiagonal_ = true;
bool left_align_subdiagonal_ = true;
static constexpr int kNumV1Inputs = 1;
};
REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
@ -351,11 +443,21 @@ REGISTER_XLA_OP(Name("MatrixDiagPartV2")
.CompileTimeConstantInput("k")
.CompileTimeConstantInput("padding_value"),
MatrixDiagPartOp);
REGISTER_XLA_OP(Name("MatrixDiagPartV3")
.CompileTimeConstantInput("k")
.CompileTimeConstantInput("padding_value"),
MatrixDiagPartOp);
class MatrixSetDiagOp : public XlaOpKernel {
public:
explicit MatrixSetDiagOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
: XlaOpKernel(context) {
// MatrixSetDiagV3-specific.
if (context->HasAttr("align")) {
ReadAlignment(context, &left_align_superdiagonal_,
&left_align_subdiagonal_);
}
}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
@ -378,7 +480,7 @@ class MatrixSetDiagOp : public XlaOpKernel {
// reading additional parameters in MatrixSetDiagV2.
int64 lower_diag_index = 0;
int64 upper_diag_index = 0;
if (context->num_inputs() > 2) {
if (context->num_inputs() > kNumV1Inputs) {
std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
}
@ -419,15 +521,21 @@ class MatrixSetDiagOp : public XlaOpKernel {
context->SetOutput(
0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
lower_diag_index, upper_diag_index, max_diag_len,
num_rows, num_cols));
num_rows, num_cols, left_align_superdiagonal_,
left_align_subdiagonal_));
}
private:
bool left_align_superdiagonal_ = true;
bool left_align_subdiagonal_ = true;
static constexpr int kNumV1Inputs = 2;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
};
REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
MatrixSetDiagOp);
REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"),
MatrixSetDiagOp);
} // namespace tensorflow

View File

@ -0,0 +1,141 @@
op {
graph_op_name: "MatrixDiagPartV3"
in_arg {
name: "input"
description: "Rank `r` tensor where `r >= 2`."
}
in_arg {
name: "k"
description: <<END
Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
diagonal, and negative value means subdiagonals. `k` can be a single integer
(for a single diagonal) or a pair of integers specifying the low and high ends
of a matrix band. `k[0]` must not be larger than `k[1]`.
END
}
in_arg {
name: "padding_value"
description: <<END
The value to fill the area outside the specified diagonal band with.
Default is 0.
END
}
out_arg {
name: "diagonal"
description: "The extracted diagonal(s)."
}
attr {
name: "align"
description: <<END
Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is
a string specifying how superdiagonals and subdiagonals should be aligned,
respectively. There are four possible alignments: "RIGHT_LEFT" (default),
"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals
to the right (left-pads the row) and subdiagonals to the left (right-pads the
row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is
the opposite alignment.
END
}
summary: "Returns the batched diagonal part of a batched tensor."
description: <<END
Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched
`input`.
Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`.
Let `max_diag_len` be the maximum length among all diagonals to be extracted,
`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
Let `num_diags` be the number of diagonals to extract,
`num_diags = k[1] - k[0] + 1`.
If `num_diags == 1`, the output tensor is of rank `r - 1` with shape
`[I, J, ..., L, max_diag_len]` and values:
```
diagonal[i, j, ..., l, n]
= input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
padding_value ; otherwise.
```
where `y = max(-k[1], 0)`, `x = max(k[1], 0)`.
Otherwise, the output tensor has rank `r` with dimensions
`[I, J, ..., L, num_diags, max_diag_len]` with values:
```
diagonal[i, j, ..., l, m, n]
= input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
padding_value ; otherwise.
```
where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`.
`offset` is zero except when the alignment of the diagonal is to the right.
```
offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
and `d >= 0`) or
(`align` in {LEFT_RIGHT, RIGHT_RIGHT}
and `d <= 0`)
0 ; otherwise
```
where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
The input must be at least a matrix.
For example:
```
input = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4)
[5, 6, 7, 8],
[9, 8, 7, 6]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8]]])
# A main diagonal from each batch.
tf.matrix_diag_part(input) ==> [[1, 6, 7], # Output shape: (2, 3)
[5, 2, 7]]
# A superdiagonal from each batch.
tf.matrix_diag_part(input, k = 1)
==> [[2, 7, 6], # Output shape: (2, 3)
[4, 3, 8]]
# A band from each batch.
tf.matrix_diag_part(input, k = (-1, 2))
==> [[[0, 3, 8], # Output shape: (2, 4, 3)
[2, 7, 6],
[1, 6, 7],
[5, 8, 0]],
[[0, 3, 4],
[4, 3, 8],
[5, 2, 7],
[1, 6, 0]]]
# LEFT_RIGHT alignment.
tf.matrix_diag_part(input, k = (-1, 2), align="LEFT_RIGHT")
==> [[[3, 8, 0], # Output shape: (2, 4, 3)
[2, 7, 6],
[1, 6, 7],
[0, 5, 8]],
[[3, 4, 0],
[4, 3, 8],
[5, 2, 7],
[0, 1, 6]]]
# max_diag_len can be shorter than the main diagonal.
tf.matrix_diag_part(input, k = (-2, -1))
==> [[[5, 8],
[9, 0]],
[[1, 6],
[5, 0]]]
# padding_value = 9
tf.matrix_diag_part(input, k = (1, 3), padding_value = 9)
==> [[[9, 9, 4], # Output shape: (2, 3, 3)
[9, 3, 8],
[2, 7, 6]],
[[9, 9, 2],
[9, 3, 4],
[4, 3, 8]]]
```
END
}

View File

@ -0,0 +1,177 @@
op {
graph_op_name: "MatrixDiagV3"
in_arg {
name: "diagonal"
description: "Rank `r`, where `r >= 1`"
}
in_arg {
name: "k"
description: <<END
Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
diagonal, and negative value means subdiagonals. `k` can be a single integer
(for a single diagonal) or a pair of integers specifying the low and high ends
of a matrix band. `k[0]` must not be larger than `k[1]`.
END
}
in_arg {
name: "num_rows"
description: <<END
The number of rows of the output matrix. If it is not provided, the op assumes
the output matrix is a square matrix and infers the matrix size from k and the
innermost dimension of `diagonal`.
END
}
in_arg {
name: "num_cols"
description: <<END
The number of columns of the output matrix. If it is not provided, the op
assumes the output matrix is a square matrix and infers the matrix size from
k and the innermost dimension of `diagonal`.
END
}
in_arg {
name: "padding_value"
description: <<END
The number to fill the area outside the specified diagonal band with.
Default is 0.
END
}
out_arg {
name: "output"
description: <<END
Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise.
END
}
attr {
name: "align"
description: <<END
Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is
a string specifying how superdiagonals and subdiagonals should be aligned,
respectively. There are four possible alignments: "RIGHT_LEFT" (default),
"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals
to the right (left-pads the row) and subdiagonals to the left (right-pads the
row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is
the opposite alignment.
END
}
summary:
"Returns a batched diagonal tensor with given batched diagonal values."
description: <<END
Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th
diagonals of a matrix, with everything else padded with `padding`. `num_rows`
and `num_cols` specify the dimension of the innermost matrix of the output. If
both are not specified, the op assumes the innermost matrix is square and infers
its size from `k` and the innermost dimension of `diagonal`. If only one of them
is specified, the op assumes the unspecified value is the smallest possible
based on other criteria.
Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor has
rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only one
diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has rank
`r` with shape `[I, J, ..., L, num_rows, num_cols]`.
The second innermost dimension of `diagonal` has double meaning.
When `k` is scalar or `k[0] == k[1]`, `M` is part of the batch size
[I, J, ..., M], and the output tensor is:
```
output[i, j, ..., l, m, n]
= diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper
padding_value ; otherwise
```
Otherwise, `M` is treated as the number of diagonals for the matrix in the
same batch (`M = k[1]-k[0]+1`), and the output tensor is:
```
output[i, j, ..., l, m, n]
= diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
padding_value ; otherwise
```
where `d = n - m`, `diag_index = [k] - d`, and
`index_in_diag = n - max(d, 0) + offset`.
`offset` is zero except when the alignment of the diagonal is to the right.
```
offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
and `d >= 0`) or
(`align` in {LEFT_RIGHT, RIGHT_RIGHT}
and `d <= 0`)
0 ; otherwise
```
where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
For example:
```
# The main diagonal.
diagonal = np.array([[1, 2, 3, 4], # Input shape: (2, 4)
[5, 6, 7, 8]])
tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0], # Output shape: (2, 4, 4)
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]],
[[5, 0, 0, 0],
[0, 6, 0, 0],
[0, 0, 7, 0],
[0, 0, 0, 8]]]
# A superdiagonal (per batch).
diagonal = np.array([[1, 2, 3], # Input shape: (2, 3)
[4, 5, 6]])
tf.matrix_diag(diagonal, k = 1)
==> [[[0, 1, 0, 0], # Output shape: (2, 4, 4)
[0, 0, 2, 0],
[0, 0, 0, 3],
[0, 0, 0, 0]],
[[0, 4, 0, 0],
[0, 0, 5, 0],
[0, 0, 0, 6],
[0, 0, 0, 0]]]
# A tridiagonal band (per batch).
diagonals = np.array([[[0, 8, 9], # Input shape: (2, 2, 3)
[1, 2, 3],
[4, 5, 0]],
[[0, 2, 3],
[6, 7, 9],
[9, 1, 0]]])
tf.matrix_diag(diagonals, k = (-1, 1))
==> [[[1, 8, 0], # Output shape: (2, 3, 3)
[4, 2, 9],
[0, 5, 3]],
[[6, 2, 0],
[9, 7, 3],
[0, 1, 9]]]
# LEFT_RIGHT alignment.
diagonals = np.array([[[8, 9, 0], # Input shape: (2, 2, 3)
[1, 2, 3],
[0, 4, 5]],
[[2, 3, 0],
[6, 7, 9],
[0, 9, 1]]])
tf.matrix_diag(diagonals, k = (-1, 1), align="LEFT_RIGHT")
==> [[[1, 8, 0], # Output shape: (2, 3, 3)
[4, 2, 9],
[0, 5, 3]],
[[6, 2, 0],
[9, 7, 3],
[0, 1, 9]]]
# Rectangular matrix.
diagonal = np.array([1, 2]) # Input shape: (2)
tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4)
==> [[0, 0, 0, 0], # Output shape: (3, 4)
[1, 0, 0, 0],
[0, 2, 0, 0]]
# Rectangular matrix with inferred num_cols and padding_value = 9.
tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
==> [[9, 9], # Output shape: (3, 2)
[1, 9],
[9, 2]]
```
END
}

View File

@ -0,0 +1,148 @@
op {
graph_op_name: "MatrixSetDiagV3"
in_arg {
name: "input"
description: "Rank `r+1`, where `r >= 1`."
}
in_arg {
name: "diagonal"
description: <<END
Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`.
`k >= 1`.
END
}
in_arg {
name: "k"
description: <<END
Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
diagonal, and negative value means subdiagonals. `k` can be a single integer
(for a single diagonal) or a pair of integers specifying the low and high ends
of a matrix band. `k[0]` must not be larger than `k[1]`.
END
}
out_arg {
name: "output"
description: <<END
Rank `r+1`, with `output.shape = input.shape`.
END
}
attr {
name: "align"
description: <<END
Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is
a string specifying how superdiagonals and subdiagonals should be aligned,
respectively. There are four possible alignments: "RIGHT_LEFT" (default),
"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals
to the right (left-pads the row) and subdiagonals to the left (right-pads the
row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is
the opposite alignment.
END
}
summary: "Returns a batched matrix tensor with new batched diagonal values."
description: <<END
Given `input` and `diagonal`, this operation returns a tensor with the
same shape and values as `input`, except for the specified diagonals of the
innermost matrices. These will be overwritten by the values in `diagonal`.
`input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or
`k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`.
Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`.
`num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
`max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`.
If `k` is scalar or `k[0] == k[1]`:
```
output[i, j, ..., l, m, n]
= diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
input[i, j, ..., l, m, n] ; otherwise
```
Otherwise,
```
output[i, j, ..., l, m, n]
= diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
input[i, j, ..., l, m, n] ; otherwise
```
where `d = n - m`, `diag_index = k[1] - d`, and
`index_in_diag = n - max(d, 0) + offset`.
`offset` is zero except when the alignment of the diagonal is to the right.
```
offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
and `d >= 0`) or
(`align` in {LEFT_RIGHT, RIGHT_RIGHT}
and `d <= 0`)
0 ; otherwise
```
where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
For example:
```
# The main diagonal.
input = np.array([[[7, 7, 7, 7], # Input shape: (2, 3, 4)
[7, 7, 7, 7],
[7, 7, 7, 7]],
[[7, 7, 7, 7],
[7, 7, 7, 7],
[7, 7, 7, 7]]])
diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3)
[4, 5, 6]])
tf.matrix_set_diag(input, diagonal)
==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
[7, 2, 7, 7],
[7, 7, 3, 7]],
[[4, 7, 7, 7],
[7, 5, 7, 7],
[7, 7, 6, 7]]]
# A superdiagonal (per batch).
tf.matrix_set_diag(input, diagonal, k = 1)
==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4)
[7, 7, 2, 7],
[7, 7, 7, 3]],
[[7, 4, 7, 7],
[7, 7, 5, 7],
[7, 7, 7, 6]]]
# A band of diagonals.
diagonals = np.array([[[0, 9, 1], # Diagonal shape: (2, 4, 3)
[6, 5, 8],
[1, 2, 3],
[4, 5, 0]],
[[0, 1, 2],
[5, 6, 4],
[6, 1, 2],
[3, 4, 0]]])
tf.matrix_set_diag(input, diagonals, k = (-1, 2))
==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4)
[4, 2, 5, 1],
[7, 5, 3, 8]],
[[6, 5, 1, 7],
[3, 1, 6, 2],
[7, 4, 2, 4]]]
# LEFT_RIGHT alignment.
diagonals = np.array([[[9, 1, 0], # Diagonal shape: (2, 4, 3)
[6, 5, 8],
[1, 2, 3],
[0, 4, 5]],
[[1, 2, 0],
[5, 6, 4],
[6, 1, 2],
[0, 3, 4]]])
tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT")
==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4)
[4, 2, 5, 1],
[7, 5, 3, 8]],
[[6, 5, 1, 7],
[3, 1, 6, 2],
[7, 4, 2, 4]]]
```
END
}

View File

@ -1,11 +1,4 @@
op {
graph_op_name: "MatrixDiagPartV2"
endpoint {
name: "linalg.diag_part"
}
endpoint {
name: "matrix_diag_part"
deprecation_version: 2
}
visibility: HIDDEN
}

View File

@ -0,0 +1,11 @@
op {
graph_op_name: "MatrixDiagPartV3"
endpoint {
name: "linalg.diag_part"
}
endpoint {
name: "matrix_diag_part"
deprecation_version: 2
}
visibility: HIDDEN
}

View File

@ -1,11 +1,4 @@
op {
graph_op_name: "MatrixDiagV2"
endpoint {
name: "linalg.diag"
}
endpoint {
name: "matrix_diag"
deprecation_version: 2
}
visibility: HIDDEN
}

View File

@ -0,0 +1,11 @@
op {
graph_op_name: "MatrixDiagV3"
endpoint {
name: "linalg.diag"
}
endpoint {
name: "matrix_diag"
deprecation_version: 2
}
visibility: HIDDEN
}

View File

@ -1,11 +1,4 @@
op {
graph_op_name: "MatrixSetDiagV2"
endpoint {
name: "linalg.set_diag"
}
endpoint {
name: "matrix_set_diag"
deprecation_version: 2
}
visibility: HIDDEN
}

View File

@ -0,0 +1,11 @@
op {
graph_op_name: "MatrixSetDiagV3"
endpoint {
name: "linalg.set_diag"
}
endpoint {
name: "matrix_set_diag"
deprecation_version: 2
}
visibility: HIDDEN
}

View File

@ -1189,6 +1189,254 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
int32* lower_diag_index, int32* upper_diag_index) {
// This function assumes that the shape of diag_index_tensor is fully defined.
if (diag_index_tensor->dims() == 0) {
*lower_diag_index = diag_index_tensor->scalar<int32>()();
*upper_diag_index = *lower_diag_index;
} else {
int32 num_elements = diag_index_tensor->dim_size(0);
if (num_elements == 1) {
*lower_diag_index = diag_index_tensor->vec<int32>()(0);
*upper_diag_index = *lower_diag_index;
} else if (num_elements == 2) {
*lower_diag_index = diag_index_tensor->vec<int32>()(0);
*upper_diag_index = diag_index_tensor->vec<int32>()(1);
} else {
return errors::InvalidArgument(
"diag_index must be a vector with one or two elements. It has ",
num_elements, " elements.");
}
}
return Status::OK();
}
Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) {
ShapeHandle input_shape, diag_index_shape, unused_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
const Tensor* diag_index_tensor = c->input_tensor(1);
if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
diag_index_tensor == nullptr) {
c->set_output(0, c->UnknownShape());
return Status::OK();
}
int32 lower_diag_index = 0;
int32 upper_diag_index = 0;
TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
&upper_diag_index));
if (lower_diag_index > upper_diag_index) {
return errors::InvalidArgument(
"lower_diag_index is greater than upper_diag_index");
}
// Validates lower_diag_index and upper_diag_index.
const int32 input_rank = c->Rank(input_shape);
const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
if (num_rows != InferenceContext::kUnknownDim &&
num_cols != InferenceContext::kUnknownDim) {
if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
(-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
return errors::InvalidArgument("lower_diag_index is out of bound.");
}
if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
(-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
return errors::InvalidArgument("upper_diag_index is out of bound.");
}
}
const int32 max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0),
num_cols - std::max(lower_diag_index, 0));
std::vector<DimensionHandle> dims;
dims.reserve(input_rank - 2);
for (int i = 0; i < input_rank - 2; ++i) {
dims.push_back(c->Dim(input_shape, i));
}
if (lower_diag_index < upper_diag_index) {
dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
}
dims.push_back(c->MakeDim(max_diag_len));
c->set_output(0, c->MakeShape(dims));
return Status::OK();
}
Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) {
// Checks input ranks.
ShapeHandle input_shape, diag_index_shape, unused_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
// Reads the diagonal indices.
const Tensor* diag_index_tensor = c->input_tensor(1);
if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
diag_index_tensor == nullptr) {
c->set_output(0, c->UnknownShape());
return Status::OK();
}
int32 lower_diag_index = 0;
int32 upper_diag_index = 0;
TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
&upper_diag_index));
if (lower_diag_index > upper_diag_index) {
return errors::InvalidArgument(
"lower_diag_index is greater than upper_diag_index");
}
// Checks if the number of diagonals provided matches what we imply from
// lower_diag_index and upper_diag_index.
const int32 input_rank = c->Rank(input_shape);
if (lower_diag_index < upper_diag_index) {
const int32 num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
const int32 other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
return errors::InvalidArgument(
"The number of rows of `diagonal` doesn't match the number of "
"diagonals implied from `d_lower` and `d_upper`.\n",
"num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim);
}
}
// Reads num_rows and num_cols.
const Tensor* num_rows_tensor = c->input_tensor(2);
const Tensor* num_cols_tensor = c->input_tensor(3);
int64 num_rows = -1;
int64 num_cols = -1;
if (num_rows_tensor != nullptr) {
TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
}
if (num_cols_tensor != nullptr) {
TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
}
// Infers the missing num_rows or num_cols: If both are missing, assume
// output is square. Otherwise, use the smallest possible value. Also
// validates the provided values.
const int32 max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
if (num_rows == -1 && num_cols == -1) { // Special case.
num_rows = std::max(min_num_rows, min_num_cols);
num_cols = num_rows;
}
if (num_rows == -1) {
num_rows = min_num_rows;
} else if (num_rows < min_num_rows) {
return errors::InvalidArgument("num_rows is too small");
}
if (num_cols == -1) {
num_cols = min_num_cols;
} else if (num_cols < min_num_cols) {
return errors::InvalidArgument("num_cols is too small.");
}
// At least one of them must match the minimum length.
if (num_rows != min_num_rows && num_cols != min_num_cols) {
return errors::InvalidArgument(
"num_rows and num_cols are not consistent with lower_diag_index, "
"upper_diag_index, and the length of the given diagonals.\n",
"num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
}
// Sets output shape.
ShapeHandle output_shape;
const DimensionHandle output_row_dim = c->MakeDim(num_rows);
const DimensionHandle output_col_dim = c->MakeDim(num_cols);
if (lower_diag_index == upper_diag_index) {
TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
output_row_dim, &output_shape));
TF_RETURN_IF_ERROR(
c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape));
} else {
TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
output_row_dim, &output_shape));
TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
output_col_dim, &output_shape));
}
c->set_output(0, output_shape);
return Status::OK();
}
Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
ShapeHandle input_shape, diag_shape, diag_index_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
int32 lower_diag_index = 0;
int32 upper_diag_index = 0;
bool diag_index_known = false;
const Tensor* diag_index_tensor = c->input_tensor(2);
if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
diag_index_known = true;
TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
&upper_diag_index));
if (lower_diag_index > upper_diag_index) {
return errors::InvalidArgument(
"lower_diag_index is greater than upper_diag_index");
}
}
// Do more checks when input rank is known.
if (c->RankKnown(input_shape)) {
int32 input_rank = c->Rank(input_shape);
// If diag_index is set, we know the exact rank of diagonal.
if (diag_index_known) {
TF_RETURN_IF_ERROR(c->WithRank(
c->input(1),
(lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank,
&diag_shape));
} else {
TF_RETURN_IF_ERROR(
c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
TF_RETURN_IF_ERROR(
c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
}
// Validates lower_diag_index and upper_diag_index.
const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
if (num_rows != InferenceContext::kUnknownDim &&
num_cols != InferenceContext::kUnknownDim) {
if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
(-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
return errors::InvalidArgument("lower_diag_index is out of bound.");
}
if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
(-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
return errors::InvalidArgument("upper_diag_index is out of bound.");
}
}
}
ShapeHandle output_shape = input_shape;
if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
// Try to infer parts of shape from diag.
ShapeHandle diag_prefix;
TF_RETURN_IF_ERROR(c->Subshape(
diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
&diag_prefix));
// The inner matrices can be rectangular, so we can't pinpoint their
// exact height and width by just lower_diag_index, upper_diag_index,
// and the longest length of given diagonals.
TF_RETURN_IF_ERROR(
c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
}
c->set_output(0, output_shape);
return Status::OK();
}
Status MaxPoolShape(shape_inference::InferenceContext* c) {
string data_format_str;
TensorFormat data_format;

View File

@ -270,6 +270,15 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c);
// Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
// Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations.
Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c);
// Shape function for MatrixDiagV2 and MatrixDiagV3 operations.
Status MatrixDiagV2Shape(shape_inference::InferenceContext* c);
// Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations.
Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c);
// Shape function for MaxPool-like operations.
Status MaxPoolShape(shape_inference::InferenceContext* c);

View File

@ -1158,7 +1158,7 @@ tf_kernel_library(
tf_kernel_library(
name = "matrix_set_diag_op",
prefix = "matrix_set_diag_op",
deps = ARRAY_DEPS,
deps = ARRAY_DEPS + [":matrix_diag_op"],
)
tf_kernel_library(

View File

@ -46,8 +46,13 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T>
class MatrixDiagPartOp : public OpKernel {
public:
explicit MatrixDiagPartOp(OpKernelConstruction* context)
: OpKernel(context) {}
explicit MatrixDiagPartOp(OpKernelConstruction* context) : OpKernel(context) {
// MatrixDiagPartV3-specific.
if (context->HasAttr("align")) {
functor::ReadAlignment(context, &left_align_superdiagonal_,
&left_align_subdiagonal_);
}
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
@ -60,7 +65,7 @@ class MatrixDiagPartOp : public OpKernel {
T padding_value(0);
// MatrixDiagPartV2-specific.
if (context->num_inputs() > 1) {
if (context->num_inputs() > kNumV1Inputs) {
auto& diag_index = context->input(1);
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(diag_index.shape()) ||
@ -132,17 +137,26 @@ class MatrixDiagPartOp : public OpKernel {
functor::MatrixDiagPart<Device, T>::Compute(
context, context->eigen_device<Device>(), input_reshaped,
output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
padding_value);
padding_value, left_align_superdiagonal_, left_align_subdiagonal_);
}
private:
bool left_align_superdiagonal_ = true;
bool left_align_subdiagonal_ = true;
static constexpr int kNumV1Inputs = 1;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagPartOp);
};
template <typename Device, typename T>
class MatrixDiagOp : public OpKernel {
public:
explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {}
explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {
// MatrixDiagV3-specific.
if (context->HasAttr("align")) {
functor::ReadAlignment(context, &left_align_superdiagonal_,
&left_align_subdiagonal_);
}
}
void Compute(OpKernelContext* context) override {
const Tensor& diagonal = context->input(0);
@ -157,7 +171,7 @@ class MatrixDiagOp : public OpKernel {
T padding_value(0);
// MatrixDiagOpV2-specific.
if (context->num_inputs() > 1) {
if (context->num_inputs() > kNumV1Inputs) {
auto& diag_index = context->input(1);
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(diag_index.shape()) ||
@ -242,10 +256,13 @@ class MatrixDiagOp : public OpKernel {
functor::MatrixDiag<Device, T>::Compute(
context, context->eigen_device<Device>(), diag_reshaped,
output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
padding_value);
padding_value, left_align_superdiagonal_, left_align_subdiagonal_);
}
private:
bool left_align_superdiagonal_ = true;
bool left_align_subdiagonal_ = true;
static constexpr int kNumV1Inputs = 1;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagOp);
};
@ -256,12 +273,19 @@ class MatrixDiagOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("MatrixDiagV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixDiagOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("MatrixDiagV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixDiagOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("MatrixDiagPartV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("MatrixDiagPartV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<CPUDevice, type>);
TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG);
#undef REGISTER_MATRIX_DIAG
@ -280,6 +304,28 @@ TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG);
// Implementation of the functor specialization for CPU.
namespace functor {
void ReadAlignment(OpKernelConstruction* context,
bool* left_align_superdiagonal,
bool* left_align_subdiagonal) {
string align;
OP_REQUIRES_OK(context, context->GetAttr("align", &align));
*left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
*left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
}
std::pair<int, int> ComputeDiagLenAndContentOffset(
int diag_index, int max_diag_len, int num_rows, int num_cols,
bool left_align_superdiagonal, bool left_align_subdiagonal) {
const bool left_align = (diag_index >= 0 && left_align_superdiagonal) ||
(diag_index <= 0 && left_align_subdiagonal);
const int diag_len = std::min(num_rows + std::min(0, diag_index),
num_cols - std::max(0, diag_index));
const int content_offset = (left_align) ? 0 : (max_diag_len - diag_len);
return {diag_len, content_offset};
}
template <typename T>
struct MatrixDiag<CPUDevice, T> {
static void Compute(OpKernelContext* context, const CPUDevice& device,
@ -287,16 +333,21 @@ struct MatrixDiag<CPUDevice, T> {
typename TTypes<T, 3>::Tensor& output,
const Eigen::Index lower_diag_index,
const Eigen::Index upper_diag_index,
const Eigen::Index max_diag_len, const T padding_value) {
const Eigen::Index max_diag_len, const T padding_value,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
// 10 in cost_per_batch is from existing heuristic.
// TODO(penporn): Tune for the best constant in cost_per_batch.
const Eigen::Index num_batches = output.dimension(0);
const Eigen::Index cost_per_batch =
10 * output.dimension(1) * output.dimension(2);
const Eigen::Index num_rows = output.dimension(1);
const Eigen::Index num_cols = output.dimension(2);
const Eigen::Index cost_per_batch = 10 * num_rows * num_cols;
auto compute_shard = [&output, &diag, &lower_diag_index, &upper_diag_index,
&max_diag_len, &padding_value](Eigen::Index begin,
Eigen::Index end) {
auto compute_shard = [&output, &num_rows, &num_cols, &diag,
&lower_diag_index, &upper_diag_index, &max_diag_len,
&padding_value, &left_align_superdiagonal,
&left_align_subdiagonal](Eigen::Index begin,
Eigen::Index end) {
const int num_diags = upper_diag_index - lower_diag_index + 1;
const int diag_elements_in_batch = num_diags * max_diag_len;
Eigen::Index diag_batch_base_index = begin * diag_elements_in_batch;
@ -305,8 +356,12 @@ struct MatrixDiag<CPUDevice, T> {
for (Eigen::Index j = 0; j < output.dimension(2); ++j) {
const int diag_index = j - i;
const int diag_index_in_input = upper_diag_index - diag_index;
int diag_len, content_offset;
std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
diag_index, max_diag_len, num_rows, num_cols,
left_align_superdiagonal, left_align_subdiagonal);
const int index_in_the_diagonal =
j - std::max<Eigen::Index>(diag_index, 0);
j - std::max<Eigen::Index>(diag_index, 0) + content_offset;
if (lower_diag_index <= diag_index &&
diag_index <= upper_diag_index) {
output(batch, i, j) = diag(diag_batch_base_index +
@ -334,7 +389,9 @@ struct MatrixDiagPart<CPUDevice, T> {
typename TTypes<T>::Tensor& output,
const Eigen::Index lower_diag_index,
const Eigen::Index upper_diag_index,
const Eigen::Index max_diag_len, const T padding_value) {
const Eigen::Index max_diag_len, const T padding_value,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
// 10 in cost_per_batch is from existing heuristic.
// TODO(penporn): Tune for the best constant in cost_per_batch.
const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
@ -346,24 +403,32 @@ struct MatrixDiagPart<CPUDevice, T> {
auto compute_shard = [&output, &input, &num_rows, &num_cols,
&upper_diag_index, &max_diag_len, &num_diags,
&output_elements_in_batch, &padding_value](
&output_elements_in_batch, &padding_value,
&left_align_superdiagonal, &left_align_subdiagonal](
Eigen::Index begin, Eigen::Index end) {
Eigen::Index output_base_index = begin * output_elements_in_batch;
for (Eigen::Index batch = begin; batch < end; ++batch) {
for (Eigen::Index m = 0; m < num_diags; ++m) {
const Eigen::Index d = upper_diag_index - m;
Eigen::Index n = 0;
// Make two separate cases to save some index calculations.
if (d >= 0) {
for (; n < std::min(num_rows, num_cols - d); ++n) {
output(output_base_index + n) = input(batch, n, n + d);
}
} else {
for (; n < std::min(num_rows + d, num_cols); ++n) {
output(output_base_index + n) = input(batch, n - d, n);
}
const Eigen::Index diag_index = upper_diag_index - m;
Eigen::Index y_offset = std::max<Eigen::Index>(0, -diag_index);
Eigen::Index x_offset = std::max<Eigen::Index>(0, diag_index);
int diag_len, content_offset;
std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
diag_index, max_diag_len, num_rows, num_cols,
left_align_superdiagonal, left_align_subdiagonal);
// Fills the diagonal.
for (Eigen::Index n = 0; n < diag_len; ++n) {
output(output_base_index + content_offset + n) =
input(batch, n + y_offset, n + x_offset);
}
for (; n < max_diag_len; ++n) { // Padding.
// Padding.
const bool left_align = (content_offset == 0);
const Eigen::Index padding_start = (left_align) ? diag_len : 0;
const Eigen::Index padding_end =
(left_align) ? max_diag_len : content_offset;
for (Eigen::Index n = padding_start; n < padding_end; ++n) {
output(output_base_index + n) = padding_value;
}
output_base_index += max_diag_len;
@ -391,7 +456,8 @@ namespace functor {
typename TTypes<T, 3>::Tensor& output, \
const Eigen::Index lower_diag_index, \
const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, \
const T padding_value); \
const T padding_value, const bool left_align_superdiagonal, \
const bool left_align_subdiagonal); \
extern template struct MatrixDiag<GPUDevice, T>; \
template <> \
void MatrixDiagPart<GPUDevice, T>::Compute( \
@ -399,7 +465,8 @@ namespace functor {
typename TTypes<T, 3>::ConstTensor& input, \
typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, \
const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, \
const T padding_value); \
const T padding_value, const bool left_align_superdiagonal, \
const bool left_align_subdiagonal); \
extern template struct MatrixDiagPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
@ -422,6 +489,14 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
.HostMemory("num_cols") \
.HostMemory("padding_value"), \
MatrixDiagOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("MatrixDiagV3") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("k") \
.HostMemory("num_rows") \
.HostMemory("num_cols") \
.HostMemory("padding_value"), \
MatrixDiagOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<GPUDevice, type>); \
@ -430,6 +505,12 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
.TypeConstraint<type>("T") \
.HostMemory("k") \
.HostMemory("padding_value"), \
MatrixDiagPartOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("MatrixDiagPartV3") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("k") \
.HostMemory("padding_value"), \
MatrixDiagPartOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);

View File

@ -19,6 +19,7 @@ limitations under the License.
// Generator definition for MatrixDiagOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
@ -26,26 +27,43 @@ limitations under the License.
namespace tensorflow {
namespace functor {
// Reads the diagonal packing alignment.
void ReadAlignment(OpKernelConstruction* context,
bool* left_align_superdiagonal,
bool* left_align_subdiagonal);
// Calculates diagonal length and content offset (from aligning) of a diagonal.
// Returns a pair of integers {diag_len, content_offset}:
// - diag_len: The length of the diag_index-th diagonal.
// - content_offset: Each diagonal is stored as a row in the compact format.
// If the diagonal is shorter than max_diag_len, its content is aligned
// either to the left or right. content_offset is the index in the row
// where the first element of the diag-index-th diagonal is stored. It is
// always zero when the diagonal is left-aligned.
std::pair<int, int> ComputeDiagLenAndContentOffset(
int diag_index, int max_diag_len, int num_rows, int num_cols,
bool left_align_superdiagonal, bool left_align_subdiagonal);
template <typename Device, typename T>
struct MatrixDiagPart {
EIGEN_ALWAYS_INLINE static void Compute(
OpKernelContext* context, const Device& device,
typename TTypes<T, 3>::ConstTensor& input,
typename TTypes<T>::Tensor& output_original, const Eigen::Index d_lower,
const Eigen::Index d_upper, const Eigen::Index max_diag_len,
const T padding);
typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index,
const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len,
const T padding_value, const bool left_align_superdiagonal,
const bool left_align_subdiagonal);
};
template <typename Device, typename T>
struct MatrixDiag {
EIGEN_ALWAYS_INLINE static void Compute(OpKernelContext* context,
const Device& device,
typename TTypes<T>::ConstTensor& diag,
typename TTypes<T, 3>::Tensor& output,
const Eigen::Index d_lower,
const Eigen::Index d_upper,
const Eigen::Index max_diag_len,
const T padding);
EIGEN_ALWAYS_INLINE static void Compute(
OpKernelContext* context, const Device& device,
typename TTypes<T>::ConstTensor& diag,
typename TTypes<T, 3>::Tensor& output,
const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index,
const Eigen::Index max_diag_len, const T padding_value,
const bool left_align_superdiagonal, const bool left_align_subdiagonal);
};
} // namespace functor

View File

@ -26,13 +26,28 @@ namespace functor {
typedef Eigen::GpuDevice GPUDevice;
__device__ inline int ComputeContentOffset(const int diag_index,
const int max_diag_len,
const int num_rows,
const int num_cols,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
const bool left_align = (diag_index >= 0 && left_align_superdiagonal) ||
(diag_index <= 0 && left_align_subdiagonal);
if (left_align) return 0;
const int y_offset = min(0, diag_index);
const int x_offset = max(0, diag_index);
const int diag_len = min(num_rows + y_offset, num_cols - x_offset);
return max_diag_len - diag_len;
}
template <typename T>
__global__ void MatrixDiagKernel(const int num_threads, const int num_rows,
const int num_cols, const int num_diags,
const int max_diag_len,
const int lower_diag_index,
const int upper_diag_index, const T padding,
const T* diag_ptr, T* output_ptr) {
__global__ void MatrixDiagKernel(
const int num_threads, const int num_rows, const int num_cols,
const int num_diags, const int max_diag_len, const int lower_diag_index,
const int upper_diag_index, const T padding_value,
const bool left_align_superdiagonal, const bool left_align_subdiagonal,
const T* diag_ptr, T* output_ptr) {
GPU_1D_KERNEL_LOOP(index, num_threads) {
const int batch_and_row_index = index / num_cols;
const int col = index - batch_and_row_index * num_cols;
@ -40,13 +55,16 @@ __global__ void MatrixDiagKernel(const int num_threads, const int num_rows,
const int row = batch_and_row_index - batch * num_rows;
const int diag_index = col - row;
const int diag_index_in_input = upper_diag_index - diag_index;
const int index_in_the_diagonal = col - max(diag_index, 0);
const int content_offset =
ComputeContentOffset(diag_index, max_diag_len, num_rows, num_cols,
left_align_superdiagonal, left_align_subdiagonal);
const int index_in_the_diagonal = col - max(diag_index, 0) + content_offset;
if (lower_diag_index <= diag_index && diag_index <= upper_diag_index) {
output_ptr[index] =
diag_ptr[batch * num_diags * max_diag_len +
diag_index_in_input * max_diag_len + index_in_the_diagonal];
} else {
output_ptr[index] = padding;
output_ptr[index] = padding_value;
}
}
}
@ -58,7 +76,9 @@ struct MatrixDiag<GPUDevice, T> {
typename TTypes<T, 3>::Tensor& output,
const Eigen::Index lower_diag_index,
const Eigen::Index upper_diag_index,
const Eigen::Index max_diag_len, const T padding) {
const Eigen::Index max_diag_len, const T padding_value,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
const int batch_size = output.dimension(0);
const int num_rows = output.dimension(1);
const int num_cols = output.dimension(2);
@ -72,19 +92,19 @@ struct MatrixDiag<GPUDevice, T> {
TF_CHECK_OK(GpuLaunchKernel(
MatrixDiagKernel<T>, config.block_count, config.thread_per_block, 0,
device.stream(), config.virtual_thread_count, num_rows, num_cols,
num_diags, max_diag_len, lower_diag_index, upper_diag_index, padding,
num_diags, max_diag_len, lower_diag_index, upper_diag_index,
padding_value, left_align_superdiagonal, left_align_subdiagonal,
diag.data(), output.data()));
}
};
template <typename T>
__global__ void MatrixDiagPartKernel(const int num_threads, const int num_rows,
const int num_cols, const int num_diags,
const int max_diag_len,
const int lower_diag_index,
const int upper_diag_index,
const T padding, const T* input_ptr,
T* output_ptr) {
__global__ void MatrixDiagPartKernel(
const int num_threads, const int num_rows, const int num_cols,
const int num_diags, const int max_diag_len, const int lower_diag_index,
const int upper_diag_index, const T padding_value,
const bool left_align_superdiagonal, const bool left_align_subdiagonal,
const T* input_ptr, T* output_ptr) {
GPU_1D_KERNEL_LOOP(index, num_threads) {
const int batch_and_mapped_diag_index = index / max_diag_len;
const int index_in_the_diagonal =
@ -93,13 +113,19 @@ __global__ void MatrixDiagPartKernel(const int num_threads, const int num_rows,
const int mapped_diag_index =
batch_and_mapped_diag_index - batch * num_diags;
const int diag_index = upper_diag_index - mapped_diag_index;
const int y_index = index_in_the_diagonal + max(0, -diag_index);
const int x_index = index_in_the_diagonal + max(0, diag_index);
if (y_index < num_rows && x_index < num_cols) {
const int content_offset =
ComputeContentOffset(diag_index, max_diag_len, num_rows, num_cols,
left_align_superdiagonal, left_align_subdiagonal);
const int y_offset = max(0, -diag_index);
const int x_offset = max(0, diag_index);
const int y_index = index_in_the_diagonal + y_offset - content_offset;
const int x_index = index_in_the_diagonal + x_offset - content_offset;
if (0 <= y_index && y_index < num_rows && 0 <= x_index &&
x_index < num_cols) {
output_ptr[index] =
input_ptr[batch * num_rows * num_cols + y_index * num_cols + x_index];
} else {
output_ptr[index] = padding;
output_ptr[index] = padding_value;
}
}
}
@ -111,7 +137,9 @@ struct MatrixDiagPart<GPUDevice, T> {
typename TTypes<T>::Tensor& output,
const Eigen::Index lower_diag_index,
const Eigen::Index upper_diag_index,
const Eigen::Index max_diag_len, const T padding) {
const Eigen::Index max_diag_len, const T padding_value,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
const int batch_size = input.dimension(0);
const int num_rows = input.dimension(1);
const int num_cols = input.dimension(2);
@ -125,7 +153,8 @@ struct MatrixDiagPart<GPUDevice, T> {
TF_CHECK_OK(GpuLaunchKernel(
MatrixDiagPartKernel<T>, config.block_count, config.thread_per_block, 0,
device.stream(), config.virtual_thread_count, num_rows, num_cols,
num_diags, max_diag_len, lower_diag_index, upper_diag_index, padding,
num_diags, max_diag_len, lower_diag_index, upper_diag_index,
padding_value, left_align_superdiagonal, left_align_subdiagonal,
input.data(), output.data()));
}
};

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/matrix_diag_op.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -42,7 +43,13 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T>
class MatrixSetDiagOp : public OpKernel {
public:
explicit MatrixSetDiagOp(OpKernelConstruction* context) : OpKernel(context) {}
explicit MatrixSetDiagOp(OpKernelConstruction* context) : OpKernel(context) {
// MatrixSetDiagV3-specific.
if (context->HasAttr("align")) {
functor::ReadAlignment(context, &left_align_superdiagonal_,
&left_align_subdiagonal_);
}
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
@ -55,7 +62,7 @@ class MatrixSetDiagOp : public OpKernel {
int32 upper_diag_index = 0;
// MatrixSetDiagV2-specific.
if (context->num_inputs() > 2) {
if (context->num_inputs() > kNumV1Inputs) {
auto& diag_index = context->input(2);
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(diag_index.shape()) ||
@ -155,10 +162,14 @@ class MatrixSetDiagOp : public OpKernel {
auto output_reshaped = output->flat_inner_dims<T, 3>();
functor::MatrixSetDiag<Device, T>::Compute(
context, context->eigen_device<Device>(), input_reshaped, diag_reshaped,
output_reshaped, lower_diag_index, upper_diag_index, max_diag_len);
output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
left_align_superdiagonal_, left_align_subdiagonal_);
}
private:
bool left_align_superdiagonal_ = true;
bool left_align_subdiagonal_ = true;
static constexpr int kNumV1Inputs = 2;
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
};
@ -168,7 +179,11 @@ class MatrixSetDiagOp : public OpKernel {
MatrixSetDiagOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("MatrixSetDiagV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixSetDiagOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("MatrixSetDiagV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixSetDiagOp<CPUDevice, type>);
TF_CALL_POD_TYPES(REGISTER_MATRIX_SET_DIAG);
#undef REGISTER_MATRIX_SET_DIAG
@ -192,29 +207,38 @@ struct MatrixSetDiag<CPUDevice, T> {
typename TTypes<T, 3>::Tensor& output,
const Eigen::Index lower_diag_index,
const Eigen::Index upper_diag_index,
const Eigen::Index max_diag_len) {
const Eigen::Index max_diag_len,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
if (input.data() != output.data()) {
output.device(device) = input;
}
const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
auto compute_shard = [&output, &diag, &upper_diag_index, &max_diag_len,
&num_diags](Eigen::Index begin, Eigen::Index end) {
&num_diags, &left_align_superdiagonal,
&left_align_subdiagonal](Eigen::Index begin,
Eigen::Index end) {
const Eigen::Index num_rows = output.dimension(1);
const Eigen::Index num_cols = output.dimension(2);
Eigen::Index diag_base_index = begin * num_diags * max_diag_len;
for (Eigen::Index batch = begin; batch < end; ++batch) {
for (Eigen::Index m = 0; m < num_diags; ++m) {
const Eigen::Index d = upper_diag_index - m;
const Eigen::Index diag_index = upper_diag_index - m;
int diag_len, content_offset;
std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
diag_index, max_diag_len, num_rows, num_cols,
left_align_superdiagonal, left_align_subdiagonal);
// Make two separate cases to save some index calculations.
if (d >= 0) {
for (Eigen::Index n = 0; n < std::min(num_rows, num_cols - d);
++n) {
output(batch, n, n + d) = diag(diag_base_index + n);
if (diag_index >= 0) {
for (Eigen::Index n = 0; n < diag_len; ++n) {
output(batch, n, n + diag_index) =
diag(diag_base_index + n + content_offset);
}
} else {
for (Eigen::Index n = 0; n < std::min(num_rows + d, num_cols);
++n) {
output(batch, n - d, n) = diag(diag_base_index + n);
for (Eigen::Index n = 0; n < diag_len; ++n) {
output(batch, n - diag_index, n) =
diag(diag_base_index + n + content_offset);
}
}
diag_base_index += max_diag_len;
@ -236,15 +260,16 @@ struct MatrixSetDiag<CPUDevice, T> {
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void MatrixSetDiag<GPUDevice, T>::Compute( \
OpKernelContext* context, const GPUDevice& device, \
typename TTypes<T, 3>::ConstTensor& input, \
typename TTypes<T>::ConstTensor& diag, \
typename TTypes<T, 3>::Tensor& output, \
const Eigen::Index lower_diag_index, \
const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len); \
#define DECLARE_GPU_SPEC(T) \
template <> \
void MatrixSetDiag<GPUDevice, T>::Compute( \
OpKernelContext* context, const GPUDevice& device, \
typename TTypes<T, 3>::ConstTensor& input, \
typename TTypes<T>::ConstTensor& diag, \
typename TTypes<T, 3>::Tensor& output, \
const Eigen::Index lower_diag_index, \
const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, \
const bool left_align_superdiagonal, const bool left_align_subdiagonal); \
extern template struct MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
@ -263,7 +288,13 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("k"), \
MatrixSetDiagOp<GPUDevice, type>); \
REGISTER_KERNEL_BUILDER(Name("MatrixSetDiagV3") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("k"), \
MatrixSetDiagOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
TF_CALL_bool(REGISTER_MATRIX_SET_DIAG_GPU);
TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);

View File

@ -29,8 +29,11 @@ struct MatrixSetDiag {
typename TTypes<T, 3>::ConstTensor& input,
typename TTypes<T>::ConstTensor& diag,
typename TTypes<T, 3>::Tensor& output,
const Eigen::Index d_lower, const Eigen::Index d_upper,
const Eigen::Index max_diag_len);
const Eigen::Index lower_diag_index,
const Eigen::Index upper_diag_index,
const Eigen::Index max_diag_len,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal);
};
} // namespace functor

View File

@ -26,26 +26,42 @@ namespace functor {
typedef Eigen::GpuDevice GPUDevice;
// TODO(penporn): Merge this file with matrix_diag_op_gpu.cu.cc.
__device__ inline int ComputeContentOffset(const int diag_index,
const int max_diag_len,
const int num_rows,
const int num_cols,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
const bool left_align = (diag_index >= 0 && left_align_superdiagonal) ||
(diag_index <= 0 && left_align_subdiagonal);
if (left_align) return 0;
const int y_offset = min(0, diag_index);
const int x_offset = max(0, diag_index);
const int diag_len = min(num_rows + y_offset, num_cols - x_offset);
return max_diag_len - diag_len;
}
template <typename Scalar>
__global__ void MatrixSetDiagKernel(const int num_threads, const int m,
const int n, const int num_diags,
const int max_diag_len,
const int upper_diag_index,
const Scalar* __restrict__ diag_ptr,
Scalar* __restrict__ output_ptr) {
__global__ void MatrixSetDiagKernel(
const int num_threads, const int m, const int n, const int num_diags,
const int max_diag_len, const int upper_diag_index,
const bool left_align_superdiagonal, const bool left_align_subdiagonal,
const Scalar* __restrict__ diag_ptr, Scalar* __restrict__ output_ptr) {
GPU_1D_KERNEL_LOOP(index, num_threads) {
const int batch_and_diag_index = index / max_diag_len;
const int index_in_the_diagonal =
index - batch_and_diag_index * max_diag_len;
int index_in_the_diagonal = index - batch_and_diag_index * max_diag_len;
const int batch = batch_and_diag_index / num_diags;
const int diag_index_in_input = batch_and_diag_index - batch * num_diags;
const int diag_index = upper_diag_index - diag_index_in_input;
const int y_index = index_in_the_diagonal + max(0, -diag_index);
index_in_the_diagonal -=
ComputeContentOffset(diag_index, max_diag_len, m, n,
left_align_superdiagonal, left_align_subdiagonal);
const int y_index = index_in_the_diagonal - min(0, diag_index);
const int x_index = index_in_the_diagonal + max(0, diag_index);
// Upper-bound checks for diagonals shorter than max_diag_len.
// y_index and x_index are nonnegative by construction.
if (y_index < m && x_index < n) {
if (index_in_the_diagonal >= 0 && y_index < m && x_index < n) {
const int out_index = batch * m * n + y_index * n + x_index;
output_ptr[out_index] = diag_ptr[index];
}
@ -56,17 +72,21 @@ template <typename Scalar>
__global__ void MatrixCopyInputAndSetDiagKernel(
const int num_threads, const int m, const int n, const int num_diags,
const int max_diag_len, const int lower_diag_index,
const int upper_diag_index, const Scalar* __restrict__ input_ptr,
const int upper_diag_index, const bool left_align_superdiagonal,
const bool left_align_subdiagonal, const Scalar* __restrict__ input_ptr,
const Scalar* __restrict__ diag_ptr, Scalar* __restrict__ output_ptr) {
GPU_1D_KERNEL_LOOP(index, num_threads) {
const int batch_and_row_index = index / n;
const int col = index - batch_and_row_index * n;
const int batch = batch_and_row_index / m;
const int row = batch_and_row_index - batch * m;
const int d = col - row;
const int diag_index_in_input = upper_diag_index - d;
const int index_in_the_diagonal = col - max(d, 0);
if (lower_diag_index <= d && d <= upper_diag_index) {
const int diag_index = col - row;
const int diag_index_in_input = upper_diag_index - diag_index;
const int index_in_the_diagonal =
col - max(0, diag_index) +
ComputeContentOffset(diag_index, max_diag_len, m, n,
left_align_superdiagonal, left_align_subdiagonal);
if (lower_diag_index <= diag_index && diag_index <= upper_diag_index) {
output_ptr[index] =
diag_ptr[batch * num_diags * max_diag_len +
diag_index_in_input * max_diag_len + index_in_the_diagonal];
@ -84,7 +104,9 @@ struct MatrixSetDiag<GPUDevice, Scalar> {
typename TTypes<Scalar, 3>::Tensor& output,
const Eigen::Index lower_diag_index,
const Eigen::Index upper_diag_index,
const Eigen::Index max_diag_len) {
const Eigen::Index max_diag_len,
const bool left_align_superdiagonal,
const bool left_align_subdiagonal) {
const int batch_size = input.dimension(0);
const int m = input.dimension(1);
const int n = input.dimension(2);
@ -98,15 +120,16 @@ struct MatrixSetDiag<GPUDevice, Scalar> {
MatrixSetDiagKernel<Scalar>, config.block_count,
config.thread_per_block, 0, device.stream(),
config.virtual_thread_count, m, n, num_diags, max_diag_len,
upper_diag_index, diag.data(), output.data()));
upper_diag_index, left_align_superdiagonal, left_align_subdiagonal,
diag.data(), output.data()));
} else {
GpuLaunchConfig config = GetGpuLaunchConfig(batch_size * m * n, device);
TF_CHECK_OK(GpuLaunchKernel(
MatrixCopyInputAndSetDiagKernel<Scalar>, config.block_count,
config.thread_per_block, 0, device.stream(),
config.virtual_thread_count, m, n, num_diags, max_diag_len,
lower_diag_index, upper_diag_index, input.data(), diag.data(),
output.data()));
lower_diag_index, upper_diag_index, left_align_superdiagonal,
left_align_subdiagonal, input.data(), diag.data(), output.data()));
}
}
};

View File

@ -279,29 +279,6 @@ Status SetOutputShapeForReshape(InferenceContext* c) {
return Status::OK();
}
Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
int32* lower_diag_index, int32* upper_diag_index) {
// This function assumes that the shape of diag_index_tensor is fully defined.
if (diag_index_tensor->dims() == 0) {
*lower_diag_index = diag_index_tensor->scalar<int32>()();
*upper_diag_index = *lower_diag_index;
} else {
int32 num_elements = diag_index_tensor->dim_size(0);
if (num_elements == 1) {
*lower_diag_index = diag_index_tensor->vec<int32>()(0);
*upper_diag_index = *lower_diag_index;
} else if (num_elements == 2) {
*lower_diag_index = diag_index_tensor->vec<int32>()(0);
*upper_diag_index = diag_index_tensor->vec<int32>()(1);
} else {
return errors::InvalidArgument(
"diag_index must be a vector with one or two elements. It has ",
num_elements, " elements.");
}
}
return Status::OK();
}
} // namespace
REGISTER_OP("ParallelConcat")
@ -861,107 +838,20 @@ REGISTER_OP("MatrixDiagV2")
.Input("padding_value: T")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
// Checks input ranks.
ShapeHandle input_shape, diag_index_shape, unused_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
.SetShapeFn(shape_inference::MatrixDiagV2Shape);
// Reads the diagonal indices.
const Tensor* diag_index_tensor = c->input_tensor(1);
if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
diag_index_tensor == nullptr) {
c->set_output(0, c->UnknownShape());
return Status::OK();
}
int32 lower_diag_index = 0;
int32 upper_diag_index = 0;
TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
&upper_diag_index));
if (lower_diag_index > upper_diag_index) {
return errors::InvalidArgument(
"lower_diag_index is greater than upper_diag_index");
}
// Checks if the number of diagonals provided matches what we imply from
// lower_diag_index and upper_diag_index.
const int32 input_rank = c->Rank(input_shape);
if (lower_diag_index < upper_diag_index) {
const int32 num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
const int32 other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
return errors::InvalidArgument(
"The number of rows of `diagonal` doesn't match the number of "
"diagonals implied from `d_lower` and `d_upper`.\n",
"num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
", d_upper = ", upper_diag_index, " ", input_rank, " ",
other_dim);
}
}
// Reads num_rows and num_cols.
const Tensor* num_rows_tensor = c->input_tensor(2);
const Tensor* num_cols_tensor = c->input_tensor(3);
int64 num_rows = -1;
int64 num_cols = -1;
if (num_rows_tensor != nullptr) {
TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
}
if (num_cols_tensor != nullptr) {
TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
}
// Infers the missing num_rows or num_cols: If both are missing, assume
// output is square. Otherwise, use the smallest possible value. Also
// validates the provided values.
const int32 max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
if (num_rows == -1 && num_cols == -1) { // Special case.
num_rows = std::max(min_num_rows, min_num_cols);
num_cols = num_rows;
}
if (num_rows == -1) {
num_rows = min_num_rows;
} else if (num_rows < min_num_rows) {
return errors::InvalidArgument("num_rows is too small");
}
if (num_cols == -1) {
num_cols = min_num_cols;
} else if (num_cols < min_num_cols) {
return errors::InvalidArgument("num_cols is too small.");
}
// At least one of them must match the minimum length.
if (num_rows != min_num_rows && num_cols != min_num_cols) {
return errors::InvalidArgument(
"num_rows and num_cols are not consistent with lower_diag_index, "
"upper_diag_index, and the length of the given diagonals.\n",
"num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
}
// Sets output shape.
ShapeHandle output_shape;
const DimensionHandle output_row_dim = c->MakeDim(num_rows);
const DimensionHandle output_col_dim = c->MakeDim(num_cols);
if (lower_diag_index == upper_diag_index) {
TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
output_row_dim, &output_shape));
TF_RETURN_IF_ERROR(c->Concatenate(
output_shape, c->Vector(output_col_dim), &output_shape));
} else {
TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
output_row_dim, &output_shape));
TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
output_col_dim, &output_shape));
}
c->set_output(0, output_shape);
return Status::OK();
});
REGISTER_OP("MatrixDiagV3")
.Input("diagonal: T")
.Input("k: int32")
.Input("num_rows: int32")
.Input("num_cols: int32")
.Input("padding_value: T")
.Output("output: T")
.Attr("T: type")
.Attr(
"align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
"'RIGHT_LEFT'")
.SetShapeFn(shape_inference::MatrixDiagV2Shape);
// --------------------------------------------------------------------------
REGISTER_OP("MatrixSetDiag")
@ -995,84 +885,25 @@ REGISTER_OP("MatrixSetDiag")
c->set_output(0, output);
return Status::OK();
});
REGISTER_OP("MatrixSetDiagV2")
.Input("input: T")
.Input("diagonal: T")
.Input("k: int32")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input_shape, diag_shape, diag_index_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
.SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
int32 lower_diag_index = 0;
int32 upper_diag_index = 0;
bool diag_index_known = false;
const Tensor* diag_index_tensor = c->input_tensor(2);
if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
diag_index_known = true;
TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor,
&lower_diag_index, &upper_diag_index));
if (lower_diag_index > upper_diag_index) {
return errors::InvalidArgument(
"lower_diag_index is greater than upper_diag_index");
}
}
// Do more checks when input rank is known.
if (c->RankKnown(input_shape)) {
int32 input_rank = c->Rank(input_shape);
// If diag_index is set, we know the exact rank of diagonal.
if (diag_index_known) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(1),
(lower_diag_index == upper_diag_index)
? input_rank - 1
: input_rank,
&diag_shape));
} else {
TF_RETURN_IF_ERROR(
c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
TF_RETURN_IF_ERROR(
c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
}
// Validates lower_diag_index and upper_diag_index.
const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
if (num_rows != InferenceContext::kUnknownDim &&
num_cols != InferenceContext::kUnknownDim) {
if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
(-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
return errors::InvalidArgument("lower_diag_index is out of bound.");
}
if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
(-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
return errors::InvalidArgument("upper_diag_index is out of bound.");
}
}
}
ShapeHandle output_shape = input_shape;
if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
// Try to infer parts of shape from diag.
ShapeHandle diag_prefix;
TF_RETURN_IF_ERROR(c->Subshape(
diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
&diag_prefix));
// The inner matrices can be rectangular, so we can't pinpoint their
// exact height and width by just lower_diag_index, upper_diag_index,
// and the longest length of given diagonals.
TF_RETURN_IF_ERROR(
c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
}
c->set_output(0, output_shape);
return Status::OK();
});
REGISTER_OP("MatrixSetDiagV3")
.Input("input: T")
.Input("diagonal: T")
.Input("k: int32")
.Output("output: T")
.Attr("T: type")
.Attr(
"align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
"'RIGHT_LEFT'")
.SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
// --------------------------------------------------------------------------
REGISTER_OP("MatrixDiagPart")
@ -1105,58 +936,18 @@ REGISTER_OP("MatrixDiagPartV2")
.Input("padding_value: T")
.Output("diagonal: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input_shape, diag_index_shape, unused_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
.SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
const Tensor* diag_index_tensor = c->input_tensor(1);
if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
diag_index_tensor == nullptr) {
c->set_output(0, c->UnknownShape());
return Status::OK();
}
int32 lower_diag_index = 0;
int32 upper_diag_index = 0;
TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
&upper_diag_index));
if (lower_diag_index > upper_diag_index) {
return errors::InvalidArgument(
"lower_diag_index is greater than upper_diag_index");
}
// Validates lower_diag_index and upper_diag_index.
const int32 input_rank = c->Rank(input_shape);
const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
if (num_rows != InferenceContext::kUnknownDim &&
num_cols != InferenceContext::kUnknownDim) {
if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
(-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
return errors::InvalidArgument("lower_diag_index is out of bound.");
}
if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
(-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
return errors::InvalidArgument("upper_diag_index is out of bound.");
}
}
const int32 max_diag_len =
std::min(num_rows + std::min(upper_diag_index, 0),
num_cols - std::max(lower_diag_index, 0));
std::vector<DimensionHandle> dims;
dims.reserve(input_rank - 2);
for (int i = 0; i < input_rank - 2; ++i) {
dims.push_back(c->Dim(input_shape, i));
}
if (lower_diag_index < upper_diag_index) {
dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
}
dims.push_back(c->MakeDim(max_diag_len));
c->set_output(0, c->MakeShape(dims));
return Status::OK();
});
REGISTER_OP("MatrixDiagPartV3")
.Input("input: T")
.Input("k: int32")
.Input("padding_value: T")
.Output("diagonal: T")
.Attr("T: type")
.Attr(
"align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
"'RIGHT_LEFT'")
.SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
// --------------------------------------------------------------------------
REGISTER_OP("MatrixBandPart")

View File

@ -180,8 +180,8 @@ cc_library(
name = "graph_transformations",
srcs = [
"graph_transformations/convert_expanddims_to_reshape.cc",
"graph_transformations/convert_matrix_diag_v2_to_v1.cc",
"graph_transformations/convert_matrix_set_diag_v2_to_v1.cc",
"graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc",
"graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc",
"graph_transformations/convert_pure_conv_to_depthwise.cc",
"graph_transformations/convert_reorder_axes.cc",
"graph_transformations/convert_squeeze_to_reshape.cc",

View File

@ -20,13 +20,16 @@ limitations under the License.
namespace toco {
::tensorflow::Status ConvertMatrixDiagV2ToV1::Run(Model* model,
std::size_t op_index,
bool* modified) {
// V3 is only different from V2 because it has an extra attribute (align).
// This attribute doesn't affect V1 so we don't have to keep track of it here.
::tensorflow::Status ConvertMatrixDiagV2OrV3ToV1::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
const auto* op = it->get();
if (op->type != OperatorType::kMatrixDiagV2) {
if (op->type != OperatorType::kMatrixDiagV2 &&
op->type != OperatorType::kMatrixDiagV3) {
return ::tensorflow::Status::OK();
}

View File

@ -26,13 +26,16 @@ limitations under the License.
namespace toco {
::tensorflow::Status ConvertMatrixSetDiagV2ToV1::Run(Model* model,
std::size_t op_index,
bool* modified) {
// V3 is only different from V2 because it has an extra attribute (align).
// This attribute doesn't affect V1 so we don't have to keep track of it here.
::tensorflow::Status ConvertMatrixSetDiagV2OrV3ToV1::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
auto it = model->operators.begin() + op_index;
const auto* op = it->get();
if (op->type != OperatorType::kMatrixSetDiagV2) {
if (op->type != OperatorType::kMatrixSetDiagV2 &&
op->type != OperatorType::kMatrixSetDiagV3) {
return ::tensorflow::Status::OK();
}

View File

@ -123,8 +123,8 @@ inline void RunGraphTransformations(
// List of all graph transformations
DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2ToV1)
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2ToV1)
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2OrV3ToV1)
DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2OrV3ToV1)
DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape)

View File

@ -2434,6 +2434,14 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
// MatrixDiagV2 operators are converted to MatrixDiag, after which their
// shapes are propagated.
break;
case OperatorType::kMatrixDiagV3:
// MatrixDiagV3 operators are converted to MatrixDiag, after which their
// shapes are propagated.
break;
case OperatorType::kMatrixSetDiagV3:
// MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which
// their shapes are propagated.
break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);

View File

@ -2560,8 +2560,16 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"MatMul", ConvertMatMulOperator},
{"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
{"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
// `MatrixDiagV3` has an `align` attribute. However, Toco only converts
// `MatrixDiagV3` to `MatrixDiag` with default `k, num_rows, num_cols,
// padding_value` inputs. In this case, `align` can be ignored.
{"MatrixDiagV3", ConvertSimpleOperator<MatrixDiagV3Operator, 5, 1>},
{"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
{"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>},
// `MatrixSetDiagV3` has an `align` attribute. However, Toco only converts
// `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs. In this
// case, `align` can be ignored.
{"MatrixSetDiagV3", ConvertSimpleOperator<MatrixSetDiagV3Operator, 3, 1>},
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator},
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},

View File

@ -175,7 +175,9 @@ enum class OperatorType : uint8 {
kMatrixDiag,
kMatrixSetDiag,
kMatrixDiagV2,
kMatrixSetDiagV2
kMatrixSetDiagV2,
kMatrixDiagV3,
kMatrixSetDiagV3
};
// Helper to deal with TensorFlow arrays using a different ordering of
@ -2122,12 +2124,24 @@ struct MatrixDiagOperator : Operator {
// Matrix Diag Operator V2:
// Construct a batched diagonal tensor with given batched diagonal values.
// Not fully supported, constains 4 extra inputs compared to MatrixDiag, support
// default parameters settings which performs the same as MatrixDiag
// Not fully supported, contains 4 extra inputs compared to MatrixDiag. Behave
// like MatrixDiag when default parameters are used.
struct MatrixDiagV2Operator : Operator {
MatrixDiagV2Operator() : Operator(OperatorType::kMatrixDiagV2) {}
};
// Matrix Diag Operator V3:
// Construct a batched diagonal tensor with given batched diagonal values.
// Not fully supported, contains 5 extra inputs compared to MatrixDiag. Behave
// like MatrixDiag when default parameters are used.
// V3 is only different from V2 because it has an extra attribute (align) which
// controls the alignment of diagonals in the band matrix (compact) format.
// The alignment in V2 contradicts with the default alignment in V3 so V2 is
// skipped. (It has never been, and should never be, exposed in the public API.)
struct MatrixDiagV3Operator : Operator {
MatrixDiagV3Operator() : Operator(OperatorType::kMatrixDiagV3) {}
};
// Matrix Set Diag Operator:
// Construct a batched diagonal tensor with given input and diagonal values.
// Input is a rank (k+1) tensor of values.
@ -2140,12 +2154,24 @@ struct MatrixSetDiagOperator : Operator {
// Matrix Set Diag Operator V2:
// Construct a batched diagonal tensor with given input and diagonal values.
// Not fully supported, constains 1 extra inputs compared to MatrixSetDiag,
// support default parameters settings which performs the same as MatrixSetDiag
// Not fully supported, contains 1 extra inputs compared to MatrixSetDiag.
// Behave like MatrixSetDiag when default parameters are used.
struct MatrixSetDiagV2Operator : Operator {
MatrixSetDiagV2Operator() : Operator(OperatorType::kMatrixSetDiagV2) {}
};
// Matrix Set Diag Operator V3:
// Construct a batched diagonal tensor with given input and diagonal values.
// Not fully supported, contains 2 extra inputs compared to MatrixSetDiag.
// Behave like MatrixSetDiag when default parameters are used.
// V3 is only different from V2 because it has an extra attribute (align) which
// controls the alignment of diagonals in the band matrix (compact) format.
// The alignment in V2 contradicts with the default alignment in V3 so V2 is
// skipped. (It has never been, and should never be, exposed in the public API.)
struct MatrixSetDiagV3Operator : Operator {
MatrixSetDiagV3Operator() : Operator(OperatorType::kMatrixSetDiagV3) {}
};
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are

View File

@ -54,8 +54,8 @@ void MakeGeneralGraphTransformationsSet(
GraphTransformationsSet* transformations) {
CHECK(transformations->empty());
transformations->Add(new ConvertExpandDimsToReshape);
transformations->Add(new ConvertMatrixDiagV2ToV1);
transformations->Add(new ConvertMatrixSetDiagV2ToV1);
transformations->Add(new ConvertMatrixDiagV2OrV3ToV1);
transformations->Add(new ConvertMatrixSetDiagV2OrV3ToV1);
transformations->Add(new ConvertSqueezeToReshape);
transformations->Add(new ConvertTrivialAddNToAdd);
transformations->Add(new ConvertTrivialPackToReshape);

View File

@ -449,6 +449,8 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV3)
HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV3)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE

File diff suppressed because it is too large Load Diff

View File

@ -341,6 +341,12 @@ def _MatrixDiagV2Grad(op, grad):
grad, k=op.inputs[1]), None, None, None, None
@ops.RegisterGradient("MatrixDiagV3")
def _MatrixDiagV3Grad(op, grad):
return array_ops.matrix_diag_part(
grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None
@ops.RegisterGradient("MatrixDiagPart")
def _MatrixDiagPartGrad(op, grad):
matrix_shape = op.inputs[0].get_shape()[-2:]
@ -362,8 +368,25 @@ def _MatrixDiagPartV2Grad(op, grad):
num_cols=matrix_shape[1]), None, None
else:
return array_ops.matrix_set_diag(
array_ops.zeros_like(op.inputs[0]), grad,
k=op.inputs[1]), None, None
array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None
@ops.RegisterGradient("MatrixDiagPartV3")
def _MatrixDiagPartV3Grad(op, grad):
"""Gradient for MatrixDiagPartV3."""
matrix_shape = op.inputs[0].get_shape()[-2:]
align = op.get_attr("align")
if matrix_shape.is_fully_defined():
return array_ops.matrix_diag(
grad,
k=op.inputs[1],
num_rows=matrix_shape[0],
num_cols=matrix_shape[1],
align=align), None, None
else:
return array_ops.matrix_set_diag(
array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1],
align=align), None, None
@ops.RegisterGradient("MatrixSetDiag")
@ -392,7 +415,7 @@ def _MatrixSetDiagGrad(op, grad):
@ops.RegisterGradient("MatrixSetDiagV2")
def _MatrixSetDiagGradV2(op, grad):
"""Gradient for MatrixSetDiag."""
"""Gradient for MatrixSetDiagV2."""
diag_shape = op.inputs[1].get_shape()
if not diag_shape.is_fully_defined():
# Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
@ -426,6 +449,46 @@ def _MatrixSetDiagGradV2(op, grad):
return (grad_input, grad_diag, None)
@ops.RegisterGradient("MatrixSetDiagV3")
def _MatrixSetDiagGradV3(op, grad):
"""Gradient for MatrixSetDiagV3."""
diag_shape = op.inputs[1].get_shape()
align = op.get_attr("align")
if not diag_shape.is_fully_defined():
# Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
grad_shape = array_ops.shape(grad)
batch_shape = grad_shape[:-2]
matrix_shape = grad_shape[-2:]
diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector.
d_lower = diag_index[0]
d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2.
y_offset = control_flow_ops.cond(
math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
x_offset = control_flow_ops.cond(
math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
matrix_shape[1] + x_offset)
# pylint: disable=g-long-lambda
# pyformat: disable
postfix = control_flow_ops.cond(
math_ops.equal(d_lower, d_upper),
lambda: ops.convert_to_tensor([max_diag_len]),
lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
max_diag_len]))
# pyformat: enable
# pylint: enable=g-long-lambda
diag_shape = array_ops.concat([batch_shape, postfix], 0)
grad_input = array_ops.matrix_set_diag(
grad,
array_ops.zeros(diag_shape, dtype=grad.dtype),
k=op.inputs[2],
align=align)
grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align)
return (grad_input, grad_diag, None)
@ops.RegisterGradient("MatrixBandPart")
def _MatrixBandPartGrad(op, grad):
num_lower = op.inputs[1]

View File

@ -54,6 +54,13 @@ tf_export("newaxis").export_constant(__name__, "newaxis")
# existing 'slice' for later use in this module.
_BaseSlice = slice
# LINT.IfChange
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
# LINT.ThenChange(
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
# //tensorflow/python/kernel_tests/diag_op_test.py,
# //tensorflow/python/ops/parallel_for/array_test.py
# )
@tf_export("reshape", v1=["reshape", "manip.reshape"])
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
@ -2043,7 +2050,8 @@ def matrix_diag(diagonal,
k=0,
num_rows=-1,
num_cols=-1,
padding_value=0):
padding_value=0,
align="RIGHT_LEFT"):
"""Returns a batched diagonal tensor with given batched diagonal values.
Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th
@ -2078,7 +2086,17 @@ def matrix_diag(diagonal,
padding_value ; otherwise
```
where `d = n - m`, `diag_index = k[1] - d`, and
`index_in_diag = n - max(d, 0)`.
`index_in_diag = n - max(d, 0) + offset`.
`offset` is zero except when the alignment of the diagonal is to the right.
```
offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
and `d >= 0`) or
(`align` in {LEFT_RIGHT, RIGHT_RIGHT}
and `d <= 0`)
0 ; otherwise
```
where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
For example:
@ -2108,17 +2126,34 @@ def matrix_diag(diagonal,
[0, 0, 0, 6],
[0, 0, 0, 0]]]
# A band of diagonals.
diagonals = np.array([[[1, 2, 3], # Input shape: (2, 2, 3)
[4, 5, 0]],
[[6, 7, 9],
[9, 1, 0]]])
tf.matrix_diag(diagonals, k = (-1, 0))
==> [[[1, 0, 0], # Output shape: (2, 3, 3)
[4, 2, 0],
# A tridiagonal band (per batch).
diagonals = np.array([[[8, 9, 0], # Input shape: (2, 2, 3)
[1, 2, 3],
[0, 4, 5]],
[[2, 3, 0],
[6, 7, 9],
[0, 9, 1]]])
tf.matrix_diag(diagonals, k = (-1, 1))
==> [[[1, 8, 0], # Output shape: (2, 3, 3)
[4, 2, 9],
[0, 5, 3]],
[[6, 0, 0],
[9, 7, 0],
[[6, 2, 0],
[9, 7, 3],
[0, 1, 9]]]
# RIGHT_LEFT alignment.
diagonals = np.array([[[0, 8, 9], # Input shape: (2, 2, 3)
[1, 2, 3],
[4, 5, 0]],
[[0, 2, 3],
[6, 7, 9],
[9, 1, 0]]])
tf.matrix_diag(diagonals, k = (-1, 1), align="RIGHT_LEFT")
==> [[[1, 8, 0], # Output shape: (2, 3, 3)
[4, 2, 9],
[0, 5, 3]],
[[6, 2, 0],
[9, 7, 3],
[0, 1, 9]]]
# Rectangular matrix.
@ -2150,27 +2185,34 @@ def matrix_diag(diagonal,
size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
padding_value: The value to fill the area outside the specified diagonal
band with. Default is 0.
align: Some diagonals are shorter than `max_diag_len` and need to be padded.
`align` is a string specifying how superdiagonals and subdiagonals should
be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
(default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
aligns superdiagonals to the right (left-pads the row) and subdiagonals to
the left (right-pads the row). It is the packing format LAPACK uses.
cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
Returns:
A Tensor. Has the same type as `diagonal`.
"""
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
# Special case to sidestep the tf.constant conversion error:
# TypeError: Expected bool, got 0 of type 'int' instead.
if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
padding_value = bool(padding_value)
return gen_array_ops.matrix_diag_v2(
return gen_array_ops.matrix_diag_v3(
diagonal=diagonal,
k=k,
num_rows=num_rows,
num_cols=num_cols,
padding_value=padding_value,
align=align,
name=name)
# Call v1 to maintain forward compatibility.
# (We skip v2 because its alignment conflicts with v3's default alignment.)
return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)
@ -2181,7 +2223,8 @@ def matrix_diag_part(
input, # pylint:disable=redefined-builtin
name="diag_part",
k=0,
padding_value=0):
padding_value=0,
align="RIGHT_LEFT"):
"""Returns the batched diagonal part of a batched tensor.
Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched
@ -2211,7 +2254,17 @@ def matrix_diag_part(
= input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
padding_value ; otherwise.
```
where `d = k[1] - m`, `y = max(-d, 0)`, and `x = max(d, 0)`.
where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`.
`offset` is zero except when the alignment of the diagonal is to the right.
```
offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
and `d >= 0`) or
(`align` in {LEFT_RIGHT, RIGHT_RIGHT}
and `d <= 0`)
0 ; otherwise
```
where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
The input must be at least a matrix.
@ -2234,16 +2287,36 @@ def matrix_diag_part(
==> [[2, 7, 6], # Output shape: (2, 3)
[4, 3, 8]]
# A tridiagonal band from each batch.
tf.matrix_diag_part(input, k = (-1, 1))
==> [[[2, 7, 6], # Output shape: (2, 3, 3)
# A band from each batch.
tf.matrix_diag_part(input, k = (-1, 2))
==> [[[3, 8, 0], # Output shape: (2, 4, 3)
[2, 7, 6],
[1, 6, 7],
[0, 5, 8]],
[[3, 4, 0],
[4, 3, 8],
[5, 2, 7],
[0, 1, 6]]]
# RIGHT_LEFT alignment.
tf.matrix_diag_part(input, k = (-1, 2), align="RIGHT_LEFT")
==> [[[0, 3, 8], # Output shape: (2, 4, 3)
[2, 7, 6],
[1, 6, 7],
[5, 8, 0]],
[[4, 3, 8],
[[0, 3, 4],
[4, 3, 8],
[5, 2, 7],
[1, 6, 0]]]
# Padding value = 9
# max_diag_len can be shorter than the main diagonal.
tf.matrix_diag_part(input, k = (-2, -1))
==> [[[5, 8],
[0, 9]],
[[1, 6],
[0, 5]]]
# padding_value = 9
tf.matrix_diag_part(input, k = (1, 3), padding_value = 9)
==> [[[4, 9, 9], # Output shape: (2, 3, 3)
[3, 8, 9],
@ -2251,6 +2324,7 @@ def matrix_diag_part(
[[2, 9, 9],
[3, 4, 9],
[4, 3, 8]]]
```
Args:
@ -2262,23 +2336,28 @@ def matrix_diag_part(
and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
padding_value: The value to fill the area outside the specified diagonal
band with. Default is 0.
align: Some diagonals are shorter than `max_diag_len` and need to be padded.
`align` is a string specifying how superdiagonals and subdiagonals should
be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
(default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
aligns superdiagonals to the right (left-pads the row) and subdiagonals to
the left (right-pads the row). It is the packing format LAPACK uses.
cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
Returns:
A Tensor containing diagonals of `input`. Has the same type as `input`.
"""
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
# Special case to sidestep the tf.constant conversion error:
# TypeError: Expected bool, got 0 of type 'int' instead.
if hasattr(input, "dtype") and input.dtype == "bool":
padding_value = bool(padding_value)
return gen_array_ops.matrix_diag_part_v2(
input=input, k=k, padding_value=padding_value, name=name)
return gen_array_ops.matrix_diag_part_v3(
input=input, k=k, padding_value=padding_value, align=align, name=name)
# Call v1 to maintain forward compatibility.
# (We skip v2 because its alignment conflicts with v3's default alignment.)
return gen_array_ops.matrix_diag_part(input=input, name=name)
@ -2288,7 +2367,8 @@ def matrix_set_diag(
input, # pylint:disable=redefined-builtin
diagonal,
name="set_diag",
k=0):
k=0,
align="RIGHT_LEFT"):
"""Returns a batched matrix tensor with new batched diagonal values.
Given `input` and `diagonal`, this operation returns a tensor with the
@ -2319,7 +2399,17 @@ def matrix_set_diag(
input[i, j, ..., l, m, n] ; otherwise
```
where `d = n - m`, `diag_index = k[1] - d`, and
`index_in_diag = n - max(d, 0)`.
`index_in_diag = n - max(d, 0) + offset`.
`offset` is zero except when the alignment of the diagonal is to the right.
```
offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
and `d >= 0`) or
(`align` in {LEFT_RIGHT, RIGHT_RIGHT}
and `d <= 0`)
0 ; otherwise
```
where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
For example:
@ -2333,15 +2423,16 @@ def matrix_set_diag(
[7, 7, 7, 7]]])
diagonal = np.array([[1, 2, 3], # Diagonal shape: (2, 3)
[4, 5, 6]])
tf.matrix_set_diag(diagonal) ==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
[7, 2, 7, 7],
[7, 7, 3, 7]],
[[4, 7, 7, 7],
[7, 5, 7, 7],
[7, 7, 6, 7]]]
tf.matrix_set_diag(input, diagonal)
==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
[7, 2, 7, 7],
[7, 7, 3, 7]],
[[4, 7, 7, 7],
[7, 5, 7, 7],
[7, 7, 6, 7]]]
# A superdiagonal (per batch).
tf.matrix_set_diag(diagonal, k = 1)
tf.matrix_set_diag(input, diagonal, k = 1)
==> [[[7, 1, 7, 7], # Output shape: (2, 3, 4)
[7, 7, 2, 7],
[7, 7, 7, 3]],
@ -2350,17 +2441,38 @@ def matrix_set_diag(
[7, 7, 7, 6]]]
# A band of diagonals.
diagonals = np.array([[[1, 2, 3], # Diagonal shape: (2, 2, 3)
diagonals = np.array([[[9, 1, 0], # Diagonal shape: (2, 4, 3)
[6, 5, 8],
[1, 2, 3],
[0, 4, 5]],
[[1, 2, 0],
[5, 6, 4],
[6, 1, 2],
[0, 3, 4]]])
tf.matrix_set_diag(input, diagonals, k = (-1, 2))
==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4)
[4, 2, 5, 1],
[7, 5, 3, 8]],
[[6, 5, 1, 7],
[3, 1, 6, 2],
[7, 4, 2, 4]]]
# RIGHT_LEFT alignment.
diagonals = np.array([[[0, 9, 1], # Diagonal shape: (2, 4, 3)
[6, 5, 8],
[1, 2, 3],
[4, 5, 0]],
[[6, 1, 2],
[[0, 1, 2],
[5, 6, 4],
[6, 1, 2],
[3, 4, 0]]])
tf.matrix_set_diag(diagonals, k = (-1, 0))
==> [[[1, 7, 7, 7], # Output shape: (2, 3, 4)
[4, 2, 7, 7],
[0, 5, 3, 7]],
[[6, 7, 7, 7],
[3, 1, 7, 7],
[7, 4, 2, 7]]]
tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="RIGHT_LEFT")
==> [[[1, 6, 9, 7], # Output shape: (2, 3, 4)
[4, 2, 5, 1],
[7, 5, 3, 8]],
[[6, 5, 1, 7],
[3, 1, 6, 2],
[7, 4, 2, 4]]]
```
@ -2373,14 +2485,20 @@ def matrix_set_diag(
main diagonal, and negative value means subdiagonals. `k` can be a single
integer (for a single diagonal) or a pair of integers specifying the low
and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
align: Some diagonals are shorter than `max_diag_len` and need to be padded.
`align` is a string specifying how superdiagonals and subdiagonals should
be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
(default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
aligns superdiagonals to the right (left-pads the row) and subdiagonals to
the left (right-pads the row). It is the packing format LAPACK uses.
cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
"""
# LINT.IfChange
if compat.forward_compatible(2019, 11, 30):
# LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
return gen_array_ops.matrix_set_diag_v2(
input=input, diagonal=diagonal, k=k, name=name)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
return gen_array_ops.matrix_set_diag_v3(
input=input, diagonal=diagonal, k=k, align=align, name=name)
# Call v1 to maintain forward compatibility.
# (We skip v2 because its alignment conflicts with v3's default alignment.)
return gen_array_ops.matrix_set_diag(
input=input, diagonal=diagonal, name=name)

View File

@ -1111,10 +1111,17 @@ def _cholesky(input, name=None): # pylint:disable=redefined-builtin
# The signature has to match with the one in python/op/array_ops.py,
# so we have k and padding_value even though we don't use them here.
# so we have k, padding_value, and align even though we don't use them here.
# pylint:disable=unused-argument
@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator)
def _diag_part(input, name="diag_part", k=0, padding_value=0): # pylint:disable=redefined-builtin, unused-argument
def _diag_part(
input, # pylint:disable=redefined-builtin
name="diag_part",
k=0,
padding_value=0,
align="RIGHT_LEFT"):
return input.diag_part(name)
# pylint:enable=unused-argument
@dispatch.dispatch_for_types(linalg.det, LinearOperator)

View File

@ -33,6 +33,15 @@ from tensorflow.python.ops.parallel_for.test_util import PForTestCase
from tensorflow.python.platform import test
# LINT.IfChange
matrix_diag_v3_forward_compat_date = (2019, 12, 6)
# LINT.ThenChange(
# //tensorflow/compiler/tests/matrix_diag_ops_test.py,
# //tensorflow/python/kernel_tests/diag_op_test.py,
# //tensorflow/python/ops/array_ops.py
# )
@test_util.run_all_in_graph_and_eager_modes
class ArrayTest(PForTestCase):
@ -336,8 +345,9 @@ class ArrayTest(PForTestCase):
def loop_fn(i):
diagonal = array_ops.gather(x, i)
if compat.forward_compatible(2019, 11, 30):
return array_ops.matrix_diag(diagonal, k=(0, 1), num_rows=4, num_cols=5)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
return array_ops.matrix_diag(
diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
return array_ops.matrix_diag(diagonal)
self._test_loop_fn(loop_fn, 3)
@ -347,8 +357,9 @@ class ArrayTest(PForTestCase):
def loop_fn(i):
input = array_ops.gather(x, i) # pylint: disable=redefined-builtin
if compat.forward_compatible(2019, 11, 30):
return array_ops.matrix_diag_part(input, k=(-2, 0), padding_value=3)
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
return array_ops.matrix_diag_part(
input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
return array_ops.matrix_diag_part(input)
self._test_loop_fn(loop_fn, 3)
@ -356,8 +367,7 @@ class ArrayTest(PForTestCase):
def test_matrix_set_diag(self):
matrices = random_ops.random_uniform([3, 4, 4])
diags = random_ops.random_uniform([3, 4])
if compat.forward_compatible(2019, 11, 30):
bands = random_ops.random_uniform([3, 3, 4])
bands = random_ops.random_uniform([3, 3, 4])
def loop_fn(i):
matrix_i = array_ops.gather(matrices, i)
@ -365,16 +375,20 @@ class ArrayTest(PForTestCase):
results = [
array_ops.matrix_set_diag(matrix_i, diag_i),
array_ops.matrix_set_diag(matrices[0, ...], diag_i),
array_ops.matrix_set_diag(matrix_i, diags[0, ...])
array_ops.matrix_set_diag(matrix_i, diags[0, ...]),
]
if compat.forward_compatible(2019, 11, 30):
if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
k = (-1, 1)
band_i = array_ops.gather(bands, i)
results.extend([
array_ops.matrix_set_diag(matrix_i, band_i, k=k),
array_ops.matrix_set_diag(matrices[0, ...], band_i, k=k),
array_ops.matrix_set_diag(matrix_i, bands[0, ...], k=k)
])
for align in ["RIGHT_LEFT", "LEFT_RIGHT"]:
results.extend([
array_ops.matrix_set_diag(matrix_i, band_i, k=k, align=align),
array_ops.matrix_set_diag(
matrices[0, ...], band_i, k=k, align=align),
array_ops.matrix_set_diag(
matrix_i, bands[0, ...], k=k, align=align)
])
return results
self._test_loop_fn(loop_fn, 3)

View File

@ -1902,43 +1902,55 @@ def _convert_matrix_set_diag(pfor_input):
return wrap(array_ops.matrix_set_diag(t, diag), True)
# Registrations for MatrixDiagV2, MatrixDiagPartv2, and MatrixSetDiagV2.
# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3.
# The input orders defined in the OpKernel and the actual python API are
# different (for compatibility with V1), so we cannot use _convert_identity.
# v2 is not compatible with v3 and is never exposed on the public API.
@RegisterPFor("MatrixDiagV2")
@RegisterPFor("MatrixDiagV3")
def _convert_matrix_diag_v2(pfor_input):
diagonal = pfor_input.stacked_input(0)
k = pfor_input.unstacked_input(1)
num_rows = pfor_input.unstacked_input(2)
num_cols = pfor_input.unstacked_input(3)
padding_value = pfor_input.unstacked_input(4)
return wrap(
array_ops.matrix_diag(
diagonal,
k=k,
num_rows=num_rows,
num_cols=num_cols,
padding_value=padding_value), True)
params = {
"diagonal": pfor_input.stacked_input(0),
"k": pfor_input.unstacked_input(1),
"num_rows": pfor_input.unstacked_input(2),
"num_cols": pfor_input.unstacked_input(3),
"padding_value": pfor_input.unstacked_input(4)
}
if pfor_input.op_type == "MatrixDiagV2":
return wrap(array_ops.matrix_diag_v2(**params), True)
params["align"] = pfor_input.get_attr("align")
return wrap(array_ops.matrix_diag(**params), True)
# See notes for MatrixDiagV2
@RegisterPFor("MatrixDiagPartV2")
@RegisterPFor("MatrixDiagPartV3")
def _convert_matrix_diag_part_v2(pfor_input):
input = pfor_input.stacked_input(0) # pylint:disable=redefined-builtin
k = pfor_input.unstacked_input(1)
padding_value = pfor_input.unstacked_input(2)
return wrap(
array_ops.matrix_diag_part(input, k=k, padding_value=padding_value), True)
params = {
"input": pfor_input.stacked_input(0),
"k": pfor_input.unstacked_input(1),
"padding_value": pfor_input.unstacked_input(2)
}
if pfor_input.op_type == "MatrixDiagPartV2":
return wrap(array_ops.matrix_diag_part_v2(**params), True)
params["align"] = pfor_input.get_attr("align")
return wrap(array_ops.matrix_diag_part(**params), True)
# See notes for MatrixDiagV2
@RegisterPFor("MatrixSetDiagV2")
@RegisterPFor("MatrixSetDiagV3")
def _convert_matrix_set_diag_v2(pfor_input):
pfor_input.stack_inputs([0, 1])
input = pfor_input.stacked_input(0) # pylint:disable=redefined-builtin
diagonal = pfor_input.stacked_input(1)
k = pfor_input.unstacked_input(2)
return wrap(array_ops.matrix_set_diag(input, diagonal, k=k), True)
params = {
"input": pfor_input.stacked_input(0),
"diagonal": pfor_input.stacked_input(1),
"k": pfor_input.unstacked_input(2)
}
if pfor_input.op_type == "MatrixSetDiagV2":
return wrap(array_ops.matrix_set_diag_v2(**params), True)
params["align"] = pfor_input.get_attr("align")
return wrap(array_ops.matrix_set_diag(**params), True)
@RegisterPFor("OneHot")

View File

@ -102,11 +102,11 @@ tf_module {
}
member_method {
name: "diag"
argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], "
argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "diag_part"
argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "eigh"
@ -202,7 +202,7 @@ tf_module {
}
member_method {
name: "set_diag"
argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], "
argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "slogdet"

View File

@ -1646,11 +1646,11 @@ tf_module {
}
member_method {
name: "matrix_diag"
argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], "
argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "matrix_diag_part"
argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "matrix_inverse"
@ -1658,7 +1658,7 @@ tf_module {
}
member_method {
name: "matrix_set_diag"
argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], "
argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "matrix_solve"

View File

@ -2172,10 +2172,18 @@ tf_module {
name: "MatrixDiagPartV2"
argspec: "args=[\'input\', \'k\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "MatrixDiagPartV3"
argspec: "args=[\'input\', \'k\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
}
member_method {
name: "MatrixDiagV2"
argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "MatrixDiagV3"
argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
}
member_method {
name: "MatrixExponential"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -2196,6 +2204,10 @@ tf_module {
name: "MatrixSetDiagV2"
argspec: "args=[\'input\', \'diagonal\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "MatrixSetDiagV3"
argspec: "args=[\'input\', \'diagonal\', \'k\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
}
member_method {
name: "MatrixSolve"
argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "

View File

@ -102,11 +102,11 @@ tf_module {
}
member_method {
name: "diag"
argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], "
argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "diag_part"
argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "eig"
@ -210,7 +210,7 @@ tf_module {
}
member_method {
name: "set_diag"
argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], "
argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], "
}
member_method {
name: "slogdet"

View File

@ -2172,10 +2172,18 @@ tf_module {
name: "MatrixDiagPartV2"
argspec: "args=[\'input\', \'k\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "MatrixDiagPartV3"
argspec: "args=[\'input\', \'k\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
}
member_method {
name: "MatrixDiagV2"
argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "MatrixDiagV3"
argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
}
member_method {
name: "MatrixExponential"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -2196,6 +2204,10 @@ tf_module {
name: "MatrixSetDiagV2"
argspec: "args=[\'input\', \'diagonal\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "MatrixSetDiagV3"
argspec: "args=[\'input\', \'diagonal\', \'k\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
}
member_method {
name: "MatrixSolve"
argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "