Define shape inference function fox XlaBroadcastHelper Op
TensorFlow doesn't define the shape inference function for XlaBroadcastHelper op and that is why this op and users of this op's results are failing to compile in the bridge with MLIR. Enabled relevant tests in compiler/tests. isCompatibleReturnTypes should be the same for all TensorFlow ops and it would be better to have that in a common extraClassDeclaration for TF_Op. I will look into this separately. PiperOrigin-RevId: 352687566 Change-Id: I7c6bcee985e0d03d2c5d4dee0f8abc3fc4e5a44b
This commit is contained in:
parent
2e64d50571
commit
88d947651d
@ -16254,7 +16254,7 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_XlaBroadcastHelperOp : TF_Op<"XlaBroadcastHelper", [NoSideEffect]> {
|
||||
def TF_XlaBroadcastHelperOp : TF_Op<"XlaBroadcastHelper", [DeclareOpInterfaceMethods<InferTypeOpInterface>, NoSideEffect]> {
|
||||
let summary = "Helper operator for performing XLA-style broadcasts";
|
||||
|
||||
let description = [{
|
||||
@ -16276,6 +16276,13 @@ for binary operators.
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// InferTypeOpInterface:
|
||||
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r) {
|
||||
return ArraysAreCastCompatible(l, r);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_XlaConvOp : TF_Op<"XlaConv", [NoSideEffect]> {
|
||||
|
@ -2969,6 +2969,92 @@ void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<XdivyWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XlaBroadcastHelperOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult XlaBroadcastHelperOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto loc = location ? *location : mlir::UnknownLoc::get(context);
|
||||
XlaBroadcastHelperOpAdaptor op(operands, attributes);
|
||||
if (failed(op.verify(loc))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
auto set_unranked_results = [&]() {
|
||||
auto unranked_lhs = UnrankedTensorType::get(getElementTypeOrSelf(lhs));
|
||||
inferredReturnTypes.push_back(unranked_lhs);
|
||||
auto unranked_rhs = UnrankedTensorType::get(getElementTypeOrSelf(rhs));
|
||||
inferredReturnTypes.push_back(unranked_rhs);
|
||||
return success();
|
||||
};
|
||||
|
||||
RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
if (!lhs_ty || !rhs_ty) return set_unranked_results();
|
||||
|
||||
int64_t lhs_rank = lhs_ty.getRank();
|
||||
int64_t rhs_rank = rhs_ty.getRank();
|
||||
|
||||
DenseIntElementsAttr dims;
|
||||
if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) {
|
||||
return set_unranked_results();
|
||||
}
|
||||
|
||||
if (dims.size() == 0) {
|
||||
if (lhs_rank != rhs_rank && lhs_rank != 0 && rhs_rank != 0) {
|
||||
return emitOptionalError(
|
||||
location,
|
||||
"if broadcast_dims is empty, both arguments must have equal rank or "
|
||||
"at least one argument must be a scalar");
|
||||
}
|
||||
inferredReturnTypes.push_back(lhs_ty);
|
||||
inferredReturnTypes.push_back(rhs_ty);
|
||||
return success();
|
||||
}
|
||||
|
||||
const bool broadcast_lhs = lhs_rank < rhs_rank;
|
||||
RankedTensorType min_rank_ty = broadcast_lhs ? lhs_ty : rhs_ty;
|
||||
RankedTensorType max_rank_ty = broadcast_lhs ? rhs_ty : lhs_ty;
|
||||
|
||||
if (dims.size() != min_rank_ty.getRank()) {
|
||||
return emitOptionalError(
|
||||
location,
|
||||
"broadcast_dims must have size equal to the smaller argument rank");
|
||||
}
|
||||
|
||||
int64_t output_rank = max_rank_ty.getRank();
|
||||
llvm::SmallVector<int64_t, 4> broadcast_shape(output_rank, 1LL);
|
||||
llvm::SmallVector<bool, 4> is_broadcasted(output_rank, false);
|
||||
for (auto item : llvm::enumerate(dims)) {
|
||||
int64_t index = item.index();
|
||||
int64_t dim = item.value().getSExtValue();
|
||||
if (dim < 0 || dim > output_rank) {
|
||||
return emitOptionalError(location, "out of range broadcast dim");
|
||||
}
|
||||
if (is_broadcasted[dim]) {
|
||||
return emitOptionalError(location, "broadcast_dims has duplicates");
|
||||
}
|
||||
broadcast_shape[dim] = min_rank_ty.getDimSize(index);
|
||||
is_broadcasted[dim] = true;
|
||||
}
|
||||
|
||||
if (broadcast_lhs) {
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(broadcast_shape, lhs_ty.getElementType()));
|
||||
inferredReturnTypes.push_back(rhs_ty);
|
||||
} else {
|
||||
inferredReturnTypes.push_back(lhs_ty);
|
||||
inferredReturnTypes.push_back(
|
||||
RankedTensorType::get(broadcast_shape, rhs_ty.getElementType()));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XlaSetDynamicDimensionSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -4159,3 +4159,54 @@ func @testVarHandleOp() -> tensor<*x!tf.resource> {
|
||||
} : () -> tensor<*x!tf.resource>
|
||||
return %0 : tensor<*x!tf.resource>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () {
|
||||
%0 = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64>
|
||||
// expected-error @+1 {{broadcast_dims must have size equal to the smaller argument rank}}
|
||||
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<1xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () {
|
||||
%0 = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
// expected-error @+1 {{if broadcast_dims is empty, both arguments must have equal rank or at least one argument must be a scalar}}
|
||||
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<0xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testXlaBroadcastHelper(%arg0: tensor<5x2xi32>, %arg1: tensor<2x3x5xi32>) -> () {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// expected-error @+1 {{broadcast_dims has duplicates}}
|
||||
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<5x2xi32>, tensor<2x3x5xi32>, tensor<2xi64>) -> (tensor<2x1x5xi32>, tensor<2x3x5xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testXlaBroadcastHelper(%arg0: tensor<2xi32>, %arg1: tensor<i32>) -> () {
|
||||
%0 = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2xi32>, tensor<i32>, tensor<0xi64>) -> (tensor<2xi32>, tensor<i32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testXlaBroadcastHelper(%arg0: tensor<5x2xi32>, %arg1: tensor<2x3x5xi32>) -> () {
|
||||
%0 = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<5x2xi32>, tensor<2x3x5xi32>, tensor<2xi64>) -> (tensor<2x1x5xi32>, tensor<2x3x5xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testXlaBroadcastHelper(%arg0: tensor<2x3x5xi32>, %arg1: tensor<5x2xi32>) -> () {
|
||||
%0 = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%lhs_output, %rhs_output = "tf.XlaBroadcastHelper"(%arg0, %arg1, %0) : (tensor<2x3x5xi32>, tensor<5x2xi32>, tensor<2xi64>) -> (tensor<2x3x5xi32>, tensor<2x1x5xi32>)
|
||||
return
|
||||
}
|
||||
|
@ -101,7 +101,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
args=(operand, start_indices),
|
||||
expected=np.array([[5, 6, 7]]))
|
||||
|
||||
@test_util.disable_mlir_bridge('Dynamic result types not supported')
|
||||
def testShiftRightLogical(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
xla.shift_right_logical,
|
||||
@ -113,7 +112,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
|
||||
expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
|
||||
|
||||
@test_util.disable_mlir_bridge('Dynamic result types not supported')
|
||||
def testShiftRightArithmetic(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
xla.shift_right_arithmetic,
|
||||
|
Loading…
x
Reference in New Issue
Block a user