Override reduce and reduce-window ops in MlirHloBuilder

Test these ops with Bucketize, LRN and LRNGrad ops

PiperOrigin-RevId: 316211359
Change-Id: Ief1d518e3c0baddc21d86bb03ae022056be7fb13
This commit is contained in:
Smit Hinsu 2020-06-12 18:04:00 -07:00 committed by TensorFlower Gardener
parent 3862d97f85
commit 6d73ffb374
8 changed files with 119 additions and 14 deletions

View File

@ -132,6 +132,52 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) {
// Reduce takes two set of variadic operands inputs and init_values.
// all_operands contains both of these so split operands into two parts.
int64_t num_args = all_operands.size() / 2;
auto op = builder_.create<mlir::xla_hlo::ReduceOp>(
loc_, GetValues(all_operands.first(num_args)),
GetValues(all_operands.subspan(num_args)),
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
auto tuple = builder_.create<mlir::xla_hlo::TupleOp>(loc_, op.getResults());
return MakeXlaOp(tuple);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
llvm::SmallVector<int64, 8> padding;
for (const auto& dim : window.dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
auto padding_ty =
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
builder_.getIntegerType(64));
auto op = builder_.create<mlir::xla_hlo::ReduceWindowOp>(
loc_, ty, GetValue(operand), GetValue(init_value),
GetI64ElementsAttr(sizes, &builder_),
GetI64ElementsAttr(strides, &builder_),
GetI64ElementsAttr(base_dilations, &builder_),
GetI64ElementsAttr(win_dilations, &builder_),
mlir::DenseIntElementsAttr::get(padding_ty, padding));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
return MakeXlaOp(op);
}
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(

View File

@ -124,6 +124,16 @@ class MlirHloBuilder : public XlaBuilder {
FftType fft_type,
absl::Span<const int64> fft_length) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) override;
StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand,
XlaOp init_value,
const XlaComputation& computation,
Window window) override;
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
StatusOr<XlaOp> TransposeInternal(

View File

@ -236,6 +236,21 @@ func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
return %1 : tensor<4x7xcomplex<f64>>
}
// CHECK-LABEL: bucketize
func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> {
// CHECK-NOT: tf.Bucketize
%0 = "tf.Bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32>
return %0 : tensor<2x5xi32>
}
// CHECK-LABEL: arg_min
func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
// CHECK-NOT: ArgMin
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -89,6 +89,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::AngleOp>(),
TypeID::get<TF::ApproximateEqualOp>(),
TypeID::get<TF::ArgMaxOp>(),
TypeID::get<TF::ArgMinOp>(),
TypeID::get<TF::AsinhOp>(),
TypeID::get<TF::AsinOp>(),
TypeID::get<TF::Atan2Op>(),
@ -100,6 +102,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::BitwiseAndOp>(),
TypeID::get<TF::BitwiseOrOp>(),
TypeID::get<TF::BitwiseXorOp>(),
TypeID::get<TF::BucketizeOp>(),
TypeID::get<TF::CastOp>(),
TypeID::get<TF::ClipByValueOp>(),
TypeID::get<TF::ComplexAbsOp>(),
@ -132,6 +135,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::IRFFTOp>(),
TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(),
TypeID::get<TF::LRNOp>(),
TypeID::get<TF::LRNGradOp>(),
TypeID::get<TF::LeakyReluGradOp>(),
TypeID::get<TF::LeakyReluOp>(),
TypeID::get<TF::LeftShiftOp>(),

View File

@ -185,6 +185,7 @@ tf_xla_py_test(
name = "argminmax_test",
size = "small",
srcs = ["argminmax_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -253,6 +254,7 @@ tf_xla_py_test(
name = "bucketize_op_test",
size = "small",
srcs = ["bucketize_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -806,6 +808,7 @@ tf_xla_py_test(
name = "lrn_ops_test",
size = "medium",
srcs = ["lrn_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -57,6 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
expected_out, sess.run(op,
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
@test_util.disable_mlir_bridge("Error handling")
def testInvalidBoundariesOrder(self):
with self.session() as sess:
p = array_ops.placeholder(dtypes.int32)

View File

@ -2040,8 +2040,6 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
@ -2060,6 +2058,17 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
Shape shape,
ShapeInference::InferReduceShape(
operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
return ReduceInternal(shape, all_operands, computation,
dimensions_to_reduce);
});
}
StatusOr<XlaOp> XlaBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : dimensions_to_reduce) {
@ -2067,7 +2076,6 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
}
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
});
}
@ -2110,28 +2118,35 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
TF_ASSIGN_OR_RETURN(auto window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding,
/*lhs_dilation=*/base_dilations,
/*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
*operand_shape, *init_shape,
instr.window(), to_apply_shape));
*instr.mutable_shape() = shape.ToProto();
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
{operand, init_value});
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferReduceWindowShape(
*operand_shape, *init_shape, window, to_apply_shape));
return ReduceWindowInternal(shape, operand, init_value, computation,
std::move(window));
});
}
StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_window() = std::move(window);
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
{operand, init_value});
}
XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
float epsilon, int64 feature_index) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {

View File

@ -542,6 +542,11 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce);
virtual StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce);
XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
const XlaComputation& computation);
@ -558,6 +563,10 @@ class XlaBuilder {
absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
virtual StatusOr<XlaOp> ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window);
XlaOp CrossReplicaSum(XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups = {});