Roll back to fix broken test.

PiperOrigin-RevId: 299338882
Change-Id: I3714947038653b3567cb562880b6ca577a3e74fa
This commit is contained in:
A. Unique TensorFlower 2020-03-06 06:04:22 -08:00 committed by TensorFlower Gardener
parent 5baff53e4e
commit 5cd4fc0d0c
5 changed files with 210 additions and 0 deletions

View File

@ -1911,6 +1911,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"LinSpace",
"ListDiff",
"LogMatrixDeterminant",
"LowerBound",
"MatMul",
"MatrixBandPart",
"MatrixDiag",
@ -2037,6 +2038,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorScatterUpdate",
"TridiagonalSolve",
"TruncatedNormal",
"UpperBound",
"UnsortedSegmentMax",
"UnsortedSegmentMin",
"UnsortedSegmentProd",

View File

@ -338,6 +338,21 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "searchsorted_op_test",
size = "small",
timeout = "moderate",
srcs = ["searchsorted_op_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "svd_op_test",
size = "medium",

View File

@ -0,0 +1,75 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for XLA implementation of tf.searchsorted."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class SearchSorteddOpTest(xla_test.XLATestCase):
def test1D(self):
# Test against NumPy implementation (which is 1D only).
np.random.seed(1)
for side in ['left', 'right']:
for dtype in [np.float32, np.int32]:
values = np.random.uniform(
low=-1000, high=1000, size=(10,)).astype(dtype)
unsorted = np.random.uniform(
low=-1000, high=1000, size=(20,)).astype(dtype)
sorted_sequence = np.sort(unsorted)
np_ans = np.searchsorted(sorted_sequence, values, side=side)
with self.session() as session:
with self.test_scope():
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
tf_out = session.run(tf_ans)
self.assertAllEqual(np_ans, tf_out)
def _test2DExample(self, dtype, side, sorted_sequence, values, correct_ans):
with self.session() as session:
with self.test_scope():
tf_ans = array_ops.searchsorted(sorted_sequence, values, side=side)
tf_out = session.run(tf_ans)
self.assertAllEqual(correct_ans, tf_out)
def testLowerBound2DExample(self):
# 2D TensorFlow documentation example.
for dtype in self.float_types | self.int_types:
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
correct_ans = np.array([[1, 2, 2], [0, 1, 5]], dtype)
self._test2DExample(dtype, 'left', sorted_sequence, values, correct_ans)
def testUpperBound2DExample(self):
# 2D TensorFlow documentation example.
for dtype in self.float_types | self.int_types:
sorted_sequence = np.array([[0, 3, 9, 9, 10], [1, 2, 3, 4, 5]], dtype)
values = np.array([[2, 4, 9], [0, 2, 6]], dtype)
correct_ans = np.array([[1, 2, 4], [0, 2, 5]], dtype)
self._test2DExample(dtype, 'right', sorted_sequence, values, correct_ans)
if __name__ == '__main__':
test.main()

View File

@ -55,6 +55,7 @@ tf_kernel_library(
"index_ops.cc",
"l2loss_op.cc",
"listdiff_op.cc",
"lower_upper_bound_ops.cc",
"lrn_ops.cc",
"matmul_op.cc",
"matrix_band_part_op.cc",
@ -149,6 +150,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",

View File

@ -0,0 +1,116 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
// Builds a LowerBound or UpperBound op, the distinction lying in
// comparison_direction: GT => LowerBoundOp, GE => UpperBoundOp.
// Note that this is an O(MN) algorithm: all entries in each sorted_inputs row
// are considered, and their sorted nature is not fully exploited.
void BuildLowerUpperBoundOp(XlaOpKernelContext* ctx, DataType out_dtype,
xla::ComparisonDirection comparison_direction) {
const TensorShape sorted_inputs_shape = ctx->InputShape("sorted_inputs");
const TensorShape values_shape = ctx->InputShape("values");
const xla::XlaOp sorted_inputs = ctx->Input("sorted_inputs");
const xla::XlaOp values = ctx->Input("values");
// We are assuming both inputs are 2D, which they will be given the current
// implementation of tf.searchsorted.
OP_REQUIRES(ctx, sorted_inputs_shape.dims() == 2,
errors::FailedPrecondition("sorted_inputs must be 2D"));
OP_REQUIRES(ctx, values_shape.dims() == 2,
errors::FailedPrecondition("values must be 2D"));
// Add a new inner dimension to values, to allow broadcasting along the inner
// dimension of sorted_sequence.
auto new_values_shape = values_shape;
new_values_shape.InsertDim(/* d */ 2, /* size */ 1);
auto values_reshaped = xla::Reshape(values, new_values_shape.dim_sizes());
// Add a new penultimate dimension to sorted_inputs, to allow broadcasting of
// sorted_sequence entries for each value.
auto new_sorted_inputs_shape = sorted_inputs_shape;
new_sorted_inputs_shape.InsertDim(/* d */ 1, /* size */ 1);
auto sorted_inputs_reshaped =
xla::Reshape(sorted_inputs, new_sorted_inputs_shape.dim_sizes());
// We are relying on broadcasting to compare each value against each entry in
// the associated sorted_inputs row.
// The reshapes above leave the tensors with equal rank of 3, so broadcast
// dimensions are not explicitly specified.
auto comparison = xla::Compare(values_reshaped, sorted_inputs_reshaped, {},
comparison_direction);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype);
// Convert boolean comparison results to integers so we can sum them.
auto comparison_int =
XlaHelpers::ConvertElementType(comparison, accumulation_type);
// Sum the comparison results over the inner dimension to find the index for
// each value.
xla::XlaBuilder* builder = ctx->builder();
auto reduced =
xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), {2});
ctx->SetOutput(0, reduced);
}
class LowerBoundOp : public XlaOpKernel {
public:
explicit LowerBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGt);
}
private:
DataType out_dtype_;
};
REGISTER_XLA_OP(Name("LowerBound"), LowerBoundOp);
class UpperBoundOp : public XlaOpKernel {
public:
explicit UpperBoundOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
BuildLowerUpperBoundOp(ctx, out_dtype_, xla::ComparisonDirection::kGe);
}
private:
DataType out_dtype_;
};
REGISTER_XLA_OP(Name("UpperBound"), UpperBoundOp);
} // namespace
} // namespace tensorflow