Define shape inference function fox XlaBroadcastHelper Op

TensorFlow doesn't define the shape inference function for XlaBroadcastHelper op and that is why this op and users of this op's results are failing to compile in the bridge with MLIR.

Enabled relevant tests in compiler/tests.

isCompatibleReturnTypes should be the same for all TensorFlow ops and it would be better to have that in a common extraClassDeclaration for TF_Op. I will look into this separately.

PiperOrigin-RevId: 352687566
Change-Id: I7c6bcee985e0d03d2c5d4dee0f8abc3fc4e5a44b
This commit is contained in:
Smit Hinsu 2021-01-19 17:17:00 -08:00 committed by TensorFlower Gardener
parent 2e64d50571
commit 88d947651d
4 changed files with 145 additions and 3 deletions

View File

@ -16254,7 +16254,7 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
let hasCanonicalizer = 1;
}
def TF_XlaBroadcastHelperOp : TF_Op<"XlaBroadcastHelper", [NoSideEffect]> {
def TF_XlaBroadcastHelperOp : TF_Op<"XlaBroadcastHelper", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect]> {
let summary = "Helper operator for performing XLA-style broadcasts";
let description = [{
@ -16276,6 +16276,13 @@ for binary operators.
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let extraClassDeclaration = [{
// InferTypeOpInterface:
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
return ArraysAreCastCompatible(l, r);
}
}];
}
def TF_XlaConvOp : TF_Op<"XlaConv", [NoSideEffect]> {

View File

@ -2969,6 +2969,92 @@ void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<XdivyWithSqrtDivisor>(context);
}
//===----------------------------------------------------------------------===//
// XlaBroadcastHelperOp
//===----------------------------------------------------------------------===//
LogicalResult XlaBroadcastHelperOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto loc = location ? *location : mlir::UnknownLoc::get(context);
XlaBroadcastHelperOpAdaptor op(operands, attributes);
if (failed(op.verify(loc))) {
return failure();
}
Value lhs = op.lhs();
Value rhs = op.rhs();
auto set_unranked_results = [&]() {
auto unranked_lhs = UnrankedTensorType::get(getElementTypeOrSelf(lhs));
inferredReturnTypes.push_back(unranked_lhs);
auto unranked_rhs = UnrankedTensorType::get(getElementTypeOrSelf(rhs));
inferredReturnTypes.push_back(unranked_rhs);
return success();
};
RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_ty || !rhs_ty) return set_unranked_results();
int64_t lhs_rank = lhs_ty.getRank();
int64_t rhs_rank = rhs_ty.getRank();
DenseIntElementsAttr dims;
if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
return set_unranked_results();
}
if (dims.size() == 0) {
if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
return emitOptionalError(
location,
"if broadcast_dims is empty, both arguments must have equal rank or "
"at least one argument must be a scalar");
}
inferredReturnTypes.push_back(lhs_ty);
inferredReturnTypes.push_back(rhs_ty);
return success();
}
const bool broadcast_lhs = lhs_rank < rhs_rank;
RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
if (dims.size() != min_rank_ty.getRank()) {
return emitOptionalError(
location,
"broadcast_dims must have size equal to the smaller argument rank");
}
int64_t output_rank = max_rank_ty.getRank();
llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
for (auto item : llvm::enumerate(dims)) {
int64_t index = item.index();
int64_t dim = item.value().getSExtValue();
if (dim < 0 || dim > output_rank) {
return emitOptionalError(location, "out of range broadcast dim");
}
if (is_broadcasted[dim]) {
return emitOptionalError(location, "broadcast_dims has duplicates");
}
broadcast_shape[dim] = min_rank_ty.getDimSize(index);
is_broadcasted[dim] = true;
}
if (broadcast_lhs) {
inferredReturnTypes.push_back(
RankedTensorType::get(broadcast_shape, lhs_ty.getElementType()));
inferredReturnTypes.push_back(rhs_ty);
} else {
inferredReturnTypes.push_back(lhs_ty);
inferredReturnTypes.push_back(
RankedTensorType::get(broadcast_shape, rhs_ty.getElementType()));
}
return success();
}
//===----------------------------------------------------------------------===//
// XlaSetDynamicDimensionSizeOp
//===----------------------------------------------------------------------===//

View File

@ -4159,3 +4159,54 @@ func @testVarHandleOp() -> tensor<*x!tf.resource> {
} : () -> tensor<*x!tf.resource>
return %0 : tensor<*x!tf.resource>
}
// -----
func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () {
%0 = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
// expected-error @+1 {{broadcast_dims must have size equal to the smaller argument rank}}
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<1xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>)
return
}
// -----
func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () {
%0 = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64>
// expected-error @+1 {{if broadcast_dims is empty, both arguments must have equal rank or at least one argument must be a scalar}}
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<0xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>)
return
}
// -----
func @testXlaBroadcastHelper(%arg0: tensor<5x2xi32>, %arg1: tensor<2x3x5xi32>) -> () {
%0 = "tf.Const"() {value = dense<0> : tensor<2xi64>} : () -> tensor<2xi64>
// expected-error @+1 {{broadcast_dims has duplicates}}
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<5x2xi32>, tensor<2x3x5xi32>, tensor<2xi64>) -> (tensor<2x1x5xi32>, tensor<2x3x5xi32>)
return
}
// -----
func @testXlaBroadcastHelper(%arg0: tensor<2xi32>, %arg1: tensor<i32>) -> () {
%0 = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64>
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2xi32>, tensor<i32>, tensor<0xi64>) -> (tensor<2xi32>, tensor<i32>)
return
}
// -----
func @testXlaBroadcastHelper(%arg0: tensor<5x2xi32>, %arg1: tensor<2x3x5xi32>) -> () {
%0 = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<5x2xi32>, tensor<2x3x5xi32>, tensor<2xi64>) -> (tensor<2x1x5xi32>, tensor<2x3x5xi32>)
return
}
// -----
func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () {
%0 = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<2xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>)
return
}

View File

@ -101,7 +101,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
args=(operand, start_indices),
expected=np.array([[5, 6, 7]]))
@test_util.disable_mlir_bridge('Dynamic result types not supported')
def testShiftRightLogical(self):
self._assertOpOutputMatchesExpected(
xla.shift_right_logical,
@ -113,7 +112,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
@test_util.disable_mlir_bridge('Dynamic result types not supported')
def testShiftRightArithmetic(self):
self._assertOpOutputMatchesExpected(
xla.shift_right_arithmetic,