Override reduce and reduce-window ops in MlirHloBuilder
Test these ops with Bucketize, LRN and LRNGrad ops PiperOrigin-RevId: 316211359 Change-Id: Ief1d518e3c0baddc21d86bb03ae022056be7fb13
This commit is contained in:
parent
3862d97f85
commit
6d73ffb374
@ -132,6 +132,52 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
// Reduce takes two set of variadic operands inputs and init_values.
|
||||
// all_operands contains both of these so split operands into two parts.
|
||||
int64_t num_args = all_operands.size() / 2;
|
||||
auto op = builder_.create<mlir::xla_hlo::ReduceOp>(
|
||||
loc_, GetValues(all_operands.first(num_args)),
|
||||
GetValues(all_operands.subspan(num_args)),
|
||||
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
|
||||
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
|
||||
auto tuple = builder_.create<mlir::xla_hlo::TupleOp>(loc_, op.getResults());
|
||||
return MakeXlaOp(tuple);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation, Window window) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
|
||||
llvm::SmallVector<int64, 8> padding;
|
||||
for (const auto& dim : window.dimensions()) {
|
||||
sizes.push_back(dim.size());
|
||||
strides.push_back(dim.stride());
|
||||
base_dilations.push_back(dim.base_dilation());
|
||||
win_dilations.push_back(dim.window_dilation());
|
||||
padding.push_back(dim.padding_low());
|
||||
padding.push_back(dim.padding_high());
|
||||
}
|
||||
auto padding_ty =
|
||||
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
|
||||
builder_.getIntegerType(64));
|
||||
auto op = builder_.create<mlir::xla_hlo::ReduceWindowOp>(
|
||||
loc_, ty, GetValue(operand), GetValue(init_value),
|
||||
GetI64ElementsAttr(sizes, &builder_),
|
||||
GetI64ElementsAttr(strides, &builder_),
|
||||
GetI64ElementsAttr(base_dilations, &builder_),
|
||||
GetI64ElementsAttr(win_dilations, &builder_),
|
||||
mlir::DenseIntElementsAttr::get(padding_ty, padding));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
@ -124,6 +124,16 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
FftType fft_type,
|
||||
absl::Span<const int64> fft_length) override;
|
||||
|
||||
StatusOr<XlaOp> ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce) override;
|
||||
|
||||
StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand,
|
||||
XlaOp init_value,
|
||||
const XlaComputation& computation,
|
||||
Window window) override;
|
||||
|
||||
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
|
||||
|
||||
StatusOr<XlaOp> TransposeInternal(
|
||||
|
@ -236,6 +236,21 @@ func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
|
||||
return %1 : tensor<4x7xcomplex<f64>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: bucketize
|
||||
func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> {
|
||||
// CHECK-NOT: tf.Bucketize
|
||||
%0 = "tf.Bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32>
|
||||
return %0 : tensor<2x5xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: arg_min
|
||||
func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
|
||||
// CHECK-NOT: ArgMin
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor<i32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -89,6 +89,8 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::AddV2Op>(),
|
||||
TypeID::get<TF::AngleOp>(),
|
||||
TypeID::get<TF::ApproximateEqualOp>(),
|
||||
TypeID::get<TF::ArgMaxOp>(),
|
||||
TypeID::get<TF::ArgMinOp>(),
|
||||
TypeID::get<TF::AsinhOp>(),
|
||||
TypeID::get<TF::AsinOp>(),
|
||||
TypeID::get<TF::Atan2Op>(),
|
||||
@ -100,6 +102,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::BitwiseAndOp>(),
|
||||
TypeID::get<TF::BitwiseOrOp>(),
|
||||
TypeID::get<TF::BitwiseXorOp>(),
|
||||
TypeID::get<TF::BucketizeOp>(),
|
||||
TypeID::get<TF::CastOp>(),
|
||||
TypeID::get<TF::ClipByValueOp>(),
|
||||
TypeID::get<TF::ComplexAbsOp>(),
|
||||
@ -132,6 +135,8 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::IRFFTOp>(),
|
||||
TypeID::get<TF::InvertOp>(),
|
||||
TypeID::get<TF::InvOp>(),
|
||||
TypeID::get<TF::LRNOp>(),
|
||||
TypeID::get<TF::LRNGradOp>(),
|
||||
TypeID::get<TF::LeakyReluGradOp>(),
|
||||
TypeID::get<TF::LeakyReluOp>(),
|
||||
TypeID::get<TF::LeftShiftOp>(),
|
||||
|
@ -185,6 +185,7 @@ tf_xla_py_test(
|
||||
name = "argminmax_test",
|
||||
size = "small",
|
||||
srcs = ["argminmax_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -253,6 +254,7 @@ tf_xla_py_test(
|
||||
name = "bucketize_op_test",
|
||||
size = "small",
|
||||
srcs = ["bucketize_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -806,6 +808,7 @@ tf_xla_py_test(
|
||||
name = "lrn_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["lrn_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -57,6 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
|
||||
expected_out, sess.run(op,
|
||||
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testInvalidBoundariesOrder(self):
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtypes.int32)
|
||||
|
@ -2040,8 +2040,6 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
|
||||
computation.GetProgramShape());
|
||||
|
||||
@ -2060,6 +2058,17 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
|
||||
Shape shape,
|
||||
ShapeInference::InferReduceShape(
|
||||
operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
|
||||
return ReduceInternal(shape, all_operands, computation,
|
||||
dimensions_to_reduce);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
for (int64 dim : dimensions_to_reduce) {
|
||||
@ -2067,7 +2076,6 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
|
||||
}
|
||||
|
||||
AddCalledComputation(computation, &instr);
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
|
||||
});
|
||||
}
|
||||
@ -2110,28 +2118,35 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
|
||||
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
|
||||
computation.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
|
||||
TF_ASSIGN_OR_RETURN(auto window,
|
||||
ShapeInference::InferWindowFromDimensions(
|
||||
window_dimensions, window_strides, padding,
|
||||
/*lhs_dilation=*/base_dilations,
|
||||
/*rhs_dilation=*/window_dilations));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
|
||||
*operand_shape, *init_shape,
|
||||
instr.window(), to_apply_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
AddCalledComputation(computation, &instr);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
|
||||
{operand, init_value});
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferReduceWindowShape(
|
||||
*operand_shape, *init_shape, window, to_apply_shape));
|
||||
return ReduceWindowInternal(shape, operand, init_value, computation,
|
||||
std::move(window));
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation, Window window) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
*instr.mutable_window() = std::move(window);
|
||||
|
||||
AddCalledComputation(computation, &instr);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
|
||||
{operand, init_value});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
|
||||
float epsilon, int64 feature_index) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
@ -542,6 +542,11 @@ class XlaBuilder {
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce);
|
||||
|
||||
virtual StatusOr<XlaOp> ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce);
|
||||
|
||||
XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation);
|
||||
|
||||
@ -558,6 +563,10 @@ class XlaBuilder {
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding);
|
||||
|
||||
virtual StatusOr<XlaOp> ReduceWindowInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation, Window window);
|
||||
|
||||
XlaOp CrossReplicaSum(XlaOp operand,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user