Lower ops with dynamic result type in fallback TF to HLO lowering
This cast converts static shaped tensor created by XLA to the original dynamic tensor type available before the lowering. This happens for TF ops that don't have shape inference function defined but the HLO lowering generates static shaped results causing a mismatch. Introducing tensor_cast in between fixes this problem. PiperOrigin-RevId: 309792185 Change-Id: I47200ce19165531b42adac9252828643b6bdb750
This commit is contained in:
parent
d1a6976039
commit
488a15749e
@ -187,6 +187,7 @@ cc_library(
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -101,6 +101,12 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
// Returns the shape of the given op.
|
||||
StatusOr<const Shape*> GetShapePtr(XlaOp op) const override;
|
||||
|
||||
// Creates the given op at the current location.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy create(Args&&... args) {
|
||||
return builder_.create<OpTy>(loc_, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
private:
|
||||
XlaOp ConstantLiteral(const LiteralSlice& literal) override;
|
||||
|
||||
|
@ -146,6 +146,16 @@ func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor
|
||||
return %0 : tensor<6x5xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_result_type
|
||||
func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: tensor_cast %0 : tensor<2xf32> to tensor<*xf32>
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<*xf32>
|
||||
|
||||
// return %[[RESULT]]
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
@ -388,7 +389,12 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
||||
return op->emitError(
|
||||
"expects XlaExpression of kind kXlaOp in compiled output");
|
||||
auto value = hlo_builder_.GetValue(expr->handle());
|
||||
op->getResult(i).replaceAllUsesWith(value);
|
||||
mlir::OpResult old_result = op->getResult(i);
|
||||
if (value.getType() != old_result.getType()) {
|
||||
value =
|
||||
hlo_builder_.create<mlir::TensorCastOp>(value, old_result.getType());
|
||||
}
|
||||
old_result.replaceAllUsesWith(value);
|
||||
}
|
||||
|
||||
op->erase();
|
||||
|
Loading…
Reference in New Issue
Block a user