Add support to lower element-wise HLO functions to LHLO on dynamic shapes.

This uses hand-written shape derivation functions and is not complete. The goal is to start experimenting with these and get some first code through.

PiperOrigin-RevId: 296403560
Change-Id: I996cc5f862604ca76344076cb6699998757d4164
This commit is contained in:
Stephan Herhut 2020-02-21 03:30:50 -08:00 committed by TensorFlower Gardener
parent 9b53275852
commit 45c98a790e
4 changed files with 247 additions and 5 deletions

View File

@ -142,6 +142,19 @@ cc_library(
], ],
) )
cc_library(
name = "hlo_shape_derivation",
srcs = [],
hdrs = ["transforms/hlo_shape_derivation.h"],
deps = [
":hlo",
":lhlo",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
)
cc_library( cc_library(
name = "lhlo_legalize_to_affine", name = "lhlo_legalize_to_affine",
srcs = ["transforms/lhlo_legalize_to_affine.cc"], srcs = ["transforms/lhlo_legalize_to_affine.cc"],
@ -218,9 +231,9 @@ cc_library(
srcs = ["transforms/hlo_legalize_to_lhlo.cc"], srcs = ["transforms/hlo_legalize_to_lhlo.cc"],
deps = [ deps = [
":hlo", ":hlo",
":hlo_shape_derivation",
":lhlo", ":lhlo",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal %s -o - | FileCheck %s --dump-input=always // RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal -split-input-file %s -o - | FileCheck %s -dump-input-on-failure
// CHECK-LABEL: func @attrs // CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -11,6 +11,8 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @func_op // CHECK-LABEL: func @func_op
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
@ -20,6 +22,8 @@ func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
} }
// -----
// CHECK-LABEL: func @func_op_long // CHECK-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
@ -45,6 +49,8 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
} }
// -----
// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store // CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) { func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
%0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32> %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
@ -58,6 +64,8 @@ func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: t
// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32> // CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
// CHECK: "xla_lhlo.terminator"() : () -> () // CHECK: "xla_lhlo.terminator"() : () -> ()
// -----
// CHECK-LABEL: func @fusion // CHECK-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -77,6 +85,8 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
"xla_lhlo.terminator"() : () -> () "xla_lhlo.terminator"() : () -> ()
} }
// -----
// CHECK-LABEL: func @copy // CHECK-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -87,6 +97,8 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @exp // CHECK-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -97,6 +109,8 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @select // CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -110,6 +124,8 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
return return
} }
// -----
// CHECK-LABEL: func @compare // CHECK-LABEL: func @compare
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
@ -122,6 +138,8 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
return return
} }
// -----
// CHECK-LABEL: func @broadcast // CHECK-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32> %tensor_operand = tensor_load %operand : memref<5xf32>
@ -133,6 +151,8 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @dyn_broadcast // CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) { func @dyn_broadcast(%operand: memref<?x?xf32>) {
%tensor_operand = tensor_load %operand : memref<?x?xf32> %tensor_operand = tensor_load %operand : memref<?x?xf32>
@ -157,6 +177,8 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @iota // CHECK-LABEL: func @iota
func @iota(%result: memref<10xi32>) { func @iota(%result: memref<10xi32>) {
%tensor_result = "xla_hlo.iota"() %tensor_result = "xla_hlo.iota"()
@ -166,6 +188,8 @@ func @iota(%result: memref<10xi32>) {
return return
} }
// -----
// CHECK-LABEL: func @abs // CHECK-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -176,6 +200,8 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @ceil // CHECK-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -186,6 +212,8 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @convert // CHECK-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -196,6 +224,8 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @cos // CHECK-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -206,6 +236,8 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @neg // CHECK-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -216,6 +248,8 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @sign // CHECK-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -226,6 +260,8 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @tanh // CHECK-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -236,6 +272,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
return return
} }
// -----
// CHECK-LABEL: func @remainder // CHECK-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
@ -246,3 +284,47 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
return return
} }
// -----
// Dynamic shape binary element-wise operation.
// CHECK-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// Dynamic shape unary element-wise operation.
// CHECK-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
@ -127,9 +128,24 @@ class HloToLhloOpConverter : public ConversionPattern {
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
const auto& original_results = op->getResults(); const auto& original_results = op->getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end()); SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) { for (auto result : llvm::enumerate(original_results)) {
buffer_args.push_back( RankedTensorType resultType =
InsertAllocAndDealloc(op->getLoc(), result, &rewriter)); result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return matchFailure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter));
} else {
Value shape_value = ShapeDerivation<HloOpTy>::impl::deriveShapeFromOp(
op, result.index(), &rewriter);
if (!shape_value) {
return matchFailure();
}
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), shape_value, &rewriter));
}
} }
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args, rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
op->getAttrs()); op->getAttrs());
@ -320,6 +336,7 @@ struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
target.addIllegalOp<mlir::TensorLoadOp>(); target.addIllegalOp<mlir::TensorLoadOp>();
target.addIllegalOp<mlir::TensorStoreOp>(); target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>(); target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<ScalarsToDimensionTensorOp>();
target.addIllegalDialect<xla_hlo::XlaHloDialect>(); target.addIllegalDialect<xla_hlo::XlaHloDialect>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs(); auto inputs = op.getType().getInputs();

