Override reduce and reduce-window ops in MlirHloBuilder

Test these ops with Bucketize, LRN and LRNGrad ops

PiperOrigin-RevId: 316211359
Change-Id: Ief1d518e3c0baddc21d86bb03ae022056be7fb13
This commit is contained in:
Smit Hinsu 2020-06-12 18:04:00 -07:00 committed by TensorFlower Gardener
parent 3862d97f85
commit 6d73ffb374
8 changed files with 119 additions and 14 deletions

View File

@ -132,6 +132,52 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
return MakeXlaOp(op); return MakeXlaOp(op);
} }
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) {
// Reduce takes two set of variadic operands inputs and init_values.
// all_operands contains both of these so split operands into two parts.
int64_t num_args = all_operands.size() / 2;
auto op = builder_.create<mlir::xla_hlo::ReduceOp>(
loc_, GetValues(all_operands.first(num_args)),
GetValues(all_operands.subspan(num_args)),
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
auto tuple = builder_.create<mlir::xla_hlo::TupleOp>(loc_, op.getResults());
return MakeXlaOp(tuple);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
llvm::SmallVector<int64, 8> padding;
for (const auto& dim : window.dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
auto padding_ty =
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
builder_.getIntegerType(64));
auto op = builder_.create<mlir::xla_hlo::ReduceWindowOp>(
loc_, ty, GetValue(operand), GetValue(init_value),
GetI64ElementsAttr(sizes, &builder_),
GetI64ElementsAttr(strides, &builder_),
GetI64ElementsAttr(base_dilations, &builder_),
GetI64ElementsAttr(win_dilations, &builder_),
mlir::DenseIntElementsAttr::get(padding_ty, padding));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
return MakeXlaOp(op);
}
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) { XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(

View File

@ -124,6 +124,16 @@ class MlirHloBuilder : public XlaBuilder {
FftType fft_type, FftType fft_type,
absl::Span<const int64> fft_length) override; absl::Span<const int64> fft_length) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) override;
StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand,
XlaOp init_value,
const XlaComputation& computation,
Window window) override;
XlaOp Iota(const Shape& shape, int64 iota_dimension) override; XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
StatusOr<XlaOp> TransposeInternal( StatusOr<XlaOp> TransposeInternal(

View File

@ -236,6 +236,21 @@ func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
return %1 : tensor<4x7xcomplex<f64>> return %1 : tensor<4x7xcomplex<f64>>
} }
// CHECK-LABEL: bucketize
func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> {
// CHECK-NOT: tf.Bucketize
%0 = "tf.Bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32>
return %0 : tensor<2x5xi32>
}
// CHECK-LABEL: arg_min
func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
// CHECK-NOT: ArgMin
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance. // available but doesn't support this instance.
} }

View File

@ -89,6 +89,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::AddV2Op>(), TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::AngleOp>(), TypeID::get<TF::AngleOp>(),
TypeID::get<TF::ApproximateEqualOp>(), TypeID::get<TF::ApproximateEqualOp>(),
TypeID::get<TF::ArgMaxOp>(),
TypeID::get<TF::ArgMinOp>(),
TypeID::get<TF::AsinhOp>(), TypeID::get<TF::AsinhOp>(),
TypeID::get<TF::AsinOp>(), TypeID::get<TF::AsinOp>(),
TypeID::get<TF::Atan2Op>(), TypeID::get<TF::Atan2Op>(),
@ -100,6 +102,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::BitwiseAndOp>(), TypeID::get<TF::BitwiseAndOp>(),
TypeID::get<TF::BitwiseOrOp>(), TypeID::get<TF::BitwiseOrOp>(),
TypeID::get<TF::BitwiseXorOp>(), TypeID::get<TF::BitwiseXorOp>(),
TypeID::get<TF::BucketizeOp>(),
TypeID::get<TF::CastOp>(), TypeID::get<TF::CastOp>(),
TypeID::get<TF::ClipByValueOp>(), TypeID::get<TF::ClipByValueOp>(),
TypeID::get<TF::ComplexAbsOp>(), TypeID::get<TF::ComplexAbsOp>(),
@ -132,6 +135,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::IRFFTOp>(), TypeID::get<TF::IRFFTOp>(),
TypeID::get<TF::InvertOp>(), TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(), TypeID::get<TF::InvOp>(),
TypeID::get<TF::LRNOp>(),
TypeID::get<TF::LRNGradOp>(),
TypeID::get<TF::LeakyReluGradOp>(), TypeID::get<TF::LeakyReluGradOp>(),
TypeID::get<TF::LeakyReluOp>(), TypeID::get<TF::LeakyReluOp>(),
TypeID::get<TF::LeftShiftOp>(), TypeID::get<TF::LeftShiftOp>(),

