Support Transpose op in MlirHloBuilder

Splits Transpose op in XlaBuilder in two parts so that MlirHloBuilder can override the internal method.

This uses xla_ops_test.py for testing the Transpose op lowering in "xla-legalize-tf-with-tf2xla" pass. Unsupported tests in xla_ops_test.py are disabled for now.

PiperOrigin-RevId: 306975198
Change-Id: I1ae68f6fcfb31ea043a169551e11e523d0c7c7ae
This commit is contained in:
Smit Hinsu 2020-04-16 20:37:06 -07:00 committed by TensorFlower Gardener
parent f3369e24bb
commit 78580b371b
7 changed files with 47 additions and 7 deletions

View File

@ -77,6 +77,15 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
});
}
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::TransposeOp>(
loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
XlaOp operand,
int64 inferred_dimension) {

View File

@ -90,6 +90,10 @@ class MlirHloBuilder : public XlaBuilder {
private:
XlaOp ConstantLiteral(const LiteralSlice& literal) override;
StatusOr<XlaOp> TransposeInternal(
const Shape& shape, XlaOp operand,
absl::Span<const int64> permutation) override;
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
int64 inferred_dimension) override;

View File

@ -83,7 +83,8 @@ static bool IsOpWhitelisted(Operation* op) {
isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) || isa<TF::LogOp>(op) ||
isa<TF::LogicalNotOp>(op) || isa<TF::NegOp>(op) ||
isa<TF::SelectV2Op>(op) || isa<TF::SinOp>(op) ||
isa<TF::SquareOp>(op) || isa<TF::UnpackOp>(op);
isa<TF::SquareOp>(op) || isa<TF::TransposeOp>(op) ||
isa<TF::UnpackOp>(op);
}
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(

View File

@ -1769,6 +1769,7 @@ tf_xla_py_test(
name = "xla_ops_test",
size = "medium",
srcs = ["xla_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
@ -50,6 +51,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
equality_fn = self.assertAllClose
equality_fn(result, expected, rtol=1e-3)
@test_util.disable_mlir_bridge('Not supported yet')
def testAdd(self):
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
@ -70,6 +72,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
np.array([7, 11], dtype=dtype)),
expected=np.array([[8, 13], [10, 15]], dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
def testBroadcast(self):
for dtype in self.numeric_types:
v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
@ -78,6 +81,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
args=(v,),
expected=np.tile(v, (7, 42, 1, 1)))
@test_util.disable_mlir_bridge('Unsigned ints are not supported yet')
def testShiftRightLogical(self):
self._assertOpOutputMatchesExpected(
xla.shift_right_logical,
@ -89,6 +93,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32))
@test_util.disable_mlir_bridge('Unsigned ints are not supported yet')
def testShiftRightArithmetic(self):
self._assertOpOutputMatchesExpected(
xla.shift_right_arithmetic,
@ -105,6 +110,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
xla_data_pb2.PrecisionConfig.HIGHEST)
@parameterized.parameters(*PRECISION_VALUES)
@test_util.disable_mlir_bridge('Not supported yet')
def testConv(self, precision):
for dtype in set(self.float_types).intersection(
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
@ -143,6 +149,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype))
@parameterized.parameters(*PRECISION_VALUES)
@test_util.disable_mlir_bridge('Not supported yet')
def testDotGeneral(self, precision):
for dtype in self.float_types:
@ -189,6 +196,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
args=(np.array([1, 2, 3], dtype=dtype),),
expected=np.array([-1, -2, -3], dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
def testPad(self):
for dtype in self.numeric_types:
@ -208,6 +216,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
def testReduce(self):
for dtype in set(self.numeric_types).intersection(
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
@ -258,6 +267,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),),
expected=np.array([0, 45, 120, 231], dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
def testSelectAndScatter(self):
for dtype in set(self.numeric_types).intersection(
set([dtypes.bfloat16.as_numpy_dtype, np.float32])):
@ -299,6 +309,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected(
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
@test_util.disable_mlir_bridge('Not supported yet')
def testDynamicSlice(self):
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
@ -311,6 +322,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
[[673, 674], [683, 684], [693, 694]]]),
dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
def testDynamicSliceWithIncorrectStartIndicesShape(self):
with self.session() as session:
with self.test_scope():
@ -324,6 +336,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
(r'start_indices must be a vector with length equal to input rank, '
r'but input rank is 3 and start_indices has shape \[2\].*'))
@test_util.disable_mlir_bridge('Not supported yet')
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
with self.session() as session:
with self.test_scope():
@ -340,6 +353,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
@test_util.disable_mlir_bridge('Not supported yet')
def testDotDifferentNumberOfContractingDimensions(self):
a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
@ -354,6 +368,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
'dimensions for lhs and rhs. Got: 1 and 2'):
xla.dot_general(a, b, dim_nums)
@test_util.disable_mlir_bridge('Not supported yet')
def testDotDifferentContractingDimensionsSizes(self):
a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
@ -367,6 +382,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
'Got: 2 and 4'):
xla.dot_general(a, b, dim_nums)
@test_util.disable_mlir_bridge('Not supported yet')
def testDotDifferentNumberOfBatchDimensions(self):
a = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 4))
@ -381,6 +397,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
'dimensions for lhs and rhs. Got: 1 and 2'):
xla.dot_general(a, b, dim_nums)
@test_util.disable_mlir_bridge('Not supported yet')
def testDotDifferentBatchDimensionsSizes(self):
a = array_ops.placeholder(np.float32, shape=(2, 2, 2, 2))
b = array_ops.placeholder(np.float32, shape=(4, 4, 4, 2))
@ -396,6 +413,7 @@ class XlaOpsShapeInferenceTest(xla_test.XLATestCase, parameterized.TestCase):
'Got: 2 and 4'):
xla.dot_general(a, b, dim_nums)
@test_util.disable_mlir_bridge('Not supported yet')
def testDotShapeInference(self):
a = array_ops.placeholder(np.float32, shape=(1, 2, 3, 4))
b = array_ops.placeholder(np.float32, shape=(4, 3, 2, 1))

View File

@ -1624,18 +1624,23 @@ XlaOp XlaBuilder::CustomCall(
XlaOp XlaBuilder::Transpose(XlaOp operand,
absl::Span<const int64> permutation) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
*operand_shape, permutation));
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : permutation) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
return TransposeInternal(shape, operand, permutation);
});
}
StatusOr<XlaOp> XlaBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : permutation) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
}
XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;

View File

@ -565,6 +565,8 @@ class XlaBuilder {
XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
virtual StatusOr<XlaOp> TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);