diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 08dc1b13db6..b36fe6ae5e9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1911,6 +1911,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "LinSpace", "ListDiff", "LogMatrixDeterminant", + "LowerBound", "MatMul", "MatrixBandPart", "MatrixDiag", @@ -2037,6 +2038,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "TensorScatterUpdate", "TridiagonalSolve", "TruncatedNormal", + "UpperBound", "UnsortedSegmentMax", "UnsortedSegmentMin", "UnsortedSegmentProd", diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f3ee4e38f31..77cd3dc074c 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -338,6 +338,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "searchsorted_op_test", + size = "small", + timeout = "moderate", + srcs = ["searchsorted_op_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "svd_op_test", size = "medium", diff --git a/tensorflow/compiler/tests/searchsorted_op_test.py b/tensorflow/compiler/tests/searchsorted_op_test.py new file mode 100644 index 00000000000..d77bd0902d3 --- /dev/null +++ b/tensorflow/compiler/tests/searchsorted_op_test.py @@ -0,0 +1,75 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for XLA implementation of tf.searchsorted.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SearchSorteddOpTest(xla_test.XLATestCase): + + def test1D(self): + # Test against NumPy implementation (which is 1D only). + np.random.seed(1) + for side in ['left', 'right']: + for dtype in [np.float32, np.int32]: + values = np.random.uniform( + low=-1000, high=1000, size=(10,)).astype(dtype) + unsorted = np.random.uniform( + low=-1000, high=1000, size=(20,)).astype(dtype) + + sorted_sequence = np.sort(unsorted) + np_ans = np.searchsorted(sorted_sequence, values, side=side) + + with self.session() as session: + with self.test_scope(): + tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side) + tf_out = session.run(tf_ans) + self.assertAllEqual(np_ans, tf_out) + + def _test2DExample(self, dtype, side, sorted_sequence, values, correct_ans): + + with self.session() as session: + with self.test_scope(): + tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side) + tf_out = session.run(tf_ans) + self.assertAllEqual(correct_ans, tf_out) + + def testLowerBound2DExample(self): + # 2D TensorFlow documentation example. + for dtype in self.float_types | self.int_types: + sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype) + values = np.array([[2, 4, 9], [0, 2, 6]], dtype) + correct_ans = np.array([[1, 2, 2], [0, 1, 5]], dtype) + self._test2DExample(dtype, 'left', sorted_sequence, values, correct_ans) + + def testUpperBound2DExample(self): + # 2D TensorFlow documentation example. + for dtype in self.float_types | self.int_types: + sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype) + values = np.array([[2, 4, 9], [0, 2, 6]], dtype) + correct_ans = np.array([[1, 2, 4], [0, 2, 5]], dtype) + self._test2DExample(dtype, 'right', sorted_sequence, values, correct_ans) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 8571c503299..5f1c2f28ba4 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -55,6 +55,7 @@ tf_kernel_library( "index_ops.cc", "l2loss_op.cc", "listdiff_op.cc", + "lower_upper_bound_ops.cc", "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", @@ -149,6 +150,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc new file mode 100644 index 00000000000..0eacf8812f1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +// Builds a LowerBound or UpperBound op, the distinction lying in +// comparison_direction: GT => LowerBoundOp, GE => UpperBoundOp. +// Note that this is an O(MN) algorithm: all entries in each sorted_inputs row +// are considered, and their sorted nature is not fully exploited. +void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype, + xla::ComparisonDirection comparison_direction) { + const TensorShape sorted_inputs_shape = ctx->InputShape("sorted_inputs"); + const TensorShape values_shape = ctx->InputShape("values"); + const xla::XlaOp sorted_inputs = ctx->Input("sorted_inputs"); + const xla::XlaOp values = ctx->Input("values"); + + // We are assuming both inputs are 2D, which they will be given the current + // implementation of tf.searchsorted. + OP_REQUIRES(ctx, sorted_inputs_shape.dims() == 2, + errors::FailedPrecondition("sorted_inputs must be 2D")); + OP_REQUIRES(ctx, values_shape.dims() == 2, + errors::FailedPrecondition("values must be 2D")); + + // Add a new inner dimension to values, to allow broadcasting along the inner + // dimension of sorted_sequence. + auto new_values_shape = values_shape; + new_values_shape.InsertDim(/* d */ 2, /* size */ 1); + auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes()); + + // Add a new penultimate dimension to sorted_inputs, to allow broadcasting of + // sorted_sequence entries for each value. + auto new_sorted_inputs_shape = sorted_inputs_shape; + new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1); + auto sorted_inputs_reshaped = + xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes()); + + // We are relying on broadcasting to compare each value against each entry in + // the associated sorted_inputs row. + // The reshapes above leave the tensors with equal rank of 3, so broadcast + // dimensions are not explicitly specified. + auto comparison = xla::Compare(values_reshaped, sorted_inputs_reshaped, {}, + comparison_direction); + + const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype); + + // Convert boolean comparison results to integers so we can sum them. + auto comparison_int = + XlaHelpers::ConvertElementType(comparison, accumulation_type); + + // Sum the comparison results over the inner dimension to find the index for + // each value. + xla::XlaBuilder* builder = ctx->builder(); + auto reduced = + xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {2}); + + ctx->SetOutput(0, reduced); +} + +class LowerBoundOp : public XlaOpKernel { + public: + explicit LowerBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGt); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("LowerBound"), LowerBoundOp); + +class UpperBoundOp : public XlaOpKernel { + public: + explicit UpperBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGe); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("UpperBound"), UpperBoundOp); + +} // namespace +} // namespace tensorflow