Roll back to fix broken test.
PiperOrigin-RevId: 299338882 Change-Id: I3714947038653b3567cb562880b6ca577a3e74fa
This commit is contained in:
parent
5baff53e4e
commit
5cd4fc0d0c
@ -1911,6 +1911,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"LinSpace",
|
||||
"ListDiff",
|
||||
"LogMatrixDeterminant",
|
||||
"LowerBound",
|
||||
"MatMul",
|
||||
"MatrixBandPart",
|
||||
"MatrixDiag",
|
||||
@ -2037,6 +2038,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorScatterUpdate",
|
||||
"TridiagonalSolve",
|
||||
"TruncatedNormal",
|
||||
"UpperBound",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
|
@ -338,6 +338,21 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "searchsorted_op_test",
|
||||
size = "small",
|
||||
timeout = "moderate",
|
||||
srcs = ["searchsorted_op_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "svd_op_test",
|
||||
size = "medium",
|
||||
|
75
tensorflow/compiler/tests/searchsorted_op_test.py
Normal file
75
tensorflow/compiler/tests/searchsorted_op_test.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test for XLA implementation of tf.searchsorted."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SearchSorteddOpTest(xla_test.XLATestCase):
|
||||
|
||||
def test1D(self):
|
||||
# Test against NumPy implementation (which is 1D only).
|
||||
np.random.seed(1)
|
||||
for side in ['left', 'right']:
|
||||
for dtype in [np.float32, np.int32]:
|
||||
values = np.random.uniform(
|
||||
low=-1000, high=1000, size=(10,)).astype(dtype)
|
||||
unsorted = np.random.uniform(
|
||||
low=-1000, high=1000, size=(20,)).astype(dtype)
|
||||
|
||||
sorted_sequence = np.sort(unsorted)
|
||||
np_ans = np.searchsorted(sorted_sequence, values, side=side)
|
||||
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
|
||||
tf_out = session.run(tf_ans)
|
||||
self.assertAllEqual(np_ans, tf_out)
|
||||
|
||||
def _test2DExample(self, dtype, side, sorted_sequence, values, correct_ans):
|
||||
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
|
||||
tf_out = session.run(tf_ans)
|
||||
self.assertAllEqual(correct_ans, tf_out)
|
||||
|
||||
def testLowerBound2DExample(self):
|
||||
# 2D TensorFlow documentation example.
|
||||
for dtype in self.float_types | self.int_types:
|
||||
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
|
||||
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
|
||||
correct_ans = np.array([[1, 2, 2], [0, 1, 5]], dtype)
|
||||
self._test2DExample(dtype, 'left', sorted_sequence, values, correct_ans)
|
||||
|
||||
def testUpperBound2DExample(self):
|
||||
# 2D TensorFlow documentation example.
|
||||
for dtype in self.float_types | self.int_types:
|
||||
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
|
||||
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
|
||||
correct_ans = np.array([[1, 2, 4], [0, 2, 5]], dtype)
|
||||
self._test2DExample(dtype, 'right', sorted_sequence, values, correct_ans)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -55,6 +55,7 @@ tf_kernel_library(
|
||||
"index_ops.cc",
|
||||
"l2loss_op.cc",
|
||||
"listdiff_op.cc",
|
||||
"lower_upper_bound_ops.cc",
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
"matrix_band_part_op.cc",
|
||||
@ -149,6 +150,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
116
tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc
Normal file
116
tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc
Normal file
@ -0,0 +1,116 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Builds a LowerBound or UpperBound op, the distinction lying in
|
||||
// comparison_direction: GT => LowerBoundOp, GE => UpperBoundOp.
|
||||
// Note that this is an O(MN) algorithm: all entries in each sorted_inputs row
|
||||
// are considered, and their sorted nature is not fully exploited.
|
||||
void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype,
|
||||
xla::ComparisonDirection comparison_direction) {
|
||||
const TensorShape sorted_inputs_shape = ctx->InputShape("sorted_inputs");
|
||||
const TensorShape values_shape = ctx->InputShape("values");
|
||||
const xla::XlaOp sorted_inputs = ctx->Input("sorted_inputs");
|
||||
const xla::XlaOp values = ctx->Input("values");
|
||||
|
||||
// We are assuming both inputs are 2D, which they will be given the current
|
||||
// implementation of tf.searchsorted.
|
||||
OP_REQUIRES(ctx, sorted_inputs_shape.dims() == 2,
|
||||
errors::FailedPrecondition("sorted_inputs must be 2D"));
|
||||
OP_REQUIRES(ctx, values_shape.dims() == 2,
|
||||
errors::FailedPrecondition("values must be 2D"));
|
||||
|
||||
// Add a new inner dimension to values, to allow broadcasting along the inner
|
||||
// dimension of sorted_sequence.
|
||||
auto new_values_shape = values_shape;
|
||||
new_values_shape.InsertDim(/* d */ 2, /* size */ 1);
|
||||
auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes());
|
||||
|
||||
// Add a new penultimate dimension to sorted_inputs, to allow broadcasting of
|
||||
// sorted_sequence entries for each value.
|
||||
auto new_sorted_inputs_shape = sorted_inputs_shape;
|
||||
new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1);
|
||||
auto sorted_inputs_reshaped =
|
||||
xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes());
|
||||
|
||||
// We are relying on broadcasting to compare each value against each entry in
|
||||
// the associated sorted_inputs row.
|
||||
// The reshapes above leave the tensors with equal rank of 3, so broadcast
|
||||
// dimensions are not explicitly specified.
|
||||
auto comparison = xla::Compare(values_reshaped, sorted_inputs_reshaped, {},
|
||||
comparison_direction);
|
||||
|
||||
const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype);
|
||||
|
||||
// Convert boolean comparison results to integers so we can sum them.
|
||||
auto comparison_int =
|
||||
XlaHelpers::ConvertElementType(comparison, accumulation_type);
|
||||
|
||||
// Sum the comparison results over the inner dimension to find the index for
|
||||
// each value.
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto reduced =
|
||||
xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type),
|
||||
*ctx->GetOrCreateAdd(accumulation_type), {2});
|
||||
|
||||
ctx->SetOutput(0, reduced);
|
||||
}
|
||||
|
||||
class LowerBoundOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit LowerBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGt);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType out_dtype_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("LowerBound"), LowerBoundOp);
|
||||
|
||||
class UpperBoundOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit UpperBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGe);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType out_dtype_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("UpperBound"), UpperBoundOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user