Enable fallback legalization for MaxPoolGradGrad and MaxPool3DGradGrad ops
Requires, * Override for ReducePrecision in HloMlirBuilder * Sinking of constants for ReduceWindow op PiperOrigin-RevId: 342330848 Change-Id: I5d9793c3700b7bf4f338565af98ab5262751632b
This commit is contained in:
parent
a968a52db7
commit
19e6114760
@ -50,6 +50,8 @@ class SinkConstantsToControlFlowPass
|
||||
} else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
|
||||
SinkToRegion(&if_op.true_branch());
|
||||
SinkToRegion(&if_op.false_branch());
|
||||
} else if (auto reduce_window_op = llvm::dyn_cast<ReduceWindowOp>(op)) {
|
||||
SinkToRegion(&reduce_window_op.body());
|
||||
} else if (auto sort_op = llvm::dyn_cast<SortOp>(op)) {
|
||||
SinkToRegion(&sort_op.comparator());
|
||||
}
|
||||
|
@ -273,6 +273,17 @@ StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
|
||||
const Shape& shape, XlaOp operand, const int exponent_bits,
|
||||
const int mantissa_bits) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
|
||||
loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
|
||||
builder_.getI32IntegerAttr(mantissa_bits));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
|
@ -171,6 +171,10 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
const XlaComputation& body,
|
||||
XlaOp init) override;
|
||||
|
||||
StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape, XlaOp operand,
|
||||
const int exponent_bits,
|
||||
const int mantissa_bits) override;
|
||||
|
||||
StatusOr<XlaOp> GatherInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
|
@ -175,6 +175,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::MatrixSetDiagV3Op>(),
|
||||
TypeID::get<TF::MatrixSolveOp>(),
|
||||
TypeID::get<TF::MatrixTriangularSolveOp>(),
|
||||
TypeID::get<TF::MaxPool3DGradGradOp>(),
|
||||
TypeID::get<TF::MaxPoolGradGradOp>(),
|
||||
TypeID::get<TF::MirrorPadOp>(),
|
||||
TypeID::get<TF::MirrorPadGradOp>(),
|
||||
TypeID::get<TF::MulOp>(),
|
||||
|
@ -23,7 +23,6 @@ import numpy as np
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -155,8 +154,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
||||
padding="SAME",
|
||||
expected=expected_output.flatten())
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testKernelSmallerThanStride(self):
|
||||
self._VerifyValues(
|
||||
nn_ops.max_pool3d,
|
||||
@ -314,8 +311,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
||||
atol=1e-6)
|
||||
self.assertShapeEqual(actual_grad_gradients_vals, outputs)
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testMaxPoolGradValidPadding1_1_3d(self):
|
||||
self._VerifyGradient(
|
||||
nn_ops.max_pool3d,
|
||||
@ -326,8 +321,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
||||
padding="VALID",
|
||||
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testMaxPoolGradValidPadding2_1_6_3d(self):
|
||||
self._VerifyGradient(
|
||||
nn_ops.max_pool3d,
|
||||
@ -350,8 +343,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
||||
strides=[1, 1, 1],
|
||||
padding="VALID")
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testMaxPoolGradValidPadding2_2_3d(self):
|
||||
self._VerifyGradient(
|
||||
nn_ops.max_pool3d,
|
||||
@ -362,8 +353,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
||||
padding="VALID",
|
||||
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testMaxPoolGradSamePadding1_1_3d(self):
|
||||
self._VerifyGradient(
|
||||
nn_ops.max_pool3d,
|
||||
@ -374,8 +363,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
||||
padding="SAME",
|
||||
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testMaxPoolGradSamePadding2_1_3d(self):
|
||||
self._VerifyGradient(
|
||||
nn_ops.max_pool3d,
|
||||
@ -386,8 +373,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
||||
padding="SAME",
|
||||
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testMaxPoolGradSamePadding2_2_3d(self):
|
||||
self._VerifyGradient(
|
||||
nn_ops.max_pool3d,
|
||||
@ -398,8 +383,6 @@ class Pooling3DTest(xla_test.XLATestCase):
|
||||
padding="SAME",
|
||||
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testMaxPoolGradSamePadding3_1_3d(self):
|
||||
self._VerifyGradient(
|
||||
nn_ops.max_pool3d,
|
||||
|
@ -23,7 +23,6 @@ import numpy as np
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -543,20 +542,12 @@ class PoolGradTest(xla_test.XLATestCase):
|
||||
padding="SAME",
|
||||
pool_grad_grad_func=pool_grad_grad_func)
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
|
||||
"MaxPoolGradGrad op in MLIR-based bridge")
|
||||
def testMaxPool(self):
|
||||
self._TestPooling(
|
||||
nn_ops.max_pool,
|
||||
gen_nn_ops.max_pool_grad,
|
||||
pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad)
|
||||
|
||||
# TODO(b/159845178): Remove this once MLIR bridge supports MaxPoolGradGrad
|
||||
# (then `testMaxPool` test will be sufficient)
|
||||
def testMaxPoolNoGradGrad(self):
|
||||
self._TestPooling(
|
||||
nn_ops.max_pool, gen_nn_ops.max_pool_grad, pool_grad_grad_func=None)
|
||||
|
||||
def testAvgPool(self):
|
||||
# Wrapper around AvgPoolGrad that ignores extra arguments needed by
|
||||
# MaxPoolGrad.
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -3002,19 +3003,27 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
|
||||
XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
|
||||
const int mantissa_bits) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferReducePrecisionShape(
|
||||
*operand_shape, exponent_bits, mantissa_bits));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_exponent_bits(exponent_bits);
|
||||
instr.set_mantissa_bits(mantissa_bits);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
|
||||
{operand});
|
||||
return ReducePrecisionInternal(shape, operand, exponent_bits,
|
||||
mantissa_bits);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
|
||||
XlaOp operand,
|
||||
const int exponent_bits,
|
||||
const int mantissa_bits) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_exponent_bits(exponent_bits);
|
||||
instr.set_mantissa_bits(mantissa_bits);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
|
||||
{operand});
|
||||
}
|
||||
|
||||
void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
|
||||
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
// Send HLO takes two operands: a data operand and a token. Generate the
|
||||
|
@ -819,6 +819,10 @@ class XlaBuilder {
|
||||
|
||||
XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
|
||||
const int mantissa_bits);
|
||||
virtual StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape,
|
||||
XlaOp operand,
|
||||
const int exponent_bits,
|
||||
const int mantissa_bits);
|
||||
|
||||
XlaOp Gather(XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
|
Loading…
Reference in New Issue
Block a user