View File

@ -0,0 +1,130 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
// This file contains implementations for shape derivation functions that,
// given some operation and a result number, produce IR that computes the
// shape of the given result at runtime based on operands of the provided
// operation.
// These should be generated at some point based on annotations on the HLO
// using the new shape dialect. While this is still in the works, we hardcode
// the expected IR here to unblock progress.
// The implementation is based on templates to allow for using these derivation
// functions in templated code.
namespace impl {
struct UnknownShape {
// Default shape derivation function that simply fails with a runtime error.
static Value deriveShapeFromOp(Operation* op, int operand_position,
ConversionPatternRewriter* rewriter) {
op->emitOpError()
<< "dynamic result shapes cannot be derived for this operation";
return {};
}
};
struct SameShapeAsFirstOperand {
// Shape derivation function that computes the shape of the result based on
// the first argument. For a 2-dimensional input tensor, this produces IR of
// the form
//
// %0 = dim %arg0, 0 : memref<?x?xf32>
// %1 = index_cast %0 : index to i64
// %2 = dim %arg0, 1 : memref<?x?xf32>
// %3 = index_cast %2 : index to i64
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
// : (i64, i64) -> tensor<2xi64>
//
// and returns %4 as the shape value.
static Value deriveShapeFromOp(Operation* op, int result_postion,
ConversionPatternRewriter* rewriter) {
Value operand = op->getOperand(0);
ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
if (!operand_type) {
op->emitOpError() << "first operand has no shaped type";
return {};
}
auto loc = op->getLoc();
SmallVector<Value, 4> shape_values;
shape_values.reserve(operand_type.getRank());
auto shape_scalar_type = rewriter->getIntegerType(64);
for (auto element : llvm::enumerate(operand_type.getShape())) {
if (element.value() == ShapedType::kDynamicSize) {
Value dim = rewriter->create<DimOp>(loc, operand, element.index());
shape_values.push_back(
rewriter->create<IndexCastOp>(loc, dim, shape_scalar_type));
} else {
shape_values.push_back(rewriter->create<ConstantOp>(
loc, rewriter->getI64IntegerAttr(element.value())));
}
}
return rewriter->create<ScalarsToDimensionTensorOp>(
loc, RankedTensorType::get({operand_type.getRank()}, shape_scalar_type),
shape_values);
}
};
} // namespace impl
// Default template to cover HLO operations whose shape derivation is unknown.
template <typename HloOpTy>
struct ShapeDerivation {
using impl = impl::UnknownShape;
};
// Element-wise operations that have the shape of their first operand.
#define SAME_SHAPE_AS_FIRST_OPERAND(Op) \
template <> \
struct ShapeDerivation<Op> { \
using impl = impl::SameShapeAsFirstOperand; \
};
SAME_SHAPE_AS_FIRST_OPERAND(AbsOp)
SAME_SHAPE_AS_FIRST_OPERAND(AddOp)
SAME_SHAPE_AS_FIRST_OPERAND(AndOp)
SAME_SHAPE_AS_FIRST_OPERAND(CeilOp)
SAME_SHAPE_AS_FIRST_OPERAND(CosOp)
SAME_SHAPE_AS_FIRST_OPERAND(DivOp)
SAME_SHAPE_AS_FIRST_OPERAND(ExpOp)
SAME_SHAPE_AS_FIRST_OPERAND(MaxOp)
SAME_SHAPE_AS_FIRST_OPERAND(MinOp)
SAME_SHAPE_AS_FIRST_OPERAND(MulOp)
SAME_SHAPE_AS_FIRST_OPERAND(NegOp)
SAME_SHAPE_AS_FIRST_OPERAND(RemOp)
SAME_SHAPE_AS_FIRST_OPERAND(SubOp)
SAME_SHAPE_AS_FIRST_OPERAND(TanhOp)
#undef SAME_SHAPE_AS_FIRST_OPERAND
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_