Merge branch 'master' into nhasabni/fixes_for_dnnl1.0

This commit is contained in:
Niranjan Hasabnis 2020-02-26 13:54:33 -08:00
commit b96c0010fa
399 changed files with 7729 additions and 2610 deletions

View File

@ -365,13 +365,14 @@ build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --config=rbe_linux
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7

View File

@ -70,7 +70,7 @@ $ python
3
>>> hello = tf.constant('Hello, TensorFlow!')
>>> hello.numpy()
'Hello, TensorFlow!'
b'Hello, TensorFlow!'
```
For more examples, see the

View File

@ -418,21 +418,11 @@ void TensorHandleSilentCopy(bool async,
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
if (!async) {
// The input handles should never change since they have been mirrored.
ASSERT_EQ(op->GetInput(0), arg0);
ASSERT_EQ(op->GetInput(1), arg1);
} else {
if (cpu_op) {
ASSERT_EQ(op->GetInput(0), arg0);
// The GPU handle should be replaced with a CPU copy
ASSERT_NE(op->GetInput(1), arg1);
} else {
// The CPU handle should be replaced with a GPU copy
ASSERT_NE(op->GetInput(0), arg0);
ASSERT_EQ(op->GetInput(1), arg1);
}
}
// The input handles should never change since they have been mirrored.
EXPECT_EQ(op->GetInput(0), arg0);
EXPECT_EQ(op->GetInput(1), arg1);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);

View File

@ -14,6 +14,10 @@ package_group(
includes = [
"//tensorflow/compiler/tf2xla:internal",
],
packages = [
"//tensorflow/compiler/tests/...",
"//tensorflow/python/...",
],
)
package_group(

View File

@ -46,7 +46,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project

View File

@ -42,7 +42,7 @@ limitations under the License.
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
@ -165,7 +165,7 @@ constexpr size_t kInitialBufferSize = 10240;
// `isSigned` is set to false for other types.
static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
bool is_signed = true) {
if (!is_signed && type.isInteger(8)) {
if (!is_signed && type.isSignlessInteger(8)) {
return tflite::TensorType_UINT8;
}
if (!is_signed) {

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
@ -275,7 +275,7 @@ Attribute ConstFoldBinaryOp(
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
float_calculate, is_commutative);
if (elemType.isa<IntegerType>())
if (elemType.isSignlessInteger())
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
int_calculate, is_commutative);
@ -1560,7 +1560,7 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
limit_tensor.getType().getRank() == 0 &&
delta_tensor.getType().getRank() == 0);
Type elem_type = getType().cast<ShapedType>().getElementType();
if (elem_type.isa<IntegerType>()) {
if (elem_type.isSignlessInteger()) {
auto start_attr = start_tensor.getValue<IntegerAttr>({});
auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
@ -1662,7 +1662,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
// Do not try to fold elements attr of a quant type because
// DenseElementsAttr does not support it.
if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
return nullptr;
assert(perm_tensor.getType().getRank() == 1);

View File

@ -3100,7 +3100,9 @@ def LstmMandatoryInputsConstraint : PredOpTrait<
"mandatory operands element types should match",
// TODO(ashwinm): Replace the indices with input tensor names when that
// support is available.
TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>>;
Or<[
TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>,
Neg<TypeIsPred<"input", F32>>]>>;
def LstmOptionalPeepholeWeightConstraint : PredOpTrait<
"the optional peephole weights should all be specified or none",
@ -3126,7 +3128,7 @@ def LstmProjectionWeightBiasConstraint : PredOpTrait<
def LstmResultConstraint : PredOpTrait<
"the input and result tensor elemental types must be same",
TCresVTEtIsSameAsOp<0, 0>>;
TFL_TCresVTEtIsSameAsOp<0, 0>>;
// This is the basic kernel type LSTM op.
// TODO(b/142417845): Refactor this part to return its tflite node name as
@ -3195,47 +3197,47 @@ Ba et al. “Layer Normalization”
}];
let arguments = (
ins TFL_TensorOf<[F32]>:$input,
ins TFL_TensorOf<[F32, QI8]>:$input,
// Weights
TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights,
TFL_TensorOf<[F32, I8]>:$input_to_forget_weights,
TFL_TensorOf<[F32, I8]>:$input_to_cell_weights,
TFL_TensorOf<[F32, I8]>:$input_to_output_weights,
TFL_TensorOfOrNone<[F32, I8, QI8]>:$input_to_input_weights,
TFL_TensorOf<[F32, I8, QI8]>:$input_to_forget_weights,
TFL_TensorOf<[F32, I8, QI8]>:$input_to_cell_weights,
TFL_TensorOf<[F32, I8, QI8]>:$input_to_output_weights,
// Recurrent weights
TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights,
TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights,
TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights,
TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights,
TFL_TensorOfOrNone<[F32, I8, QI8]>:$recurrent_to_input_weights,
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_forget_weights,
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_cell_weights,
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_output_weights,
// Cell weights
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights,
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_input_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights,
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_forget_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights,
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_output_weights,
// Bias
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
TFL_TensorOf<[F32]>:$forget_gate_bias,
TFL_TensorOf<[F32]>:$cell_bias,
TFL_TensorOf<[F32]>:$output_gate_bias,
TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias,
TFL_TensorOf<[F32, QI32]>:$forget_gate_bias,
TFL_TensorOf<[F32, QI32]>:$cell_bias,
TFL_TensorOf<[F32, QI32]>:$output_gate_bias,
// Projection weight and bias
TFL_TensorOfOrNone<[F32, I8]>:$projection_weights,
TFL_TensorOfOrNone<[F32, I8, QI8]>:$projection_weights,
// Optional input
TFL_TensorOfOrNone<[F32]>:$projection_bias,
TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias,
// Stateful activation and cell states.
TFL_StatefulTensor:$input_activation_state,
TFL_StatefulTensor:$input_cell_state,
// Layer norm coefficients
TFL_TensorOfOrNone<[F32]>:$input_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32]>:$forget_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32]>:$cell_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32]>:$output_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, QI16]>:$input_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, QI16]>:$forget_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, QI16]>:$cell_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, QI16]>:$output_layer_norm_coefficients,
// Attributes
TFL_AFAttr:$fused_activation_function,

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
#include "mlir/IR/AffineMap.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
@ -191,7 +191,7 @@ struct QuantizationPattern : public RewritePattern {
auto ele_type = operand.getType().cast<TensorType>().getElementType();
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
inputs.push_back(op_inst.input());
} else if (ele_type.isa<IntegerType>()) {
} else if (ele_type.isSignlessInteger()) {
// If the operand is an integer tensor, then it doesn't require the
// DQ op in the pattern.
inputs.push_back(operand);
@ -225,7 +225,7 @@ struct QuantizationPattern : public RewritePattern {
auto user = llvm::cast<Q>(*result.user_begin());
outputs_replaced.insert({user.output(), enumerated_result.index()});
output_types.push_back(user.getType());
} else if (result_ele_type.template isa<IntegerType>()) {
} else if (result_ele_type.isSignlessInteger()) {
// If the result is an integer tensor, then it doesn't require the
// D op in the pattern.
outputs_replaced.insert({result, enumerated_result.index()});

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project

View File

@ -27,6 +27,20 @@ func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithNonTrivialDilations
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
// CHECK-NEXT: return [[RESULT]]
}
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
@ -104,7 +118,7 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1:
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@ -115,7 +129,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -125,7 +139,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@ -136,7 +150,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -146,7 +160,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
@ -157,7 +171,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -167,7 +181,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
@ -178,7 +192,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -188,7 +202,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@ -200,7 +214,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
@ -210,7 +224,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
@ -222,10 +236,29 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
return %4 : tensor<1x128x128x1xf32>
// CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
// CHECK-NEXT: return [[RESULT]]
}

View File

@ -593,6 +593,21 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testLstmQuantizedType
func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
%cst = constant unit
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK: %[[RES0:.*]] = constant unit
// CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
// CHECK-NEXT: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK: return %[[RES1]]
}
// -----
// CHECK-LABEL: testLstm

View File

@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
// -----
module {
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
@ -165,7 +165,7 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
@ -189,7 +189,7 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor
// -----
module {
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
@ -200,7 +200,7 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
@ -224,3 +224,84 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te
// CHECK: return [[VAL_24]] : tensor<8x8x10xf32>
// CHECK: }
}
// -----
module {
func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
// CHECK: [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32>
// CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<?x8x8xf32>, tensor<1xi32>) -> tensor<?x8x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_21:%.*]] = constant unit
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
// CHECK: return [[VAL_23:%.*]] : tensor<8x?x10xf32>
// CHECK: }
}
// -----
module {
func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<0> : tensor<1xi32>
// CHECK: [[VAL_9:%.*]] = "tf.ReverseV2"([[VAL_7]], [[VAL_8]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32>
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_10]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_12:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_13:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_12]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_16:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_14]], [[VAL_15]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_19:%.*]]:4 = "tf.SplitV"([[VAL_13]], [[VAL_17]], [[VAL_18]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK: [[VAL_20:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_23:%.*]] = constant unit
// CHECK: [[VAL_24:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
// CHECK: return [[VAL_26]] : tensor<8x8x10xf32>
// CHECK: }
}

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
@ -80,6 +81,17 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
template <typename Conv2dOpTy>
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
Conv2dOpTy op, PatternRewriter& rewriter) const {
// Make sure Conv2D has 'VALID' padding.
if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
return Pattern::matchFailure();
}
// Make sure dilations are all ones if set.
const ArrayAttr& dilations =
op.template getAttrOfType<ArrayAttr>("dilations");
if (dilations && !TFIntListIsAllOnes(dilations)) {
return Pattern::matchFailure();
}
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
// `Squeeze` op.
Operation* prev_op = op.getOperation()->getPrevNode();
@ -90,6 +102,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
TF::ExpandDimsOp expand_op;
TF::SqueezeOp squeeze_op;
int64_t expand_axis;
// Expand + Squeeze op.
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
@ -99,6 +112,22 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
// Make sure that the axis in `expand_op` is constant.
if (auto const_op =
llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
expand_axis =
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
.getSExtValue();
} else {
return Pattern::matchFailure();
}
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
auto squeeze_dims = squeeze_op.squeeze_dims();
if (squeeze_dims.size() != 1 ||
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
return Pattern::matchFailure();
}
// Update previous/next op pointer.
prev_op = prev_op->getPrevNode();
if (!prev_op) return Pattern::matchFailure();
@ -108,10 +137,14 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
// SpaceToBatchND op.
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
// TODO(b/149936532): Check `padding` input, currently ignored.
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
// Pad op.
TF::PadOp pad_op;
// TODO(b/149936532): Currently we just ignore the PadOp. However note that
// in real scenarios this may not always be correct: user can put a PadOp here
// with non-trivial consequences.
if (llvm::isa<TF::PadOp>(next_op)) {
pad_op = llvm::cast<TF::PadOp>(next_op);
next_op = next_op->getNextNode();
@ -119,6 +152,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
}
// BatchToSpaceND + BiasAdd.
// TODO(b/149936532): Check the `crops` input, currently ignored.
TF::BatchToSpaceNDOp bts_op;
TF::BiasAddOp biasadd_op;
bool final_op_is_bts = true;
@ -146,14 +180,10 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
op.setAttr("dilations", dilations_attr.getValue());
// Here we need to set the correct padding for Conv op. In TF, the conv op
// inserted after 'SpaceToBatch' always has 'VALID' padding. This might
// become a problem here if the original Conv op has 'SAME' padding. When
// the original conv has 'SAME' padding, TF will set a non-zero padding for
// the 'SpaceToBatch' op, so we rely on this information to check if we need
// to change the padding from 'VALID' to 'SAME' (a.k.a when we see non-zero
// values in `stb_op.paddings`, we change the current Conv's padding to
// 'SAME').
// Padding is set to 'SAME' when `stb_op` has non-zero paddings.
// TODO(b/149936532): This assumption only holds when the input width & height
// is multiple of dilation width & height. We should fix it in order to
// support other use cases.
auto stb_paddings = stb_op.paddings();
ElementsAttr stb_paddings_attr;
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
@ -175,7 +205,8 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
input_shape.end());
expand_shape.push_back(1);
expand_shape.insert(expand_shape.begin() + expand_axis, 1);
auto expand_result_type = RankedTensorType::get(
expand_shape, getElementTypeOrSelf(stb_op.input()));
expand_op.getResult().setType(expand_result_type);
@ -208,7 +239,7 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
ElementsAttr stb_bs_attr, bts_bs_attr;
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
!matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
// Returns failure status if block shape is not a constant.
// Returns failure status if block_shape is not a constant.
return {};
}
// Check that the block_shape of `stb_op` and `bts_op` are equal.
@ -217,9 +248,8 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
}
// TODO(haoliang): support 1-D dilated conv.
// Set dilation factor.
if (stb_bs_attr.getNumElements() < 2) return {};
int dilation_h_factor =
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
int dilation_w_factor =

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -16,7 +16,7 @@ limitations under the License.
// TFLite legalization patterns
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
@ -341,7 +341,7 @@ def : Pat<(TF_MatrixDiagOp $diagonal), (TFL_MatrixDiagOp $diagonal)>;
class I32VectorElementsAttr<int len> : ElementsAttrBase<
CPred<"$_self.isa<DenseIntElementsAttr>() &&"
"$_self.cast<DenseIntElementsAttr>().getType()."
"getElementType().isInteger(32)">,
"getElementType().isSignlessInteger(32)">,
"32-bit int elements attribute of shape [" # len # "]"> {
let storageType = [{ DenseIntElementsAttr }];

View File

@ -255,7 +255,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
ShapedType shape_type = shape.getType().cast<ShapedType>();
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
if (!shape_type.getElementType().isInteger(32)) {
if (!shape_type.getElementType().isSignlessInteger(32)) {
auto new_shape = shape_type.getShape();
IntegerType new_ele_type = rewriter.getIntegerType(32);
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);

View File

@ -15,7 +15,7 @@ limitations under the License.
// Converts TF While to TFL While with single call in body and cond.
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project

View File

@ -32,7 +32,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
@ -335,8 +335,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
ConversionPatternRewriter &rewriter) const override {
Type dtype = op.element_dtype();
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
dtype.isInteger(32) || dtype.isInteger(64))) {
dtype.isInteger(1) || dtype.isSignlessInteger(8) ||
dtype.isSignlessInteger(16) || dtype.isSignlessInteger(32) ||
dtype.isSignlessInteger(64))) {
op.emitError(
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project

View File

@ -16,7 +16,7 @@ limitations under the License.
// This is the optimization pattern definition file for TensorFlow Lite.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"

View File

@ -16,7 +16,7 @@ limitations under the License.
// This is the quantization pattern definition file for TensorFlow Lite.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
// Both Quantize and Dequantize ops have side effects, so we have to define

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project

View File

@ -16,7 +16,7 @@ limitations under the License.
// This is the quantization pattern definition file for TensorFlow Lite.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
// Quantize attribute $0 by using quantization parameter from %1.

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
//===----------------------------------------------------------------------===//

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Identifier.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Identifier.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project

View File

@ -38,7 +38,7 @@ FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) {
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) {
if (attr.getType().getNumElements() != 1 ||
!attr.getType().getElementType().isa<IntegerType>()) {
!attr.getType().getElementType().isSignlessInteger()) {
return {};
}
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);

View File

@ -19,7 +19,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
namespace mlir {
namespace TFL {

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
@ -95,6 +95,14 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
return Transpose(builder, value_to_transpose, perm, type, location);
}
Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis,
RankedTensorType type, mlir::Location location) {
auto axis_op = CreateI32SplatConst(builder, {1}, axis, location);
// The result type will be the same as the input.
return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse,
axis_op);
}
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
return value.getType().cast<RankedTensorType>().getShape();
}
@ -615,6 +623,16 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
final_input_type = final_inputs.getType().dyn_cast<RankedTensorType>();
}
// Handle go_backwards:
// LSTM in Keras semantic will reverse the input sequence if it's go_backwards
auto go_backwards_attr = func_op.getAttrOfType<BoolAttr>("tf.go_backwards");
if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
// We assume input is already in {time, batch, size} layout.
final_inputs =
Reverse(builder, final_inputs, 0, final_input_type, func_op.getLoc());
}
int batch = final_input_type.getDimSize(1);
int time = final_input_type.getDimSize(0);

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
namespace mlir {
namespace TFL {

View File

@ -19,7 +19,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
namespace mlir {

View File

@ -90,7 +90,7 @@ gentbl(
td_file = "ir/tf_saved_model_ops.td",
td_srcs = [
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
],
)
@ -114,7 +114,7 @@ gentbl(
td_file = "ir/tf_executor_ops.td",
td_srcs = [
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
],
)
@ -138,7 +138,7 @@ gentbl(
td_file = "ir/tf_device_ops.td",
td_srcs = [
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
],
)
@ -281,6 +281,7 @@ cc_library(
"transforms/generated_canonicalize.inc",
"transforms/generated_optimize.inc",
"transforms/graph_pruning.cc",
"transforms/launch_to_device_attribute.cc",
"transforms/layout_optimization.cc",
"transforms/mark_function_visibility.cc",
"transforms/materialize_mlir_passthrough_op.cc",

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/Dialect/Traits.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -3111,6 +3111,70 @@ cublas.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, AllTypesMatch<["input", "band"]>]> {
let summary = [{
Copy a tensor setting everything outside a central band in each innermost matrix
to zero.
}];
let description = [{
The `band` part is computed as follows:
Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
tensor with the same shape where
`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
The indicator function
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
(num_upper < 0 || (n-m) <= num_upper)`.
For example:
```
# if 'input' is [[ 0, 1, 2, 3]
[-1, 0, 1, 2]
[-2, -1, 0, 1]
[-3, -2, -1, 0]],
tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
[-1, 0, 1, 2]
[ 0, -1, 0, 1]
[ 0, 0, -1, 0]],
tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
[-1, 0, 1, 0]
[-2, -1, 0, 1]
[ 0, -2, -1, 0]]
```
Useful special cases:
```
tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
tf.matrix_band_part(input, 0, 0) ==> Diagonal.
```
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$num_lower,
TF_I32OrI64Tensor:$num_upper
);
let results = (outs
TF_Tensor:$band
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tindex = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_MatrixDiagOp : TF_Op<"MatrixDiag", [NoSideEffect]> {
let summary = [{
Returns a batched diagonal tensor with a given batched diagonal values.

View File

@ -85,7 +85,7 @@ class TF_TensorFlowType <string name, string description> :
// Any tensor element type allowed in TensorFlow ops
def TF_ElementType : Type<Or<[AnyFloat.predicate,
AnyInteger.predicate,
AnySignlessInteger.predicate,
AnyComplex.predicate,
TF_TFDialectType.predicate]>,
"tf.dtype">;

View File

@ -35,7 +35,7 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/Dialect/Traits.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
@ -1506,6 +1506,29 @@ void LogicalNotOp::getCanonicalizationPatterns(
LogicalNotOfLess, LogicalNotOfLessEqual>(context);
}
//===----------------------------------------------------------------------===//
// MatrixBandPartOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(MatrixBandPartOp op) {
if (!HasRankAtLeast(op.input(), 2)) {
return op.emitOpError()
<< "requires `input` to have rank of at least 2, but found "
<< op.input().getType();
}
if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
return op.emitOpError()
<< "requires `num_lower` to have 0 dimensions, but found "
<< op.num_lower().getType();
}
if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
return op.emitOpError()
<< "requires `num_upper` to have 0 dimensions, but found "
<< op.num_upper().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//
@ -2104,7 +2127,8 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
}
Type element_type = result_ranked_type.getElementType();
if (!element_type.isInteger(32) && !element_type.isInteger(64))
if (!element_type.isSignlessInteger(32) &&
!element_type.isSignlessInteger(64))
return op->emitOpError("requires int32 or int64 return type for result")
<< variadic_idx_str;

View File

@ -91,7 +91,7 @@ class TensorFlowType : public Type {
// Returns true if the specified type is a valid TensorFlow element type.
static inline bool IsValidTFElementType(Type type) {
return type.isa<ComplexType>() || type.isa<FloatType>() ||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
type.isSignlessInteger() || type.isa<TensorFlowType>();
}
// Returns true if this is a valid TensorFlow tensor type.

View File

@ -0,0 +1,123 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-launch-to-device-attribute | FileCheck %s --dump-input=fail
// Tests single TensorFlow op is hoisted out and has the correct device assigned
// by parent `tf_device.launch`.
// CHECK-LABEL: func @single_op_launch
func @single_op_launch() {
tf_executor.graph {
%0:5 = tf_executor.island {
%a = "tf.opA"() : () -> tensor<i1>
%launch:2 = "tf_device.launch"() ( {
%b:2 = "tf.opB"(%a) : (tensor<i1>) -> (tensor<i32>, tensor<f32>)
tf_device.return %b#1, %b#0 : tensor<f32>, tensor<i32>
}) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
%c = "tf.opC"() : () -> tensor<i1>
tf_executor.yield %a, %launch#0, %launch#1, %c : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
}
tf_executor.fetch
}
return
}
// CHECK: %[[A:.*]] = "tf.opA"
// CHECK: %[[B:.*]]:2 = "tf.opB"(%[[A]])
// CHECK-SAME: device = "CPU:0"
// CHECK: %[[C:.*]] = "tf.opC"
// CHECK-NOT: "tf_device.launch"
// CHECK: tf_executor.yield %[[A]], %[[B]]#1, %[[B]]#0, %[[C]]
// Tests multiple TensorFlow ops are hoisted out and all have the correct device
// assigned by parent `tf_device.launch`.
// CHECK-LABEL: func @multi_op_launch
func @multi_op_launch() {
tf_executor.graph {
%0:5 = tf_executor.island {
%a = "tf.opA"() : () -> tensor<i1>
%launch:2 = "tf_device.launch"() ( {
%b = "tf.opB"(%a) : (tensor<i1>) -> tensor<i32>
%c = "tf.opC"(%b) : (tensor<i32>) -> tensor<f32>
tf_device.return %c, %b : tensor<f32>, tensor<i32>
}) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
%d = "tf.opD"() : () -> tensor<i1>
tf_executor.yield %a, %launch#0, %launch#1, %d : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
}
tf_executor.fetch
}
return
}
// CHECK: %[[A:.*]] = "tf.opA"
// CHECK: %[[B:.*]] = "tf.opB"(%[[A]])
// CHECK-SAME: device = "CPU:0"
// CHECK: %[[C:.*]] = "tf.opC"(%[[B]])
// CHECK-SAME: device = "CPU:0"
// CHECK: %[[D:.*]] = "tf.opD"
// CHECK-NOT: "tf_device.launch"
// CHECK: tf_executor.yield %[[A]], %[[C]], %[[B]], %[[D]]
// -----
// Tests ops are hoisted out and devices are set only if the `tf_device.launch`
// contains TensorFlow ops.
func @non_tf_dialect_op_launch() {
tf_executor.graph {
%0:5 = tf_executor.island {
%a = "tf.opA"() : () -> tensor<i1>
// expected-error@+1 {{'tf_device.launch' op must contain only 'tf' dialect ops}}
%launch:2 = "tf_device.launch"() ( {
%b = "tf.opB"(%a) : (tensor<i1>) -> tensor<i32>
%c = "unknown.opC"(%b) : (tensor<i32>) -> tensor<f32>
tf_device.return %c, %b : tensor<f32>, tensor<i32>
}) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
%d = "tf.opD"() : () -> tensor<i1>
tf_executor.yield %a, %launch#0, %launch#1, %d : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
}
tf_executor.fetch
}
return
}
// -----
// Tests TensorFlow op with conflicting `device` attribute compared to parent
// `tf_device.launch`.
func @conflicting_device() {
tf_executor.graph {
%0 = tf_executor.island {
// expected-error@+1 {{'tf_device.launch' op inner 'tf' dialect op has conflicting 'device' attribute, got 'GPU:0' but expected 'CPU:0'}}
"tf_device.launch"() ( {
"tf.opA"() {device = "GPU:0"} : () -> ()
tf_device.return
}) {device = "CPU:0"} : () -> ()
tf_executor.yield
}
tf_executor.fetch
}
return
}
// -----
// Tests TensorFlow op with bad `device` attribute already set.
func @bad_tf_device_attr() {
tf_executor.graph {
%0 = tf_executor.island {
// expected-error@+1 {{'tf_device.launch' op inner 'tf' dialect op has bad 'device' attribute}}
"tf_device.launch"() ( {
"tf.opA"() {device = 0 : i32} : () -> ()
tf_device.return
}) {device = "CPU:0"} : () -> ()
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -854,6 +854,78 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
// -----
// Test valid tf.MatrixBandPart
// CHECK-LABEL: func @testValidMatrixBandPartOp
func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
return %0 : tensor<64x64xbf16>
}
// -----
// Test valid tf.MatrixBandPart
// CHECK-LABEL: func @testValidMatrixBandPartOp3D
func @testValidMatrixBandPartOp3D(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64x64xbf16> {
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64x64xbf16>
return %0 : tensor<64x64x64xbf16>
}
// -----
// Test valid tf.MatrixBandPart
// CHECK-LABEL: func @testValidMatrixBandPartOpUnranked
func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
return %0 : tensor<*xbf16>
}
// -----
// Test invalid tf.MatrixBandPart
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
return %0 : tensor<64x64xbf16>
}
// -----
// Test invalid tf.MatrixBandPart
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
return %0 : tensor<*xbf16>
}
// -----
// Test invalid tf.MatrixBandPart
func @testInvalidMatrixBandPartOp(%arg0: tensor<i64>, %arg1: tensor<64x64xi64>, %arg2: tensor<i64>) -> tensor<i64> {
// expected-error @+1 {{op requires `input` to have rank of at least 2, but found 'tensor<i64>'}}
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<i64>, tensor<64x64xi64>, tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
// -----
// Test invalid tf.MatrixBandPart
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<32xi64>, %arg2: tensor<i64>) -> tensor<64x64xi64> {
// expected-error @+1 {{op requires `num_lower` to have 0 dimensions, but found 'tensor<32xi64>'}}
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<32xi64>, tensor<i64>) -> tensor<64x64xi64>
return %0 : tensor<64x64xi64>
}
// -----
// Test invalid tf.MatrixBandPart
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<i64>, %arg2: tensor<32xi64>) -> tensor<64x64xi64> {
// expected-error @+1 {{op requires `num_upper` to have 0 dimensions, but found 'tensor<32xi64>'}}
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<i64>, tensor<32xi64>) -> tensor<64x64xi64>
return %0 : tensor<64x64xi64>
}
// -----
//===--------------------------------------------------------------------===//
// tf.{|Stateful}PartitionedCall
//===--------------------------------------------------------------------===//

View File

@ -0,0 +1,220 @@
// RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s --dump-input=fail
//===----------------------------------------------------------------------===//
// Immutability.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
// Test case: This test exercises marking a global tensor as immutable after it propagates
// via set of chained calls -> f -> f_callee -> f_callee_callee
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-NOT: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["f"]} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
return %val : tensor<f32>
}
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
return %val : tensor<f32>
}
}
// -----
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
// Test case:
// This test exercises trying to mark immutable when same func is called by multiple callers
// with different global tensors.
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-NOT: is_mutable
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-NOT: is_mutable
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v2", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["f"]} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_common} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
return %val : tensor<f32>
}
func @f2(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v2}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["f2"]} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_common} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
return %val : tensor<f32>
}
func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
return %val : tensor<f32>
}
}
// -----
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
// Test case: This test exercises immutability without explicit use
// via ReadVariableOp
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-NOT: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["f"]} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
%val_2 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
return %val_2 : tensor<f32>
}
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%cst_1 = constant dense<2.0> : tensor<f32>
return %cst_1 : tensor<f32>
}
}
// -----
//===----------------------------------------------------------------------===//
// Test case: Test mutation detection propagates across function calls
//===----------------------------------------------------------------------===//
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["f"]} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32>
}
}
// -----
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
// Test case: The inter-procedural analysis with different types of
// TF call ops
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["f"]} {
%val = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32>
}
}
// -----
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
// Test case: The inter-procedural analysis does not recurse infinitely
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-NOT: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @exported_f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["exported_f"]} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
}
// -----
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
module attributes {tf_saved_model.semantics} {
// Test case: Inter-procedural analysis with resource usage in an
// unknown op, we assume mutating behavior and propagate that.
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @exported_f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
attributes {tf_saved_model.exported_names = ["exported_f"]} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32>
}
}

View File

@ -17,7 +17,7 @@ limitations under the License.
// `tf_device.launch` with equivalent `tf_device.launch_func` operations.
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
// Here, the element type can be any integer or float type. But, note that only

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Visitors.h" // TF:llvm-project

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project

View File

@ -16,7 +16,7 @@ limitations under the License.
// This transformation pass transforms functional control flow operations in the
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project

View File

@ -0,0 +1,135 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass hoists a `tf_device.launch` body and assigns a `device` attribute
// to each TensorFlow dialect op in the body based on the `device` attribute on
// the `tf_device.launch`. If a TensorFlow dialect op already has a device
// attribute, that attribute will be overwritten with the `tf_device.launch`
// device.
//
// For example:
// %island:5 = tf_executor.island {
// %a = "tf.opA"() : () -> tensor<i1>
// %launch:2 = "tf_device.launch"() ( {
// %b = "tf.opB"() : () -> tensor<i32>
// %c = "tf.opC"() : () -> tensor<f32>
// tf_device.return %c, %b : tensor<f32>, tensor<i32>
// }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
// %d = "tf.opD"() : () -> tensor<i1>
// tf_executor.yield %a, %launch#0, %launch#1, %d :
// tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
// }
//
// Will be transformed into:
// %island:5 = tf_executor.island {
// %a = "tf.opA"() : () -> tensor<i1>
// %b = "tf.opB"() {device = "CPU:0"} : () -> tensor<i32>
// %c = "tf.opC"() {device = "CPU:0"} : () -> tensor<f32>
// %d = "tf.opD"() : () -> tensor<i1>
// tf_executor.yield %a, %c, %b, %d :
// tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
// }
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Dialect.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/Visitors.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
namespace mlir {
namespace TFDevice {
namespace {
constexpr char kDeviceAttr[] = "device";
struct LaunchToDeviceAttributePass
: public FunctionPass<LaunchToDeviceAttributePass> {
void runOnFunction() override;
};
LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect,
tf_device::LaunchOp launch) {
// Forward launch inner op results to launch op results.
launch.replaceAllUsesWith(launch.GetBody().getTerminator()->getOperands());
// For all inner ops of the TensorFlow dialect, assign the launch device as a
// `device` attribute.
auto body = launch.GetBody().without_terminator();
for (Operation& op : body) {
if (op.getDialect() != tf_dialect)
return launch.emitOpError() << "must contain only 'tf' dialect ops";
auto device_attr = op.getAttr(kDeviceAttr);
if (!device_attr) {
op.setAttr(kDeviceAttr, launch.deviceAttr());
continue;
}
if (auto device_str_attr = device_attr.dyn_cast<StringAttr>()) {
if (launch.device() != device_str_attr.getValue())
return launch.emitOpError()
<< "inner 'tf' dialect op has conflicting 'device' attribute, "
"got '"
<< device_str_attr.getValue() << "' but expected '"
<< launch.device() << "'";
} else {
return launch.emitOpError()
<< "inner 'tf' dialect op has bad 'device' attribute";
}
}
// Move all inner ops of the launch to the block containing the launch.
Operation* launch_op = launch.getOperation();
launch_op->getBlock()->getOperations().splice(
launch_op->getIterator(), launch.GetBody().getOperations(), body.begin(),
body.end());
launch.erase();
return success();
}
void LaunchToDeviceAttributePass::runOnFunction() {
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
if (!tf_dialect) {
signalPassFailure();
getFunction().emitError() << "'tf' dialect is not registered";
}
auto result = getFunction().walk([&](tf_device::LaunchOp launch) {
if (failed(HoistOpsAndAnnotateWithDevice(tf_dialect, launch)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (result.wasInterrupted()) return signalPassFailure();
}
} // anonymous namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass() {
return std::make_unique<LaunchToDeviceAttributePass>();
}
static PassRegistration<LaunchToDeviceAttributePass> pass(
"tf-launch-to-device-attribute",
"Hoists and annotates device launch inner ops with associated device "
"attribute");
} // namespace TFDevice
} // namespace mlir

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
// Here, the element type can be any integer or float type. But, note that only
@ -122,7 +122,7 @@ def LowerSparseSoftmaxCrossEntropyWithLogitsOp : Pattern<
//===----------------------------------------------------------------------===//
def ComplexTensor : TensorOf<[AnyComplex]>;
def RealTensor : TensorOf<[AnyInteger, AnyFloat]>;
def RealTensor : TensorOf<[AnySignlessInteger, AnyFloat]>;
def : Pat<(TF_SquareOp $val), (TF_MulOp $val, $val)>;
@ -179,7 +179,7 @@ def LowerL2LossOp :
// Pad op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_PadOp TensorOf<[AnyInteger, AnyFloat]>:$input, $paddings),
def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings),
(TF_PadV2Op $input, $paddings,
(TF_ConstOp (GetScalarOfType<0> $input)))>;
@ -224,6 +224,6 @@ def CreateTFShapeOp : NativeCodeCall<
// TODO(hinsu): Support inputs of TensorList types.
def LowerZerosLikeOp :
Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnyInteger, AnyFloat]>:$input),
Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input),
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
(CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;

View File

@ -15,16 +15,27 @@ limitations under the License.
// This pass optimizes tf_saved_model.global_tensor ops.
#include <cstddef>
#include <map>
#include <set>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace tf_saved_model {
@ -45,8 +56,142 @@ struct GlobalTensorUse {
using GlobalTensorUsesMap =
std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
static bool IsResourceType(Type type) {
if (auto tensor_type = type.dyn_cast<TensorType>()) {
return tensor_type.getElementType().isa<TF::ResourceType>();
}
return false;
}
static bool IsResource(Value value) { return IsResourceType(value.getType()); }
class ResourceAnalyzer {
public:
explicit ResourceAnalyzer(ModuleOp module) {
SymbolTable symbol_table(module);
for (auto func : module.getOps<FuncOp>()) {
AnalyzeFunc(func, symbol_table);
}
}
bool IsPotentiallyWritten(Value resource) const {
assert(IsResource(resource));
auto it = resource_infos_.find(resource);
if (it == resource_infos_.end()) {
return false;
}
return it->second.potentially_written;
}
private:
// Analyze the specified func for resource mutating operations, namely
// TF::AssignVariableOp, if so, set the resource associated as "potentially
// written". Do this recursively across the chain of funcs via call or control
// flow ops.
// TODO(ashwinm): Move to iterative traversal.
LogicalResult AnalyzeFunc(FuncOp func, const SymbolTable& symbol_table) {
// Avoid infinite recursion.
if (!discovered_.insert(func).second) {
return success();
}
func.walk([&](Operation* op) {
if (isa<TF::ReadVariableOp>(op) || isa<ReturnOp>(op)) {
return;
}
if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
SetPotentiallyWritten(assign_variable.resource());
return;
}
if (auto call = dyn_cast<CallOpInterface>(op)) {
if (auto sym = op->getAttrOfType<SymbolRefAttr>("f")) {
PropagatePotentiallyWrittenUpFromCallee(
sym.cast<FlatSymbolRefAttr>().getValue(), call.getArgOperands(),
symbol_table);
}
return;
}
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
for (auto callee : {if_op.then_branch(), if_op.else_branch()}) {
PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input(),
symbol_table);
}
return;
}
if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
for (auto callee : {while_op.cond(), while_op.body()}) {
PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input(),
symbol_table);
}
return;
}
// For all other ops, we assume it mutates all resources it uses, so
// this errs on the side of being conservative. We should improve
// this by using either a property or a trait that clearly
// identifies ops with resource mutating behavior.
if (PropagatePotentiallyWrittenWithinUnhandledOp(op)) {
return;
}
});
return success();
}
// If an op is not one of the handled ones, we assume all resource usages
// within its purview are mutating in nature.
bool PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
for (auto operand : op->getOperands()) {
if (IsResource(operand)) {
SetPotentiallyWritten(operand);
return true;
}
}
bool uses_resources = false;
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
if (IsResource(operand->get())) {
SetPotentiallyWritten(operand->get());
uses_resources = true;
}
});
return uses_resources;
}
// Given a funcOp associated with the callee and operands from the
// corresponding callOp, propagate the potentially written decision to the
// callOp's operands, if the corresponding func's arguments are potentially
// written resources.
void PropagatePotentiallyWrittenUpFromCallee(
StringRef callee, Operation::operand_range propagate_to,
const SymbolTable& symbol_table) {
auto func = symbol_table.lookup<FuncOp>(callee);
AnalyzeFunc(func, symbol_table);
for (auto t : llvm::zip(func.getArguments(), propagate_to)) {
if (!IsResource(std::get<0>(t))) {
continue;
}
if (IsPotentiallyWritten(std::get<0>(t))) {
SetPotentiallyWritten(std::get<1>(t));
}
}
}
void SetPotentiallyWritten(Value resource) {
assert(IsResource(resource));
resource_infos_[resource].potentially_written = true;
}
struct ResourceInfo {
bool potentially_written = false;
};
// Key: Resource Value's
// Value: Information we know about that Value.
// Note that these Value's are in general in different functions.
DenseMap<Value, ResourceInfo> resource_infos_;
// The set of func's we already discovered.
DenseSet<FuncOp> discovered_;
};
bool IsImmutable(GlobalTensorOp global_tensor,
ArrayRef<GlobalTensorUse> global_tensor_uses) {
ArrayRef<GlobalTensorUse> global_tensor_uses,
const ResourceAnalyzer& resource_analyzer) {
// Global tensor is already known to be immutable.
if (!global_tensor.is_mutable()) {
return false;
@ -57,17 +202,11 @@ bool IsImmutable(GlobalTensorOp global_tensor,
return false;
}
// Check the uses to see if this global tensor is only used in a way that
// is compatible with being immutable.
// Right now, this uses a very simple algorithm that only checks the top-level
// func for tf.ReadVariableOp. If the resource is passed into other functions
// or control flow, we fail to prove it is freezable even though we could.
// A global tensor is immutable if the resource analyzer deems it so.
for (auto& global_tensor_use : global_tensor_uses) {
auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
for (auto user : arg.getUsers()) {
if (!isa<TF::ReadVariableOp>(user)) {
return false;
}
if (resource_analyzer.IsPotentiallyWritten(arg)) {
return false;
}
}
return true;
@ -96,11 +235,12 @@ static GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
// Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
// can prove it is safe to do so.
void MarkGlobalTensorsImmutable(
ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map) {
ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
const ResourceAnalyzer& resource_analyzer) {
for (const auto& kv : global_tensor_uses_map) {
auto global_tensor = kv.first;
const auto& global_tensor_uses = kv.second;
if (IsImmutable(global_tensor, global_tensor_uses)) {
if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
global_tensor.removeAttr("is_mutable");
}
}
@ -137,17 +277,15 @@ void EraseUnusedBoundInputs(ModuleOp module) {
}
void OptimizeGlobalTensorsPass::runOnModule() {
// This analysis could be much more elaborate, including tracking global
// tensors interprocedurally and uses in a wide variety of ops. But I don't
// know if we need that complexity.
auto module = getModule();
EraseUnusedBoundInputs(module);
// Figure out which func's use each tf_saved_model.global_tensor.
ResourceAnalyzer resource_analyzer(module);
GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
MarkGlobalTensorsImmutable(module, global_tensor_uses);
MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
EraseUnusedGlobalTensors(module, global_tensor_uses);
}

View File

@ -186,6 +186,10 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateParallelExecuteToIslandsPass();
// same data across replicas.
std::unique_ptr<OpPassBase<ModuleOp>> CreateAnnotateParameterReplicationPass();
// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
// attribute to each TensorFlow dialect op in the body based on the `device`
// attribute on the `tf_device.launch`.
std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass();
} // namespace TFDevice
namespace TFTPU {

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project

View File

@ -42,7 +42,7 @@ limitations under the License.
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/Verifier.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/OpDefinition.h" // TF:llvm-project

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Identifier.h" // TF:llvm-project

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/hlo_module_importer.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project

View File

@ -1130,7 +1130,7 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
// Illegal attributes.
ShapedType attr_ty = start_indices.getType();
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
!attr_ty.getElementType().isInteger(64) ||
!attr_ty.getElementType().isSignlessInteger(64) ||
limit_indices.getType() != attr_ty || strides.getType() != attr_ty)
return ty;

View File

@ -52,7 +52,7 @@ def HLO_FpTensor : TensorOf<[AnyFloat]>;
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger, AnyComplex]>;
def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def HLO_ComplexTensor : TensorOf<[AnyComplex]>;
@ -64,13 +64,13 @@ def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
// an index type (as it stores indices) but that is currently disallowed in
// MLIR.
def HLO_DimensionTensor : ShapedContainerType<
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
[AnySignlessInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
"a 1D tensor of dimensions">;
// In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes.
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
AnyFloat, AnyInteger, AnyComplex]>;
AnyFloat, AnySignlessInteger, AnyComplex]>;
//===----------------------------------------------------------------------===//
// XLA combined type definitions.
@ -784,7 +784,7 @@ def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor",
compute shape arguments to dynamic operations.
}];
let arguments = (ins Variadic<AnyInteger>);
let arguments = (ins Variadic<AnySignlessInteger>);
let results = (outs HLO_DimensionTensor);
// Cannot be exported to legacy formats.

View File

@ -40,7 +40,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
// Handle integer elements.
Attribute elementAttr;
if (valElementType.isa<IntegerType>())
if (valElementType.isSignlessInteger())
elementAttr = b->getIntegerAttr(valElementType, constant);
else if (valElementType.isa<FloatType>())
elementAttr = b->getFloatAttr(valElementType, constant);

View File

@ -42,7 +42,7 @@ def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger]>;
def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>;

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
@ -1221,7 +1221,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
<< "requires arg " << padding_arg_index
<< " to be a scalar for use as a dynamic parameter";
if (!mlir::getElementTypeOrSelf(padding_arg_type).isa<IntegerType>())
if (!mlir::getElementTypeOrSelf(padding_arg_type).isSignlessInteger())
return entry_func.emitError()
<< "requires arg " << padding_arg_index
<< " to be of an int type for use as a dynamic parameter";

View File

@ -210,7 +210,6 @@ func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) {
// -----
// CHECK-DAG: #[[OPERAND_MPA:.*]] = affine_map<(d0, d1, d2) -> (0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: memref<f32>, %result: memref<7x10x6xf32>) {
@ -219,9 +218,10 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<7x10x6xf32>) {
: (memref<f32>, memref<7x10x6xf32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:[a-zA-Z0-9_]*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[RESULT:.*]]: f32):
// CHECK-NEXT: %[[CONST:.*]] = load %{{.*}} : memref<f32>
// CHECK-NEXT: linalg.yield %[[CONST]] : f32
// -----

View File

@ -16,7 +16,7 @@ limitations under the License.
// This file implements logic for lowering HLO dialect to LHLO dialect.
#include "absl/memory/memory.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/Dialect/Traits.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
@ -119,7 +119,7 @@ static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
Type GetSumAccumulationType(Type input_type) {
MLIRContext *ctx = input_type.getContext();
if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
if (input_type.isInteger(8) || input_type.isInteger(16))
if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
return IntegerType::get(32, ctx);
return input_type;
}
@ -1274,7 +1274,7 @@ class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> {
PatternRewriter &rewriter) const override {
Type element_type =
op.input().getType().cast<TensorType>().getElementType();
if (!element_type.isIntOrFloat()) return matchFailure();
if (!element_type.isSignlessIntOrFloat()) return matchFailure();
Location loc = op.getLoc();
ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
@ -2248,7 +2248,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
Type input_element_type = input_type.getElementType();
// TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
// tf.ArgMax doesn't support complex data types, this check can be removed.
if (!input_element_type.isIntOrFloat()) return this->matchFailure();
if (!input_element_type.isSignlessIntOrFloat()) return this->matchFailure();
Location loc = op.getLoc();
Value init_value =

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/iterator_range.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project

View File

@ -16,7 +16,7 @@ limitations under the License.
// This is the legalization pattern definition file for TF to XLA.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
@ -504,8 +504,8 @@ def : Pat<(TF_SignOp $x),
)>;
def BothElementTypesSameWidthIntOrFloat : Constraint<CPred<
"getElementTypeOrSelf($0.getType()).isIntOrFloat() && "
"getElementTypeOrSelf($1.getType()).isIntOrFloat() && "
"getElementTypeOrSelf($0.getType()).isSignlessIntOrFloat() && "
"getElementTypeOrSelf($1.getType()).isSignlessIntOrFloat() && "
"getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == "
"getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">,
"element types must be integers or floats of same width">;

View File

@ -16,7 +16,7 @@ limitations under the License.
// This file implements logic for lowering XLA dialect to Standard dialect.
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
@ -45,8 +45,8 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
// Broadcasting not supported by this rewrite.
if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();
if (!lhs_type.getElementType().isa<IntegerType>() ||
!rhs_type.getElementType().isa<IntegerType>())
if (!lhs_type.getElementType().isSignlessInteger() ||
!rhs_type.getElementType().isSignlessInteger())
return matchFailure();
auto comparison_direction = op.comparison_direction();
@ -113,7 +113,8 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
PatternRewriter &rewriter) const override {
auto output_type = op.getType().cast<ShapedType>();
// TODO(prakalps): Handle FP and ComplexType iota ops.
if (!output_type.getElementType().isa<IntegerType>()) return matchFailure();
if (!output_type.getElementType().isSignlessInteger())
return matchFailure();
auto output_size = output_type.getNumElements();
auto dimension = op.iota_dimension().getSExtValue();
auto max_dim_size = output_type.getDimSize(dimension);

View File

@ -16,7 +16,7 @@ limitations under the License.
// This is the legalization pattern definition file for XLA to StandardOps.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
//===----------------------------------------------------------------------===//

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project
#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project

View File

@ -17,7 +17,7 @@ limitations under the License.
// equivalent real value operations.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
//===----------------------------------------------------------------------===//

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
@ -206,7 +206,7 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
const auto& lhs = args[0];
const auto& rhs = args[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
if (element_type.isSignlessInteger()) {
Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction());
assert(predicate.hasValue() && "expected valid comparison direction");
@ -288,8 +288,8 @@ template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
const Type& sourceType = args.front().getType();
const Type& targetType = result_types.front();
Type sourceType = args.front().getType();
Type targetType = result_types.front();
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args,
@ -307,7 +307,7 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
// No conversion is needed for the same width floats
return args.front();
}
if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
IntegerType src = sourceType.cast<IntegerType>();
IntegerType res = targetType.cast<IntegerType>();
if (src.getWidth() > res.getWidth()) {

View File

@ -15,7 +15,7 @@ limitations under the License.
#include <numeric>
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
@ -83,7 +83,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
emitError(loc, "lhlo to linalg conversion expects ranked args");
return ConversionPattern::matchFailure();
}
if (!argType.getElementType().isIntOrFloat()) {
if (!argType.getElementType().isSignlessIntOrFloat()) {
return ConversionPattern::matchFailure();
}
@ -171,7 +171,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
auto loc = lhlo_op.getLoc();
auto argType =
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.getElementType().isIntOrFloat() ||
if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
(argType.getRank() != 0)) {
return ConversionPattern::matchFailure();
}
@ -208,6 +208,9 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
auto resultType = getXLAOpResultType<isLHLO>(op);
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op))
return ConversionPattern::matchFailure();
// TODO(b/150203558) Enable once tiling/fusion works in this case.
if (isLHLO && (operandType.getRank() == 0))
return ConversionPattern::matchFailure();
ArrayAttr indexingMapsAttr =
static_cast<const Derived&>(*this).getIndexingMapsAttr(op, &rewriter);
if (!indexingMapsAttr) return ConversionPattern::matchFailure();
@ -278,6 +281,52 @@ class BroadcastInDimConverter
}
};
// Special case for scalar broadcast in lhlo.
// TODO(b/150203558) Remove once the bug is fixed.
class ScalarBroadcastInDimConverter
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
public:
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
PatternMatchResult matchAndRewrite(
xla_lhlo::BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto operandMemrefType =
broadcastOp.operand().getType().dyn_cast<MemRefType>();
// Only support scalar operands.
if (operandMemrefType.getRank() != 0) return matchFailure();
auto resultMemrefType =
broadcastOp.output().getType().dyn_cast<MemRefType>();
if (!operandMemrefType || !resultMemrefType) return matchFailure();
auto broadcastDims = broadcastOp.broadcast_dimensions();
if (!broadcastDims.hasValue()) return matchFailure();
unsigned nloops = resultMemrefType.getRank();
SmallVector<Attribute, 1> indexingMaps{
AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops))};
auto loc = broadcastOp.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>{}, broadcastOp.output(),
rewriter.getI64IntegerAttr(0), // args_in
rewriter.getI64IntegerAttr(1), // args_out
rewriter.getArrayAttr(indexingMaps),
GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);
// Add a block to the region.
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
block->addArguments(resultMemrefType.getElementType());
rewriter.setInsertionPointToEnd(block);
auto scalar =
rewriter.create<LoadOp>(loc, broadcastOp.operand(), llvm::None);
rewriter.create<linalg::YieldOp>(loc, scalar.getResult());
rewriter.eraseOp(broadcastOp);
return matchSuccess();
}
};
template <typename OpTy, bool isLHLO = true>
class TransposeConverter
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
@ -385,7 +434,7 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
if (!resultMemrefType) return matchFailure();
auto resultElementType = resultMemrefType.getElementType();
if (!resultElementType.isIntOrFloat()) return matchFailure();
if (!resultElementType.isSignlessIntOrFloat()) return matchFailure();
// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultMemrefType.getRank();
@ -502,6 +551,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>,
ScalarBroadcastInDimConverter,
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
SliceConverter
>(context);

View File

@ -1514,6 +1514,7 @@ cuda_py_test(
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
xla_enable_strict_auto_jit = False,
xla_enabled = True,
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
@ -1534,6 +1535,7 @@ cuda_py_test(
"no_rocm",
],
xla_enable_strict_auto_jit = False,
xla_enabled = True,
deps = [
":test_utils",
"//tensorflow/core:protos_all_py",
@ -1558,6 +1560,7 @@ cuda_py_test(
"no_rocm",
],
xla_enable_strict_auto_jit = False,
xla_enabled = True,
deps = [
":test_utils",
"//tensorflow/core:protos_all_py",
@ -1650,6 +1653,7 @@ cuda_py_test(
"no_rocm",
],
xla_enable_strict_auto_jit = False,
xla_enabled = True,
deps = [
":lstm",
":xla_test",

View File

@ -695,7 +695,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
wholly_compiled_f = def_function.function(f)
op_by_op_f = def_function.function(f, experimental_compile=False)
x = constant_op.constant([0.0, 2.0], name='data')
x = array_ops.identity([0.0, 2.0], name='data')
# When function is wholly compiled, all outputs will be on the
# device on which it is run.

Some files were not shown because too many files have changed in this diff Show More