Roll back to fix broken test.
PiperOrigin-RevId: 299338882 Change-Id: I3714947038653b3567cb562880b6ca577a3e74fa
This commit is contained in:
parent
5baff53e4e
commit
5cd4fc0d0c
@ -1911,6 +1911,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
|||||||
"LinSpace",
|
"LinSpace",
|
||||||
"ListDiff",
|
"ListDiff",
|
||||||
"LogMatrixDeterminant",
|
"LogMatrixDeterminant",
|
||||||
|
"LowerBound",
|
||||||
"MatMul",
|
"MatMul",
|
||||||
"MatrixBandPart",
|
"MatrixBandPart",
|
||||||
"MatrixDiag",
|
"MatrixDiag",
|
||||||
@ -2037,6 +2038,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
|||||||
"TensorScatterUpdate",
|
"TensorScatterUpdate",
|
||||||
"TridiagonalSolve",
|
"TridiagonalSolve",
|
||||||
"TruncatedNormal",
|
"TruncatedNormal",
|
||||||
|
"UpperBound",
|
||||||
"UnsortedSegmentMax",
|
"UnsortedSegmentMax",
|
||||||
"UnsortedSegmentMin",
|
"UnsortedSegmentMin",
|
||||||
"UnsortedSegmentProd",
|
"UnsortedSegmentProd",
|
||||||
|
@ -338,6 +338,21 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "searchsorted_op_test",
|
||||||
|
size = "small",
|
||||||
|
timeout = "moderate",
|
||||||
|
srcs = ["searchsorted_op_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "svd_op_test",
|
name = "svd_op_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
75
tensorflow/compiler/tests/searchsorted_op_test.py
Normal file
75
tensorflow/compiler/tests/searchsorted_op_test.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test for XLA implementation of tf.searchsorted."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class SearchSorteddOpTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
def test1D(self):
|
||||||
|
# Test against NumPy implementation (which is 1D only).
|
||||||
|
np.random.seed(1)
|
||||||
|
for side in ['left', 'right']:
|
||||||
|
for dtype in [np.float32, np.int32]:
|
||||||
|
values = np.random.uniform(
|
||||||
|
low=-1000, high=1000, size=(10,)).astype(dtype)
|
||||||
|
unsorted = np.random.uniform(
|
||||||
|
low=-1000, high=1000, size=(20,)).astype(dtype)
|
||||||
|
|
||||||
|
sorted_sequence = np.sort(unsorted)
|
||||||
|
np_ans = np.searchsorted(sorted_sequence, values, side=side)
|
||||||
|
|
||||||
|
with self.session() as session:
|
||||||
|
with self.test_scope():
|
||||||
|
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
|
||||||
|
tf_out = session.run(tf_ans)
|
||||||
|
self.assertAllEqual(np_ans, tf_out)
|
||||||
|
|
||||||
|
def _test2DExample(self, dtype, side, sorted_sequence, values, correct_ans):
|
||||||
|
|
||||||
|
with self.session() as session:
|
||||||
|
with self.test_scope():
|
||||||
|
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
|
||||||
|
tf_out = session.run(tf_ans)
|
||||||
|
self.assertAllEqual(correct_ans, tf_out)
|
||||||
|
|
||||||
|
def testLowerBound2DExample(self):
|
||||||
|
# 2D TensorFlow documentation example.
|
||||||
|
for dtype in self.float_types | self.int_types:
|
||||||
|
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
|
||||||
|
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
|
||||||
|
correct_ans = np.array([[1, 2, 2], [0, 1, 5]], dtype)
|
||||||
|
self._test2DExample(dtype, 'left', sorted_sequence, values, correct_ans)
|
||||||
|
|
||||||
|
def testUpperBound2DExample(self):
|
||||||
|
# 2D TensorFlow documentation example.
|
||||||
|
for dtype in self.float_types | self.int_types:
|
||||||
|
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
|
||||||
|
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
|
||||||
|
correct_ans = np.array([[1, 2, 4], [0, 2, 5]], dtype)
|
||||||
|
self._test2DExample(dtype, 'right', sorted_sequence, values, correct_ans)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -55,6 +55,7 @@ tf_kernel_library(
|
|||||||
"index_ops.cc",
|
"index_ops.cc",
|
||||||
"l2loss_op.cc",
|
"l2loss_op.cc",
|
||||||
"listdiff_op.cc",
|
"listdiff_op.cc",
|
||||||
|
"lower_upper_bound_ops.cc",
|
||||||
"lrn_ops.cc",
|
"lrn_ops.cc",
|
||||||
"matmul_op.cc",
|
"matmul_op.cc",
|
||||||
"matrix_band_part_op.cc",
|
"matrix_band_part_op.cc",
|
||||||
@ -149,6 +150,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/tf2xla/lib:util",
|
"//tensorflow/compiler/tf2xla/lib:util",
|
||||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||||
"//tensorflow/compiler/xla:array4d",
|
"//tensorflow/compiler/xla:array4d",
|
||||||
|
"//tensorflow/compiler/xla:comparison_util",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
116
tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc
Normal file
116
tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Builds a LowerBound or UpperBound op, the distinction lying in
|
||||||
|
// comparison_direction: GT => LowerBoundOp, GE => UpperBoundOp.
|
||||||
|
// Note that this is an O(MN) algorithm: all entries in each sorted_inputs row
|
||||||
|
// are considered, and their sorted nature is not fully exploited.
|
||||||
|
void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype,
|
||||||
|
xla::ComparisonDirection comparison_direction) {
|
||||||
|
const TensorShape sorted_inputs_shape = ctx->InputShape("sorted_inputs");
|
||||||
|
const TensorShape values_shape = ctx->InputShape("values");
|
||||||
|
const xla::XlaOp sorted_inputs = ctx->Input("sorted_inputs");
|
||||||
|
const xla::XlaOp values = ctx->Input("values");
|
||||||
|
|
||||||
|
// We are assuming both inputs are 2D, which they will be given the current
|
||||||
|
// implementation of tf.searchsorted.
|
||||||
|
OP_REQUIRES(ctx, sorted_inputs_shape.dims() == 2,
|
||||||
|
errors::FailedPrecondition("sorted_inputs must be 2D"));
|
||||||
|
OP_REQUIRES(ctx, values_shape.dims() == 2,
|
||||||
|
errors::FailedPrecondition("values must be 2D"));
|
||||||
|
|
||||||
|
// Add a new inner dimension to values, to allow broadcasting along the inner
|
||||||
|
// dimension of sorted_sequence.
|
||||||
|
auto new_values_shape = values_shape;
|
||||||
|
new_values_shape.InsertDim(/* d */ 2, /* size */ 1);
|
||||||
|
auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes());
|
||||||
|
|
||||||
|
// Add a new penultimate dimension to sorted_inputs, to allow broadcasting of
|
||||||
|
// sorted_sequence entries for each value.
|
||||||
|
auto new_sorted_inputs_shape = sorted_inputs_shape;
|
||||||
|
new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1);
|
||||||
|
auto sorted_inputs_reshaped =
|
||||||
|
xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes());
|
||||||
|
|
||||||
|
// We are relying on broadcasting to compare each value against each entry in
|
||||||
|
// the associated sorted_inputs row.
|
||||||
|
// The reshapes above leave the tensors with equal rank of 3, so broadcast
|
||||||
|
// dimensions are not explicitly specified.
|
||||||
|
auto comparison = xla::Compare(values_reshaped, sorted_inputs_reshaped, {},
|
||||||
|
comparison_direction);
|
||||||
|
|
||||||
|
const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype);
|
||||||
|
|
||||||
|
// Convert boolean comparison results to integers so we can sum them.
|
||||||
|
auto comparison_int =
|
||||||
|
XlaHelpers::ConvertElementType(comparison, accumulation_type);
|
||||||
|
|
||||||
|
// Sum the comparison results over the inner dimension to find the index for
|
||||||
|
// each value.
|
||||||
|
xla::XlaBuilder* builder = ctx->builder();
|
||||||
|
auto reduced =
|
||||||
|
xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type),
|
||||||
|
*ctx->GetOrCreateAdd(accumulation_type), {2});
|
||||||
|
|
||||||
|
ctx->SetOutput(0, reduced);
|
||||||
|
}
|
||||||
|
|
||||||
|
class LowerBoundOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit LowerBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGt);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType out_dtype_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("LowerBound"), LowerBoundOp);
|
||||||
|
|
||||||
|
class UpperBoundOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit UpperBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGe);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType out_dtype_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("UpperBound"), UpperBoundOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user