Support Transpose op in MlirHloBuilder
Splits Transpose op in XlaBuilder in two parts so that MlirHloBuilder can override the internal method. This uses xla_ops_test.py for testing the Transpose op lowering in "xla-legalize-tf-with-tf2xla" pass. Unsupported tests in xla_ops_test.py are disabled for now. PiperOrigin-RevId: 306975198 Change-Id: I1ae68f6fcfb31ea043a169551e11e523d0c7c7ae
This commit is contained in:
parent
f3369e24bb
commit
78580b371b
@ -77,6 +77,15 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::TransposeOp>(
|
||||
loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
|
||||
XlaOp operand,
|
||||
int64 inferred_dimension) {
|
||||
|
@ -90,6 +90,10 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
private:
|
||||
XlaOp ConstantLiteral(const LiteralSlice& literal) override;
|
||||
|
||||
StatusOr<XlaOp> TransposeInternal(
|
||||
const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> permutation) override;
|
||||
|
||||
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension) override;
|
||||
|
||||
|
@ -83,7 +83,8 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) || isa<TF::LogOp>(op) ||
|
||||
isa<TF::LogicalNotOp>(op) || isa<TF::NegOp>(op) ||
|
||||
isa<TF::SelectV2Op>(op) || isa<TF::SinOp>(op) ||
|
||||
isa<TF::SquareOp>(op) || isa<TF::UnpackOp>(op);
|
||||
isa<TF::SquareOp>(op) || isa<TF::TransposeOp>(op) ||
|
||||
isa<TF::UnpackOp>(op);
|
||||
}
|
||||
|
||||
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
|
||||
|
@ -1769,6 +1769,7 @@ tf_xla_py_test(
|
||||
name = "xla_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["xla_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
@ -50,6 +51,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
equality_fn = self.assertAllClose
|
||||
equality_fn(result, expected, rtol=1e-3)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testAdd(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
@ -70,6 +72,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.array([7, 11], dtype=dtype)),
|
||||
expected=np.array([[8, 13], [10, 15]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testBroadcast(self):
|
||||
for dtype in self.numeric_types:
|
||||
v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
|
||||
@ -78,6 +81,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
args=(v,),
|
||||
expected=np.tile(v, (7, 42, 1, 1)))
|
||||
|
||||
@test_util.disable_mlir_bridge('Unsigned ints are not supported yet')
|
||||
def testShiftRightLogical(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
xla.shift_right_logical,
|
||||
@ -89,6 +93,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
|
||||
expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
|
||||
|
||||
@test_util.disable_mlir_bridge('Unsigned ints are not supported yet')
|
||||
def testShiftRightArithmetic(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
xla.shift_right_arithmetic,
|
||||
@ -105,6 +110,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
xla_data_pb2.PrecisionConfig.HIGHEST)
|
||||
|
||||
@parameterized.parameters(*PRECISION_VALUES)
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testConv(self, precision):
|
||||
for dtype in set(self.float_types).intersection(
|
||||
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
||||
@ -143,6 +149,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype))
|
||||
|
||||
@parameterized.parameters(*PRECISION_VALUES)
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDotGeneral(self, precision):
|
||||
for dtype in self.float_types:
|
||||
|
||||
@ -189,6 +196,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
args=(np.array([1, 2, 3], dtype=dtype),),
|
||||
expected=np.array([-1, -2, -3], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testPad(self):
|
||||
for dtype in self.numeric_types:
|
||||
|
||||
@ -208,6 +216,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testReduce(self):
|
||||
for dtype in set(self.numeric_types).intersection(
|
||||
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
||||
@ -258,6 +267,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
|
||||
expected=np.array([0, 45, 120, 231], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testSelectAndScatter(self):
|
||||
for dtype in set(self.numeric_types).intersection(
|
||||
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
||||
@ -299,6 +309,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDynamicSlice(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
@ -311,6 +322,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
[[673, 674], [683, 684], [693, 694]]]),
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
@ -324,6 +336,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
(r'start_indices must be a vector with length equal to input rank, '
|
||||
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
@ -340,6 +353,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDotDifferentNumberOfContractingDimensions(self):
|
||||
a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||
@ -354,6 +368,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
'dimensions for lhs and rhs. Got: 1 and 2'):
|
||||
xla.dot_general(a, b, dim_nums)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDotDifferentContractingDimensionsSizes(self):
|
||||
a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
|
||||
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||
@ -367,6 +382,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
'Got: 2 and 4'):
|
||||
xla.dot_general(a, b, dim_nums)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDotDifferentNumberOfBatchDimensions(self):
|
||||
a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||
@ -381,6 +397,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
'dimensions for lhs and rhs. Got: 1 and 2'):
|
||||
xla.dot_general(a, b, dim_nums)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDotDifferentBatchDimensionsSizes(self):
|
||||
a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
|
||||
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 2))
|
||||
@ -396,6 +413,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
'Got: 2 and 4'):
|
||||
xla.dot_general(a, b, dim_nums)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDotShapeInference(self):
|
||||
a = array_ops.placeholder(np.float32, shape=(1, 2, 3, 4))
|
||||
b = array_ops.placeholder(np.float32, shape=(4, 3, 2, 1))
|
||||
|
@ -1624,18 +1624,23 @@ XlaOp XlaBuilder::CustomCall(
|
||||
XlaOp XlaBuilder::Transpose(XlaOp operand,
|
||||
absl::Span<const int64> permutation) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
|
||||
*operand_shape, permutation));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int64 dim : permutation) {
|
||||
instr.add_dimensions(dim);
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
|
||||
return TransposeInternal(shape, operand, permutation);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::TransposeInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int64 dim : permutation) {
|
||||
instr.add_dimensions(dim);
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
@ -565,6 +565,8 @@ class XlaBuilder {
|
||||
XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
|
||||
|
||||
XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
|
||||
virtual StatusOr<XlaOp> TransposeInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
|
||||
|
||||
XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user