View File

@ -185,6 +185,7 @@ tf_xla_py_test(
name = "argminmax_test", name = "argminmax_test",
size = "small", size = "small",
srcs = ["argminmax_test.py"], srcs = ["argminmax_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -253,6 +254,7 @@ tf_xla_py_test(
name = "bucketize_op_test", name = "bucketize_op_test",
size = "small", size = "small",
srcs = ["bucketize_op_test.py"], srcs = ["bucketize_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -806,6 +808,7 @@ tf_xla_py_test(
name = "lrn_ops_test", name = "lrn_ops_test",
size = "medium", size = "medium",
srcs = ["lrn_ops_test.py"], srcs = ["lrn_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -57,6 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
expected_out, sess.run(op, expected_out, sess.run(op,
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
@test_util.disable_mlir_bridge("Error handling")
def testInvalidBoundariesOrder(self): def testInvalidBoundariesOrder(self):
with self.session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtypes.int32) p = array_ops.placeholder(dtypes.int32)

View File

@ -2040,8 +2040,6 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
const XlaComputation& computation, const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) { absl::Span<const int64> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape()); computation.GetProgramShape());
@ -2060,6 +2058,17 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
Shape shape, Shape shape,
ShapeInference::InferReduceShape( ShapeInference::InferReduceShape(
operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
return ReduceInternal(shape, all_operands, computation,
dimensions_to_reduce);
});
}
StatusOr<XlaOp> XlaBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto(); *instr.mutable_shape() = shape.ToProto();
for (int64 dim : dimensions_to_reduce) { for (int64 dim : dimensions_to_reduce) {
@ -2067,7 +2076,6 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
} }
AddCalledComputation(computation, &instr); AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands); return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
}); });
} }
@ -2110,26 +2118,33 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
absl::Span<const int64> window_dilations, absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) { absl::Span<const std::pair<int64, int64>> padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value)); TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
computation.GetProgramShape()); computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_window(), TF_ASSIGN_OR_RETURN(auto window,
ShapeInference::InferWindowFromDimensions( ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding, window_dimensions, window_strides, padding,
/*lhs_dilation=*/base_dilations, /*lhs_dilation=*/base_dilations,
/*rhs_dilation=*/window_dilations)); /*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape( TF_ASSIGN_OR_RETURN(
*operand_shape, *init_shape, Shape shape, ShapeInference::InferReduceWindowShape(
instr.window(), to_apply_shape)); *operand_shape, *init_shape, window, to_apply_shape));
return ReduceWindowInternal(shape, operand, init_value, computation,
std::move(window));
});
}
StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto(); *instr.mutable_shape() = shape.ToProto();
*instr.mutable_window() = std::move(window);
AddCalledComputation(computation, &instr); AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
{operand, init_value}); {operand, init_value});
});
} }
XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,

View File

@ -542,6 +542,11 @@ class XlaBuilder {
const XlaComputation& computation, const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce); absl::Span<const int64> dimensions_to_reduce);
virtual StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce);
XlaOp ReduceAll(XlaOp operand, XlaOp init_value, XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
const XlaComputation& computation); const XlaComputation& computation);
@ -558,6 +563,10 @@ class XlaBuilder {
absl::Span<const int64> window_dilations, absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding); absl::Span<const std::pair<int64, int64>> padding);
virtual StatusOr<XlaOp> ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window);
XlaOp CrossReplicaSum(XlaOp operand, XlaOp CrossReplicaSum(XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups = {}); absl::Span<const ReplicaGroup> replica_groups = {});