Enable fallback legalization for MaxPoolGradGrad and MaxPool3DGradGrad ops

Requires,
* Override for ReducePrecision in HloMlirBuilder
* Sinking of constants for ReduceWindow op

PiperOrigin-RevId: 342330848
Change-Id: I5d9793c3700b7bf4f338565af98ab5262751632b
This commit is contained in:
Smit Hinsu 2020-11-13 13:48:17 -08:00 committed by TensorFlower Gardener
parent a968a52db7
commit 19e6114760
8 changed files with 38 additions and 32 deletions

View File

@ -50,6 +50,8 @@ class SinkConstantsToControlFlowPass
} else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
SinkToRegion(&if_op.true_branch());
SinkToRegion(&if_op.false_branch());
} else if (auto reduce_window_op = llvm::dyn_cast<ReduceWindowOp>(op)) {
SinkToRegion(&reduce_window_op.body());
} else if (auto sort_op = llvm::dyn_cast<SortOp>(op)) {
SinkToRegion(&sort_op.comparator());
}

View File

@ -273,6 +273,17 @@ StatusOr<XlaOp> MlirHloBuilder::WhileInternal(const Shape& shape,
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ReducePrecisionInternal(
const Shape& shape, XlaOp operand, const int exponent_bits,
const int mantissa_bits) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::ReducePrecisionOp>(
loc_, ty, GetValue(operand), builder_.getI32IntegerAttr(exponent_bits),
builder_.getI32IntegerAttr(mantissa_bits));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,

View File

@ -171,6 +171,10 @@ class MlirHloBuilder : public XlaBuilder {
const XlaComputation& body,
XlaOp init) override;
StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape, XlaOp operand,
const int exponent_bits,
const int mantissa_bits) override;
StatusOr<XlaOp> GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,

View File

@ -175,6 +175,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::MatrixSetDiagV3Op>(),
TypeID::get<TF::MatrixSolveOp>(),
TypeID::get<TF::MatrixTriangularSolveOp>(),
TypeID::get<TF::MaxPool3DGradGradOp>(),
TypeID::get<TF::MaxPoolGradGradOp>(),
TypeID::get<TF::MirrorPadOp>(),
TypeID::get<TF::MirrorPadGradOp>(),
TypeID::get<TF::MulOp>(),

View File

@ -23,7 +23,6 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
@ -155,8 +154,6 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
expected=expected_output.flatten())
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testKernelSmallerThanStride(self):
self._VerifyValues(
nn_ops.max_pool3d,
@ -314,8 +311,6 @@ class Pooling3DTest(xla_test.XLATestCase):
atol=1e-6)
self.assertShapeEqual(actual_grad_gradients_vals, outputs)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradValidPadding1_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -326,8 +321,6 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="VALID",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradValidPadding2_1_6_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -350,8 +343,6 @@ class Pooling3DTest(xla_test.XLATestCase):
strides=[1, 1, 1],
padding="VALID")
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradValidPadding2_2_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -362,8 +353,6 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="VALID",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradSamePadding1_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -374,8 +363,6 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradSamePadding2_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -386,8 +373,6 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradSamePadding2_2_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@ -398,8 +383,6 @@ class Pooling3DTest(xla_test.XLATestCase):
padding="SAME",
pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPoolGradSamePadding3_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,

View File

@ -23,7 +23,6 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
@ -543,20 +542,12 @@ class PoolGradTest(xla_test.XLATestCase):
padding="SAME",
pool_grad_grad_func=pool_grad_grad_func)
@test_util.disable_mlir_bridge("TODO(b/159845178): Implement support for "
"MaxPoolGradGrad op in MLIR-based bridge")
def testMaxPool(self):
self._TestPooling(
nn_ops.max_pool,
gen_nn_ops.max_pool_grad,
pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad)
# TODO(b/159845178): Remove this once MLIR bridge supports MaxPoolGradGrad
# (then `testMaxPool` test will be sufficient)
def testMaxPoolNoGradGrad(self):
self._TestPooling(
nn_ops.max_pool, gen_nn_ops.max_pool_grad, pool_grad_grad_func=None)
def testAvgPool(self):
# Wrapper around AvgPoolGrad that ignores extra arguments needed by
# MaxPoolGrad.

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -3002,19 +3003,27 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
const int mantissa_bits) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferReducePrecisionShape(
*operand_shape, exponent_bits, mantissa_bits));
*instr.mutable_shape() = shape.ToProto();
instr.set_exponent_bits(exponent_bits);
instr.set_mantissa_bits(mantissa_bits);
return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
{operand});
return ReducePrecisionInternal(shape, operand, exponent_bits,
mantissa_bits);
});
}
StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
XlaOp operand,
const int exponent_bits,
const int mantissa_bits) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_exponent_bits(exponent_bits);
instr.set_mantissa_bits(mantissa_bits);
return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
{operand});
}
void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Send HLO takes two operands: a data operand and a token. Generate the

View File

@ -819,6 +819,10 @@ class XlaBuilder {
XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
const int mantissa_bits);
virtual StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape,
XlaOp operand,
const int exponent_bits,
const int mantissa_bits);
XlaOp Gather(XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,