Lower ops with dynamic result type in fallback TF to HLO lowering

This cast converts static shaped tensor created by XLA to the original dynamic tensor type available before the lowering.

This happens for TF ops that don't have shape inference function defined but the HLO lowering generates static shaped results causing a mismatch. Introducing tensor_cast in between fixes this problem.

PiperOrigin-RevId: 309792185
Change-Id: I47200ce19165531b42adac9252828643b6bdb750
This commit is contained in:
Smit Hinsu 2020-05-04 12:02:08 -07:00 committed by TensorFlower Gardener
parent d1a6976039
commit 488a15749e
4 changed files with 24 additions and 1 deletions

View File

@ -187,6 +187,7 @@ cc_library(
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,

View File

@ -101,6 +101,12 @@ class MlirHloBuilder : public XlaBuilder {
// Returns the shape of the given op.
StatusOr<const Shape*> GetShapePtr(XlaOp op) const override;
// Creates the given op at the current location.
template <typename OpTy, typename... Args>
OpTy create(Args&&... args) {
return builder_.create<OpTy>(loc_, std::forward<Args>(args)...);
}
private:
XlaOp ConstantLiteral(const LiteralSlice& literal) override;

View File

@ -146,6 +146,16 @@ func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor
return %0 : tensor<6x5xf64>
}
// CHECK-LABEL: dynamic_result_type
func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: tensor_cast %0 : tensor<2xf32> to tensor<*xf32>
%0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<*xf32>
// return %[[RESULT]]
return %0 : tensor<*xf32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Optional.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
@ -388,7 +389,12 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
return op->emitError(
"expects XlaExpression of kind kXlaOp in compiled output");
auto value = hlo_builder_.GetValue(expr->handle());
op->getResult(i).replaceAllUsesWith(value);
mlir::OpResult old_result = op->getResult(i);
if (value.getType() != old_result.getType()) {
value =
hlo_builder_.create<mlir::TensorCastOp>(value, old_result.getType());
}
old_result.replaceAllUsesWith(value);
}
op->erase();