Support Transpose op in MlirHloBuilder
Splits Transpose op in XlaBuilder in two parts so that MlirHloBuilder can override the internal method. This uses xla_ops_test.py for testing the Transpose op lowering in "xla-legalize-tf-with-tf2xla" pass. Unsupported tests in xla_ops_test.py are disabled for now. PiperOrigin-RevId: 306975198 Change-Id: I1ae68f6fcfb31ea043a169551e11e523d0c7c7ae
This commit is contained in:
parent
f3369e24bb
commit
78580b371b
@ -77,6 +77,15 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
|
||||||
|
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
|
||||||
|
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||||
|
shape, builder_));
|
||||||
|
auto op = builder_.create<mlir::xla_hlo::TransposeOp>(
|
||||||
|
loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
|
||||||
|
return MakeXlaOp(op);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
|
StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
|
||||||
XlaOp operand,
|
XlaOp operand,
|
||||||
int64 inferred_dimension) {
|
int64 inferred_dimension) {
|
||||||
|
@ -90,6 +90,10 @@ class MlirHloBuilder : public XlaBuilder {
|
|||||||
private:
|
private:
|
||||||
XlaOp ConstantLiteral(const LiteralSlice& literal) override;
|
XlaOp ConstantLiteral(const LiteralSlice& literal) override;
|
||||||
|
|
||||||
|
StatusOr<XlaOp> TransposeInternal(
|
||||||
|
const Shape& shape, XlaOp operand,
|
||||||
|
absl::Span<const int64> permutation) override;
|
||||||
|
|
||||||
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
|
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
|
||||||
int64 inferred_dimension) override;
|
int64 inferred_dimension) override;
|
||||||
|
|
||||||
|
@ -83,7 +83,8 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||||||
isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) || isa<TF::LogOp>(op) ||
|
isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) || isa<TF::LogOp>(op) ||
|
||||||
isa<TF::LogicalNotOp>(op) || isa<TF::NegOp>(op) ||
|
isa<TF::LogicalNotOp>(op) || isa<TF::NegOp>(op) ||
|
||||||
isa<TF::SelectV2Op>(op) || isa<TF::SinOp>(op) ||
|
isa<TF::SelectV2Op>(op) || isa<TF::SinOp>(op) ||
|
||||||
isa<TF::SquareOp>(op) || isa<TF::UnpackOp>(op);
|
isa<TF::SquareOp>(op) || isa<TF::TransposeOp>(op) ||
|
||||||
|
isa<TF::UnpackOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
|
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
|
||||||
|
@ -1769,6 +1769,7 @@ tf_xla_py_test(
|
|||||||
name = "xla_ops_test",
|
name = "xla_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["xla_ops_test.py"],
|
srcs = ["xla_ops_test.py"],
|
||||||
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
@ -50,6 +51,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
equality_fn = self.assertAllClose
|
equality_fn = self.assertAllClose
|
||||||
equality_fn(result, expected, rtol=1e-3)
|
equality_fn(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testAdd(self):
|
def testAdd(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
@ -70,6 +72,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
np.array([7, 11], dtype=dtype)),
|
np.array([7, 11], dtype=dtype)),
|
||||||
expected=np.array([[8, 13], [10, 15]], dtype=dtype))
|
expected=np.array([[8, 13], [10, 15]], dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testBroadcast(self):
|
def testBroadcast(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
|
v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
|
||||||
@ -78,6 +81,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
args=(v,),
|
args=(v,),
|
||||||
expected=np.tile(v, (7, 42, 1, 1)))
|
expected=np.tile(v, (7, 42, 1, 1)))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Unsigned ints are not supported yet')
|
||||||
def testShiftRightLogical(self):
|
def testShiftRightLogical(self):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
xla.shift_right_logical,
|
xla.shift_right_logical,
|
||||||
@ -89,6 +93,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
|
args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
|
||||||
expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
|
expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Unsigned ints are not supported yet')
|
||||||
def testShiftRightArithmetic(self):
|
def testShiftRightArithmetic(self):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
xla.shift_right_arithmetic,
|
xla.shift_right_arithmetic,
|
||||||
@ -105,6 +110,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
xla_data_pb2.PrecisionConfig.HIGHEST)
|
xla_data_pb2.PrecisionConfig.HIGHEST)
|
||||||
|
|
||||||
@parameterized.parameters(*PRECISION_VALUES)
|
@parameterized.parameters(*PRECISION_VALUES)
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testConv(self, precision):
|
def testConv(self, precision):
|
||||||
for dtype in set(self.float_types).intersection(
|
for dtype in set(self.float_types).intersection(
|
||||||
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
||||||
@ -143,6 +149,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype))
|
expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype))
|
||||||
|
|
||||||
@parameterized.parameters(*PRECISION_VALUES)
|
@parameterized.parameters(*PRECISION_VALUES)
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDotGeneral(self, precision):
|
def testDotGeneral(self, precision):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
|
|
||||||
@ -189,6 +196,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
args=(np.array([1, 2, 3], dtype=dtype),),
|
args=(np.array([1, 2, 3], dtype=dtype),),
|
||||||
expected=np.array([-1, -2, -3], dtype=dtype))
|
expected=np.array([-1, -2, -3], dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testPad(self):
|
def testPad(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
|
|
||||||
@ -208,6 +216,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
|
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testReduce(self):
|
def testReduce(self):
|
||||||
for dtype in set(self.numeric_types).intersection(
|
for dtype in set(self.numeric_types).intersection(
|
||||||
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
||||||
@ -258,6 +267,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
|
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
|
||||||
expected=np.array([0, 45, 120, 231], dtype=dtype))
|
expected=np.array([0, 45, 120, 231], dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testSelectAndScatter(self):
|
def testSelectAndScatter(self):
|
||||||
for dtype in set(self.numeric_types).intersection(
|
for dtype in set(self.numeric_types).intersection(
|
||||||
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
|
||||||
@ -299,6 +309,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDynamicSlice(self):
|
def testDynamicSlice(self):
|
||||||
for dtype in self.numeric_types:
|
for dtype in self.numeric_types:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
@ -311,6 +322,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
[[673, 674], [683, 684], [693, 694]]]),
|
[[673, 674], [683, 684], [693, 694]]]),
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
||||||
with self.session() as session:
|
with self.session() as session:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
@ -324,6 +336,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
(r'start_indices must be a vector with length equal to input rank, '
|
(r'start_indices must be a vector with length equal to input rank, '
|
||||||
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
||||||
with self.session() as session:
|
with self.session() as session:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
@ -340,6 +353,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDotDifferentNumberOfContractingDimensions(self):
|
def testDotDifferentNumberOfContractingDimensions(self):
|
||||||
a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||||
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||||
@ -354,6 +368,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
'dimensions for lhs and rhs. Got: 1 and 2'):
|
'dimensions for lhs and rhs. Got: 1 and 2'):
|
||||||
xla.dot_general(a, b, dim_nums)
|
xla.dot_general(a, b, dim_nums)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDotDifferentContractingDimensionsSizes(self):
|
def testDotDifferentContractingDimensionsSizes(self):
|
||||||
a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
|
a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
|
||||||
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||||
@ -367,6 +382,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
'Got: 2 and 4'):
|
'Got: 2 and 4'):
|
||||||
xla.dot_general(a, b, dim_nums)
|
xla.dot_general(a, b, dim_nums)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDotDifferentNumberOfBatchDimensions(self):
|
def testDotDifferentNumberOfBatchDimensions(self):
|
||||||
a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||||
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
|
||||||
@ -381,6 +397,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
'dimensions for lhs and rhs. Got: 1 and 2'):
|
'dimensions for lhs and rhs. Got: 1 and 2'):
|
||||||
xla.dot_general(a, b, dim_nums)
|
xla.dot_general(a, b, dim_nums)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDotDifferentBatchDimensionsSizes(self):
|
def testDotDifferentBatchDimensionsSizes(self):
|
||||||
a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
|
a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
|
||||||
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 2))
|
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 2))
|
||||||
@ -396,6 +413,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
'Got: 2 and 4'):
|
'Got: 2 and 4'):
|
||||||
xla.dot_general(a, b, dim_nums)
|
xla.dot_general(a, b, dim_nums)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testDotShapeInference(self):
|
def testDotShapeInference(self):
|
||||||
a = array_ops.placeholder(np.float32, shape=(1, 2, 3, 4))
|
a = array_ops.placeholder(np.float32, shape=(1, 2, 3, 4))
|
||||||
b = array_ops.placeholder(np.float32, shape=(4, 3, 2, 1))
|
b = array_ops.placeholder(np.float32, shape=(4, 3, 2, 1))
|
||||||
|
@ -1624,18 +1624,23 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
XlaOp XlaBuilder::Transpose(XlaOp operand,
|
XlaOp XlaBuilder::Transpose(XlaOp operand,
|
||||||
absl::Span<const int64> permutation) {
|
absl::Span<const int64> permutation) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
|
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
|
||||||
*operand_shape, permutation));
|
*operand_shape, permutation));
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
return TransposeInternal(shape, operand, permutation);
|
||||||
for (int64 dim : permutation) {
|
|
||||||
instr.add_dimensions(dim);
|
|
||||||
}
|
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> XlaBuilder::TransposeInternal(
|
||||||
|
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
|
||||||
|
HloInstructionProto instr;
|
||||||
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
|
for (int64 dim : permutation) {
|
||||||
|
instr.add_dimensions(dim);
|
||||||
|
}
|
||||||
|
return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
|
XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
|
@ -565,6 +565,8 @@ class XlaBuilder {
|
|||||||
XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
|
XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
|
||||||
|
|
||||||
XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
|
XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
|
||||||
|
virtual StatusOr<XlaOp> TransposeInternal(
|
||||||
|
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
|
||||||
|
|
||||||
XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
|
XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user