Merge branch 'master' into nhasabni/fixes_for_dnnl1.0
This commit is contained in:
commit
b96c0010fa
15
.bazelrc
15
.bazelrc
@ -365,13 +365,14 @@ build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe
|
||||
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010,@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@org_tensorflow//third_party/toolchains:rbe_cuda10.1-cudnn7-ubuntu16.04-manylinux2010"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/cuda10.1-cudnn7"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/tensorrt6.0"
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NEED_TENSORRT=1
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_VERSION=10
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDNN_VERSION=7
|
||||
|
@ -70,7 +70,7 @@ $ python
|
||||
3
|
||||
>>> hello = tf.constant('Hello, TensorFlow!')
|
||||
>>> hello.numpy()
|
||||
'Hello, TensorFlow!'
|
||||
b'Hello, TensorFlow!'
|
||||
```
|
||||
|
||||
For more examples, see the
|
||||
|
@ -418,21 +418,11 @@ void TensorHandleSilentCopy(bool async,
|
||||
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
if (!async) {
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
} else {
|
||||
if (cpu_op) {
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
// The GPU handle should be replaced with a CPU copy
|
||||
ASSERT_NE(op->GetInput(1), arg1);
|
||||
} else {
|
||||
// The CPU handle should be replaced with a GPU copy
|
||||
ASSERT_NE(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
}
|
||||
}
|
||||
|
||||
// The input handles should never change since they have been mirrored.
|
||||
EXPECT_EQ(op->GetInput(0), arg0);
|
||||
EXPECT_EQ(op->GetInput(1), arg1);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
|
@ -14,6 +14,10 @@ package_group(
|
||||
includes = [
|
||||
"//tensorflow/compiler/tf2xla:internal",
|
||||
],
|
||||
packages = [
|
||||
"//tensorflow/compiler/tests/...",
|
||||
"//tensorflow/python/...",
|
||||
],
|
||||
)
|
||||
|
||||
package_group(
|
||||
|
@ -46,7 +46,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
|
@ -42,7 +42,7 @@ limitations under the License.
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
@ -165,7 +165,7 @@ constexpr size_t kInitialBufferSize = 10240;
|
||||
// `isSigned` is set to false for other types.
|
||||
static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
bool is_signed = true) {
|
||||
if (!is_signed && type.isInteger(8)) {
|
||||
if (!is_signed && type.isSignlessInteger(8)) {
|
||||
return tflite::TensorType_UINT8;
|
||||
}
|
||||
if (!is_signed) {
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
@ -275,7 +275,7 @@ Attribute ConstFoldBinaryOp(
|
||||
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
|
||||
float_calculate, is_commutative);
|
||||
|
||||
if (elemType.isa<IntegerType>())
|
||||
if (elemType.isSignlessInteger())
|
||||
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
|
||||
int_calculate, is_commutative);
|
||||
|
||||
@ -1560,7 +1560,7 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
|
||||
limit_tensor.getType().getRank() == 0 &&
|
||||
delta_tensor.getType().getRank() == 0);
|
||||
Type elem_type = getType().cast<ShapedType>().getElementType();
|
||||
if (elem_type.isa<IntegerType>()) {
|
||||
if (elem_type.isSignlessInteger()) {
|
||||
auto start_attr = start_tensor.getValue<IntegerAttr>({});
|
||||
auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
|
||||
auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
|
||||
@ -1662,7 +1662,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
// Do not try to fold elements attr of a quant type because
|
||||
// DenseElementsAttr does not support it.
|
||||
if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
|
||||
if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
|
||||
return nullptr;
|
||||
|
||||
assert(perm_tensor.getType().getRank() == 1);
|
||||
|
@ -3100,7 +3100,9 @@ def LstmMandatoryInputsConstraint : PredOpTrait<
|
||||
"mandatory operands element types should match",
|
||||
// TODO(ashwinm): Replace the indices with input tensor names when that
|
||||
// support is available.
|
||||
TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>>;
|
||||
Or<[
|
||||
TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>,
|
||||
Neg<TypeIsPred<"input", F32>>]>>;
|
||||
|
||||
def LstmOptionalPeepholeWeightConstraint : PredOpTrait<
|
||||
"the optional peephole weights should all be specified or none",
|
||||
@ -3126,7 +3128,7 @@ def LstmProjectionWeightBiasConstraint : PredOpTrait<
|
||||
|
||||
def LstmResultConstraint : PredOpTrait<
|
||||
"the input and result tensor elemental types must be same",
|
||||
TCresVTEtIsSameAsOp<0, 0>>;
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>;
|
||||
|
||||
// This is the basic kernel type LSTM op.
|
||||
// TODO(b/142417845): Refactor this part to return its tflite node name as
|
||||
@ -3195,47 +3197,47 @@ Ba et al. “Layer Normalization”
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32]>:$input,
|
||||
ins TFL_TensorOf<[F32, QI8]>:$input,
|
||||
|
||||
// Weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$input_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$input_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$input_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8, QI8]>:$input_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$input_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$input_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$input_to_output_weights,
|
||||
|
||||
// Recurrent weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8, QI8]>:$recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_output_weights,
|
||||
|
||||
// Cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_input_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_forget_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_output_weights,
|
||||
|
||||
// Bias
|
||||
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
|
||||
TFL_TensorOf<[F32]>:$forget_gate_bias,
|
||||
TFL_TensorOf<[F32]>:$cell_bias,
|
||||
TFL_TensorOf<[F32]>:$output_gate_bias,
|
||||
TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias,
|
||||
TFL_TensorOf<[F32, QI32]>:$forget_gate_bias,
|
||||
TFL_TensorOf<[F32, QI32]>:$cell_bias,
|
||||
TFL_TensorOf<[F32, QI32]>:$output_gate_bias,
|
||||
|
||||
// Projection weight and bias
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$projection_weights,
|
||||
TFL_TensorOfOrNone<[F32, I8, QI8]>:$projection_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32]>:$projection_bias,
|
||||
TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias,
|
||||
|
||||
// Stateful activation and cell states.
|
||||
TFL_StatefulTensor:$input_activation_state,
|
||||
TFL_StatefulTensor:$input_cell_state,
|
||||
|
||||
// Layer norm coefficients
|
||||
TFL_TensorOfOrNone<[F32]>:$input_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32]>:$forget_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32]>:$cell_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32]>:$output_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, QI16]>:$input_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, QI16]>:$forget_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, QI16]>:$cell_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, QI16]>:$output_layer_norm_coefficients,
|
||||
|
||||
// Attributes
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
@ -191,7 +191,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
||||
inputs.push_back(op_inst.input());
|
||||
} else if (ele_type.isa<IntegerType>()) {
|
||||
} else if (ele_type.isSignlessInteger()) {
|
||||
// If the operand is an integer tensor, then it doesn't require the
|
||||
// DQ op in the pattern.
|
||||
inputs.push_back(operand);
|
||||
@ -225,7 +225,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
auto user = llvm::cast<Q>(*result.user_begin());
|
||||
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
||||
output_types.push_back(user.getType());
|
||||
} else if (result_ele_type.template isa<IntegerType>()) {
|
||||
} else if (result_ele_type.isSignlessInteger()) {
|
||||
// If the result is an integer tensor, then it doesn't require the
|
||||
// D op in the pattern.
|
||||
outputs_replaced.insert({result, enumerated_result.index()});
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
|
@ -27,6 +27,20 @@ func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithNonTrivialDilations
|
||||
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||
// CHECK-NEXT: return [[RESULT]]
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
@ -104,7 +118,7 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1:
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
@ -115,7 +129,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -125,7 +139,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
@ -136,7 +150,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -146,7 +160,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
@ -157,7 +171,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -167,7 +181,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
@ -178,7 +192,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -188,7 +202,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
@ -200,7 +214,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -210,7 +224,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
@ -222,10 +236,29 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
|
||||
return %4 : tensor<1x128x128x1xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis
|
||||
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||
// CHECK-NEXT: return [[RESULT]]
|
||||
}
|
||||
|
@ -593,6 +593,21 @@ func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>,
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLstmQuantizedType
|
||||
func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 0.01>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
|
||||
%cst = constant unit
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
|
||||
}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
// CHECK: %[[RES0:.*]] = constant unit
|
||||
// CHECK: %[[RES1:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %[[RES0]], %[[RES0]], %[[RES0]], %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ( {
|
||||
// CHECK-NEXT: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<2048x!quant.uniform<i32:f32, 1.000000e-02>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 9.9999999747524271E-7>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
// CHECK: return %[[RES1]]
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLstm
|
||||
|
@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||
@ -165,7 +165,7 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
@ -189,7 +189,7 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
|
||||
@ -200,7 +200,7 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
@ -224,3 +224,84 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te
|
||||
// CHECK: return [[VAL_24]] : tensor<8x8x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x?x10xf32>
|
||||
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<?x8x8xf32>, tensor<1xi32>) -> tensor<?x8x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_21:%.*]] = constant unit
|
||||
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
|
||||
// CHECK: return [[VAL_23:%.*]] : tensor<8x?x10xf32>
|
||||
// CHECK: }
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
|
||||
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<0> : tensor<1xi32>
|
||||
// CHECK: [[VAL_9:%.*]] = "tf.ReverseV2"([[VAL_7]], [[VAL_8]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32>
|
||||
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_10]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_12:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_13:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_12]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_16:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_14]], [[VAL_15]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_19:%.*]]:4 = "tf.SplitV"([[VAL_13]], [[VAL_17]], [[VAL_18]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||
// CHECK: [[VAL_20:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_23:%.*]] = constant unit
|
||||
// CHECK: [[VAL_24:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||
// CHECK: return [[VAL_26]] : tensor<8x8x10xf32>
|
||||
// CHECK: }
|
||||
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -80,6 +81,17 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
|
||||
template <typename Conv2dOpTy>
|
||||
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
Conv2dOpTy op, PatternRewriter& rewriter) const {
|
||||
// Make sure Conv2D has 'VALID' padding.
|
||||
if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
// Make sure dilations are all ones if set.
|
||||
const ArrayAttr& dilations =
|
||||
op.template getAttrOfType<ArrayAttr>("dilations");
|
||||
if (dilations && !TFIntListIsAllOnes(dilations)) {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
|
||||
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
|
||||
// `Squeeze` op.
|
||||
Operation* prev_op = op.getOperation()->getPrevNode();
|
||||
@ -90,6 +102,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
|
||||
TF::ExpandDimsOp expand_op;
|
||||
TF::SqueezeOp squeeze_op;
|
||||
int64_t expand_axis;
|
||||
// Expand + Squeeze op.
|
||||
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
||||
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
||||
@ -99,6 +112,22 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
|
||||
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
|
||||
|
||||
// Make sure that the axis in `expand_op` is constant.
|
||||
if (auto const_op =
|
||||
llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
|
||||
expand_axis =
|
||||
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
|
||||
.getSExtValue();
|
||||
} else {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
|
||||
auto squeeze_dims = squeeze_op.squeeze_dims();
|
||||
if (squeeze_dims.size() != 1 ||
|
||||
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
|
||||
// Update previous/next op pointer.
|
||||
prev_op = prev_op->getPrevNode();
|
||||
if (!prev_op) return Pattern::matchFailure();
|
||||
@ -108,10 +137,14 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
|
||||
// SpaceToBatchND op.
|
||||
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
|
||||
// TODO(b/149936532): Check `padding` input, currently ignored.
|
||||
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
|
||||
|
||||
// Pad op.
|
||||
TF::PadOp pad_op;
|
||||
// TODO(b/149936532): Currently we just ignore the PadOp. However note that
|
||||
// in real scenarios this may not always be correct: user can put a PadOp here
|
||||
// with non-trivial consequences.
|
||||
if (llvm::isa<TF::PadOp>(next_op)) {
|
||||
pad_op = llvm::cast<TF::PadOp>(next_op);
|
||||
next_op = next_op->getNextNode();
|
||||
@ -119,6 +152,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
}
|
||||
|
||||
// BatchToSpaceND + BiasAdd.
|
||||
// TODO(b/149936532): Check the `crops` input, currently ignored.
|
||||
TF::BatchToSpaceNDOp bts_op;
|
||||
TF::BiasAddOp biasadd_op;
|
||||
bool final_op_is_bts = true;
|
||||
@ -146,14 +180,10 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
|
||||
op.setAttr("dilations", dilations_attr.getValue());
|
||||
|
||||
// Here we need to set the correct padding for Conv op. In TF, the conv op
|
||||
// inserted after 'SpaceToBatch' always has 'VALID' padding. This might
|
||||
// become a problem here if the original Conv op has 'SAME' padding. When
|
||||
// the original conv has 'SAME' padding, TF will set a non-zero padding for
|
||||
// the 'SpaceToBatch' op, so we rely on this information to check if we need
|
||||
// to change the padding from 'VALID' to 'SAME' (a.k.a when we see non-zero
|
||||
// values in `stb_op.paddings`, we change the current Conv's padding to
|
||||
// 'SAME').
|
||||
// Padding is set to 'SAME' when `stb_op` has non-zero paddings.
|
||||
// TODO(b/149936532): This assumption only holds when the input width & height
|
||||
// is multiple of dilation width & height. We should fix it in order to
|
||||
// support other use cases.
|
||||
auto stb_paddings = stb_op.paddings();
|
||||
ElementsAttr stb_paddings_attr;
|
||||
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
|
||||
@ -175,7 +205,8 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
|
||||
SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
|
||||
input_shape.end());
|
||||
expand_shape.push_back(1);
|
||||
expand_shape.insert(expand_shape.begin() + expand_axis, 1);
|
||||
|
||||
auto expand_result_type = RankedTensorType::get(
|
||||
expand_shape, getElementTypeOrSelf(stb_op.input()));
|
||||
expand_op.getResult().setType(expand_result_type);
|
||||
@ -208,7 +239,7 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
||||
ElementsAttr stb_bs_attr, bts_bs_attr;
|
||||
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
|
||||
!matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
|
||||
// Returns failure status if block shape is not a constant.
|
||||
// Returns failure status if block_shape is not a constant.
|
||||
return {};
|
||||
}
|
||||
// Check that the block_shape of `stb_op` and `bts_op` are equal.
|
||||
@ -217,9 +248,8 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
||||
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
|
||||
}
|
||||
|
||||
// TODO(haoliang): support 1-D dilated conv.
|
||||
// Set dilation factor.
|
||||
if (stb_bs_attr.getNumElements() < 2) return {};
|
||||
|
||||
int dilation_h_factor =
|
||||
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||
int dilation_w_factor =
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// TFLite legalization patterns
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
@ -341,7 +341,7 @@ def : Pat<(TF_MatrixDiagOp $diagonal), (TFL_MatrixDiagOp $diagonal)>;
|
||||
class I32VectorElementsAttr<int len> : ElementsAttrBase<
|
||||
CPred<"$_self.isa<DenseIntElementsAttr>() &&"
|
||||
"$_self.cast<DenseIntElementsAttr>().getType()."
|
||||
"getElementType().isInteger(32)">,
|
||||
"getElementType().isSignlessInteger(32)">,
|
||||
"32-bit int elements attribute of shape [" # len # "]"> {
|
||||
|
||||
let storageType = [{ DenseIntElementsAttr }];
|
||||
|
@ -255,7 +255,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
|
||||
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
||||
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
||||
if (!shape_type.getElementType().isInteger(32)) {
|
||||
if (!shape_type.getElementType().isSignlessInteger(32)) {
|
||||
auto new_shape = shape_type.getShape();
|
||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
// Converts TF While to TFL While with single call in body and cond.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
@ -32,7 +32,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
@ -335,8 +335,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type dtype = op.element_dtype();
|
||||
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
||||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
|
||||
dtype.isInteger(32) || dtype.isInteger(64))) {
|
||||
dtype.isInteger(1) || dtype.isSignlessInteger(8) ||
|
||||
dtype.isSignlessInteger(16) || dtype.isSignlessInteger(32) ||
|
||||
dtype.isSignlessInteger(64))) {
|
||||
op.emitError(
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This is the optimization pattern definition file for TensorFlow Lite.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This is the quantization pattern definition file for TensorFlow Lite.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
// Both Quantize and Dequantize ops have side effects, so we have to define
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This is the quantization pattern definition file for TensorFlow Lite.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
// Quantize attribute $0 by using quantization parameter from %1.
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
|
@ -38,7 +38,7 @@ FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) {
|
||||
|
||||
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) {
|
||||
if (attr.getType().getNumElements() != 1 ||
|
||||
!attr.getType().getElementType().isa<IntegerType>()) {
|
||||
!attr.getType().getElementType().isSignlessInteger()) {
|
||||
return {};
|
||||
}
|
||||
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
@ -95,6 +95,14 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
||||
return Transpose(builder, value_to_transpose, perm, type, location);
|
||||
}
|
||||
|
||||
Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis,
|
||||
RankedTensorType type, mlir::Location location) {
|
||||
auto axis_op = CreateI32SplatConst(builder, {1}, axis, location);
|
||||
// The result type will be the same as the input.
|
||||
return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse,
|
||||
axis_op);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
|
||||
return value.getType().cast<RankedTensorType>().getShape();
|
||||
}
|
||||
@ -615,6 +623,16 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
final_input_type = final_inputs.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
// Handle go_backwards:
|
||||
// LSTM in Keras semantic will reverse the input sequence if it's go_backwards
|
||||
auto go_backwards_attr = func_op.getAttrOfType<BoolAttr>("tf.go_backwards");
|
||||
|
||||
if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
|
||||
// We assume input is already in {time, batch, size} layout.
|
||||
final_inputs =
|
||||
Reverse(builder, final_inputs, 0, final_input_type, func_op.getLoc());
|
||||
}
|
||||
|
||||
int batch = final_input_type.getDimSize(1);
|
||||
int time = final_input_type.getDimSize(0);
|
||||
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
|
@ -90,7 +90,7 @@ gentbl(
|
||||
td_file = "ir/tf_saved_model_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -114,7 +114,7 @@ gentbl(
|
||||
td_file = "ir/tf_executor_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -138,7 +138,7 @@ gentbl(
|
||||
td_file = "ir/tf_device_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -281,6 +281,7 @@ cc_library(
|
||||
"transforms/generated_canonicalize.inc",
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/graph_pruning.cc",
|
||||
"transforms/launch_to_device_attribute.cc",
|
||||
"transforms/layout_optimization.cc",
|
||||
"transforms/mark_function_visibility.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -3111,6 +3111,70 @@ cublas.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, AllTypesMatch<["input", "band"]>]> {
|
||||
let summary = [{
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix
|
||||
to zero.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The `band` part is computed as follows:
|
||||
Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
|
||||
tensor with the same shape where
|
||||
|
||||
`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
|
||||
|
||||
The indicator function
|
||||
|
||||
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
|
||||
(num_upper < 0 || (n-m) <= num_upper)`.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# if 'input' is [[ 0, 1, 2, 3]
|
||||
[-1, 0, 1, 2]
|
||||
[-2, -1, 0, 1]
|
||||
[-3, -2, -1, 0]],
|
||||
|
||||
tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
|
||||
[-1, 0, 1, 2]
|
||||
[ 0, -1, 0, 1]
|
||||
[ 0, 0, -1, 0]],
|
||||
|
||||
tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
|
||||
[-1, 0, 1, 0]
|
||||
[-2, -1, 0, 1]
|
||||
[ 0, -2, -1, 0]]
|
||||
```
|
||||
|
||||
Useful special cases:
|
||||
|
||||
```
|
||||
tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
|
||||
tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
|
||||
tf.matrix_band_part(input, 0, 0) ==> Diagonal.
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
TF_I32OrI64Tensor:$num_lower,
|
||||
TF_I32OrI64Tensor:$num_upper
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$band
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tindex = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MatrixDiagOp : TF_Op<"MatrixDiag", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a batched diagonal tensor with a given batched diagonal values.
|
||||
|
@ -85,7 +85,7 @@ class TF_TensorFlowType <string name, string description> :
|
||||
|
||||
// Any tensor element type allowed in TensorFlow ops
|
||||
def TF_ElementType : Type<Or<[AnyFloat.predicate,
|
||||
AnyInteger.predicate,
|
||||
AnySignlessInteger.predicate,
|
||||
AnyComplex.predicate,
|
||||
TF_TFDialectType.predicate]>,
|
||||
"tf.dtype">;
|
||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
@ -1506,6 +1506,29 @@ void LogicalNotOp::getCanonicalizationPatterns(
|
||||
LogicalNotOfLess, LogicalNotOfLessEqual>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixBandPartOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(MatrixBandPartOp op) {
|
||||
if (!HasRankAtLeast(op.input(), 2)) {
|
||||
return op.emitOpError()
|
||||
<< "requires `input` to have rank of at least 2, but found "
|
||||
<< op.input().getType();
|
||||
}
|
||||
if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
|
||||
return op.emitOpError()
|
||||
<< "requires `num_lower` to have 0 dimensions, but found "
|
||||
<< op.num_lower().getType();
|
||||
}
|
||||
if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
|
||||
return op.emitOpError()
|
||||
<< "requires `num_upper` to have 0 dimensions, but found "
|
||||
<< op.num_upper().getType();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2104,7 +2127,8 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
|
||||
}
|
||||
|
||||
Type element_type = result_ranked_type.getElementType();
|
||||
if (!element_type.isInteger(32) && !element_type.isInteger(64))
|
||||
if (!element_type.isSignlessInteger(32) &&
|
||||
!element_type.isSignlessInteger(64))
|
||||
return op->emitOpError("requires int32 or int64 return type for result")
|
||||
<< variadic_idx_str;
|
||||
|
||||
|
@ -91,7 +91,7 @@ class TensorFlowType : public Type {
|
||||
// Returns true if the specified type is a valid TensorFlow element type.
|
||||
static inline bool IsValidTFElementType(Type type) {
|
||||
return type.isa<ComplexType>() || type.isa<FloatType>() ||
|
||||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
|
||||
type.isSignlessInteger() || type.isa<TensorFlowType>();
|
||||
}
|
||||
|
||||
// Returns true if this is a valid TensorFlow tensor type.
|
||||
|
@ -0,0 +1,123 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-launch-to-device-attribute | FileCheck %s --dump-input=fail
|
||||
|
||||
|
||||
// Tests single TensorFlow op is hoisted out and has the correct device assigned
|
||||
// by parent `tf_device.launch`.
|
||||
// CHECK-LABEL: func @single_op_launch
|
||||
func @single_op_launch() {
|
||||
tf_executor.graph {
|
||||
%0:5 = tf_executor.island {
|
||||
%a = "tf.opA"() : () -> tensor<i1>
|
||||
%launch:2 = "tf_device.launch"() ( {
|
||||
%b:2 = "tf.opB"(%a) : (tensor<i1>) -> (tensor<i32>, tensor<f32>)
|
||||
tf_device.return %b#1, %b#0 : tensor<f32>, tensor<i32>
|
||||
}) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
|
||||
%c = "tf.opC"() : () -> tensor<i1>
|
||||
tf_executor.yield %a, %launch#0, %launch#1, %c : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[A:.*]] = "tf.opA"
|
||||
// CHECK: %[[B:.*]]:2 = "tf.opB"(%[[A]])
|
||||
// CHECK-SAME: device = "CPU:0"
|
||||
// CHECK: %[[C:.*]] = "tf.opC"
|
||||
// CHECK-NOT: "tf_device.launch"
|
||||
// CHECK: tf_executor.yield %[[A]], %[[B]]#1, %[[B]]#0, %[[C]]
|
||||
|
||||
|
||||
// Tests multiple TensorFlow ops are hoisted out and all have the correct device
|
||||
// assigned by parent `tf_device.launch`.
|
||||
// CHECK-LABEL: func @multi_op_launch
|
||||
func @multi_op_launch() {
|
||||
tf_executor.graph {
|
||||
%0:5 = tf_executor.island {
|
||||
%a = "tf.opA"() : () -> tensor<i1>
|
||||
%launch:2 = "tf_device.launch"() ( {
|
||||
%b = "tf.opB"(%a) : (tensor<i1>) -> tensor<i32>
|
||||
%c = "tf.opC"(%b) : (tensor<i32>) -> tensor<f32>
|
||||
tf_device.return %c, %b : tensor<f32>, tensor<i32>
|
||||
}) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
|
||||
%d = "tf.opD"() : () -> tensor<i1>
|
||||
tf_executor.yield %a, %launch#0, %launch#1, %d : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[A:.*]] = "tf.opA"
|
||||
// CHECK: %[[B:.*]] = "tf.opB"(%[[A]])
|
||||
// CHECK-SAME: device = "CPU:0"
|
||||
// CHECK: %[[C:.*]] = "tf.opC"(%[[B]])
|
||||
// CHECK-SAME: device = "CPU:0"
|
||||
// CHECK: %[[D:.*]] = "tf.opD"
|
||||
// CHECK-NOT: "tf_device.launch"
|
||||
// CHECK: tf_executor.yield %[[A]], %[[C]], %[[B]], %[[D]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Tests ops are hoisted out and devices are set only if the `tf_device.launch`
|
||||
// contains TensorFlow ops.
|
||||
func @non_tf_dialect_op_launch() {
|
||||
tf_executor.graph {
|
||||
%0:5 = tf_executor.island {
|
||||
%a = "tf.opA"() : () -> tensor<i1>
|
||||
// expected-error@+1 {{'tf_device.launch' op must contain only 'tf' dialect ops}}
|
||||
%launch:2 = "tf_device.launch"() ( {
|
||||
%b = "tf.opB"(%a) : (tensor<i1>) -> tensor<i32>
|
||||
%c = "unknown.opC"(%b) : (tensor<i32>) -> tensor<f32>
|
||||
tf_device.return %c, %b : tensor<f32>, tensor<i32>
|
||||
}) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
|
||||
%d = "tf.opD"() : () -> tensor<i1>
|
||||
tf_executor.yield %a, %launch#0, %launch#1, %d : tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Tests TensorFlow op with conflicting `device` attribute compared to parent
|
||||
// `tf_device.launch`.
|
||||
func @conflicting_device() {
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.island {
|
||||
// expected-error@+1 {{'tf_device.launch' op inner 'tf' dialect op has conflicting 'device' attribute, got 'GPU:0' but expected 'CPU:0'}}
|
||||
"tf_device.launch"() ( {
|
||||
"tf.opA"() {device = "GPU:0"} : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "CPU:0"} : () -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Tests TensorFlow op with bad `device` attribute already set.
|
||||
func @bad_tf_device_attr() {
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.island {
|
||||
// expected-error@+1 {{'tf_device.launch' op inner 'tf' dialect op has bad 'device' attribute}}
|
||||
"tf_device.launch"() ( {
|
||||
"tf.opA"() {device = 0 : i32} : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "CPU:0"} : () -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
@ -854,6 +854,78 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOp
|
||||
func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||
return %0 : tensor<64x64xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOp3D
|
||||
func @testValidMatrixBandPartOp3D(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64x64xbf16> {
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64x64xbf16>
|
||||
return %0 : tensor<64x64x64xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOpUnranked
|
||||
func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||
return %0 : tensor<*xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||
return %0 : tensor<64x64xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||
return %0 : tensor<*xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<i64>, %arg1: tensor<64x64xi64>, %arg2: tensor<i64>) -> tensor<i64> {
|
||||
// expected-error @+1 {{op requires `input` to have rank of at least 2, but found 'tensor<i64>'}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<i64>, tensor<64x64xi64>, tensor<i64>) -> tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<32xi64>, %arg2: tensor<i64>) -> tensor<64x64xi64> {
|
||||
// expected-error @+1 {{op requires `num_lower` to have 0 dimensions, but found 'tensor<32xi64>'}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<32xi64>, tensor<i64>) -> tensor<64x64xi64>
|
||||
return %0 : tensor<64x64xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<i64>, %arg2: tensor<32xi64>) -> tensor<64x64xi64> {
|
||||
// expected-error @+1 {{op requires `num_upper` to have 0 dimensions, but found 'tensor<32xi64>'}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<i64>, tensor<32xi64>) -> tensor<64x64xi64>
|
||||
return %0 : tensor<64x64xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.{|Stateful}PartitionedCall
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -0,0 +1,220 @@
|
||||
// RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Immutability.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: This test exercises marking a global tensor as immutable after it propagates
|
||||
// via set of chained calls -> f -> f_callee -> f_callee_callee
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-NOT: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case:
|
||||
// This test exercises trying to mark immutable when same func is called by multiple callers
|
||||
// with different global tensors.
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-NOT: is_mutable
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-NOT: is_mutable
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v2", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_common} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
func @f2(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v2}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f2"]} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_common} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: This test exercises immutability without explicit use
|
||||
// via ReadVariableOp
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-NOT: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
%val_2 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
return %val_2 : tensor<f32>
|
||||
}
|
||||
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%cst_1 = constant dense<2.0> : tensor<f32>
|
||||
return %cst_1 : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test case: Test mutation detection propagates across function calls
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: The inter-procedural analysis with different types of
|
||||
// TF call ops
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%val = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: The inter-procedural analysis does not recurse infinitely
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-NOT: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @exported_f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["exported_f"]} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module attributes {tf_saved_model.semantics}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Inter-procedural analysis with resource usage in an
|
||||
// unknown op, we assume mutating behavior and propagate that.
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @exported_f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["exported_f"]} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
}
|
||||
}
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
// `tf_device.launch` with equivalent `tf_device.launch_func` operations.
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
// Here, the element type can be any integer or float type. But, note that only
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Visitors.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This transformation pass transforms functional control flow operations in the
|
||||
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
|
@ -0,0 +1,135 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This pass hoists a `tf_device.launch` body and assigns a `device` attribute
|
||||
// to each TensorFlow dialect op in the body based on the `device` attribute on
|
||||
// the `tf_device.launch`. If a TensorFlow dialect op already has a device
|
||||
// attribute, that attribute will be overwritten with the `tf_device.launch`
|
||||
// device.
|
||||
//
|
||||
// For example:
|
||||
// %island:5 = tf_executor.island {
|
||||
// %a = "tf.opA"() : () -> tensor<i1>
|
||||
// %launch:2 = "tf_device.launch"() ( {
|
||||
// %b = "tf.opB"() : () -> tensor<i32>
|
||||
// %c = "tf.opC"() : () -> tensor<f32>
|
||||
// tf_device.return %c, %b : tensor<f32>, tensor<i32>
|
||||
// }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
|
||||
// %d = "tf.opD"() : () -> tensor<i1>
|
||||
// tf_executor.yield %a, %launch#0, %launch#1, %d :
|
||||
// tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
|
||||
// }
|
||||
//
|
||||
// Will be transformed into:
|
||||
// %island:5 = tf_executor.island {
|
||||
// %a = "tf.opA"() : () -> tensor<i1>
|
||||
// %b = "tf.opB"() {device = "CPU:0"} : () -> tensor<i32>
|
||||
// %c = "tf.opC"() {device = "CPU:0"} : () -> tensor<f32>
|
||||
// %d = "tf.opD"() : () -> tensor<i1>
|
||||
// tf_executor.yield %a, %c, %b, %d :
|
||||
// tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
|
||||
// }
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Dialect.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/Visitors.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFDevice {
|
||||
namespace {
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
|
||||
struct LaunchToDeviceAttributePass
|
||||
: public FunctionPass<LaunchToDeviceAttributePass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect,
|
||||
tf_device::LaunchOp launch) {
|
||||
// Forward launch inner op results to launch op results.
|
||||
launch.replaceAllUsesWith(launch.GetBody().getTerminator()->getOperands());
|
||||
|
||||
// For all inner ops of the TensorFlow dialect, assign the launch device as a
|
||||
// `device` attribute.
|
||||
auto body = launch.GetBody().without_terminator();
|
||||
for (Operation& op : body) {
|
||||
if (op.getDialect() != tf_dialect)
|
||||
return launch.emitOpError() << "must contain only 'tf' dialect ops";
|
||||
|
||||
auto device_attr = op.getAttr(kDeviceAttr);
|
||||
if (!device_attr) {
|
||||
op.setAttr(kDeviceAttr, launch.deviceAttr());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto device_str_attr = device_attr.dyn_cast<StringAttr>()) {
|
||||
if (launch.device() != device_str_attr.getValue())
|
||||
return launch.emitOpError()
|
||||
<< "inner 'tf' dialect op has conflicting 'device' attribute, "
|
||||
"got '"
|
||||
<< device_str_attr.getValue() << "' but expected '"
|
||||
<< launch.device() << "'";
|
||||
} else {
|
||||
return launch.emitOpError()
|
||||
<< "inner 'tf' dialect op has bad 'device' attribute";
|
||||
}
|
||||
}
|
||||
|
||||
// Move all inner ops of the launch to the block containing the launch.
|
||||
Operation* launch_op = launch.getOperation();
|
||||
launch_op->getBlock()->getOperations().splice(
|
||||
launch_op->getIterator(), launch.GetBody().getOperations(), body.begin(),
|
||||
body.end());
|
||||
|
||||
launch.erase();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void LaunchToDeviceAttributePass::runOnFunction() {
|
||||
const Dialect* tf_dialect = getContext().getRegisteredDialect("tf");
|
||||
if (!tf_dialect) {
|
||||
signalPassFailure();
|
||||
getFunction().emitError() << "'tf' dialect is not registered";
|
||||
}
|
||||
|
||||
auto result = getFunction().walk([&](tf_device::LaunchOp launch) {
|
||||
if (failed(HoistOpsAndAnnotateWithDevice(tf_dialect, launch)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (result.wasInterrupted()) return signalPassFailure();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass() {
|
||||
return std::make_unique<LaunchToDeviceAttributePass>();
|
||||
}
|
||||
|
||||
static PassRegistration<LaunchToDeviceAttributePass> pass(
|
||||
"tf-launch-to-device-attribute",
|
||||
"Hoists and annotates device launch inner ops with associated device "
|
||||
"attribute");
|
||||
|
||||
} // namespace TFDevice
|
||||
} // namespace mlir
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
// Here, the element type can be any integer or float type. But, note that only
|
||||
@ -122,7 +122,7 @@ def LowerSparseSoftmaxCrossEntropyWithLogitsOp : Pattern<
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ComplexTensor : TensorOf<[AnyComplex]>;
|
||||
def RealTensor : TensorOf<[AnyInteger, AnyFloat]>;
|
||||
def RealTensor : TensorOf<[AnySignlessInteger, AnyFloat]>;
|
||||
|
||||
def : Pat<(TF_SquareOp $val), (TF_MulOp $val, $val)>;
|
||||
|
||||
@ -179,7 +179,7 @@ def LowerL2LossOp :
|
||||
// Pad op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(TF_PadOp TensorOf<[AnyInteger, AnyFloat]>:$input, $paddings),
|
||||
def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings),
|
||||
(TF_PadV2Op $input, $paddings,
|
||||
(TF_ConstOp (GetScalarOfType<0> $input)))>;
|
||||
|
||||
@ -224,6 +224,6 @@ def CreateTFShapeOp : NativeCodeCall<
|
||||
|
||||
// TODO(hinsu): Support inputs of TensorList types.
|
||||
def LowerZerosLikeOp :
|
||||
Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnyInteger, AnyFloat]>:$input),
|
||||
Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input),
|
||||
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
|
||||
(CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <iostream>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
|
||||
|
@ -15,16 +15,27 @@ limitations under the License.
|
||||
|
||||
// This pass optimizes tf_saved_model.global_tensor ops.
|
||||
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
@ -45,8 +56,142 @@ struct GlobalTensorUse {
|
||||
using GlobalTensorUsesMap =
|
||||
std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
|
||||
|
||||
static bool IsResourceType(Type type) {
|
||||
if (auto tensor_type = type.dyn_cast<TensorType>()) {
|
||||
return tensor_type.getElementType().isa<TF::ResourceType>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool IsResource(Value value) { return IsResourceType(value.getType()); }
|
||||
|
||||
class ResourceAnalyzer {
|
||||
public:
|
||||
explicit ResourceAnalyzer(ModuleOp module) {
|
||||
SymbolTable symbol_table(module);
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
AnalyzeFunc(func, symbol_table);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsPotentiallyWritten(Value resource) const {
|
||||
assert(IsResource(resource));
|
||||
auto it = resource_infos_.find(resource);
|
||||
if (it == resource_infos_.end()) {
|
||||
return false;
|
||||
}
|
||||
return it->second.potentially_written;
|
||||
}
|
||||
|
||||
private:
|
||||
// Analyze the specified func for resource mutating operations, namely
|
||||
// TF::AssignVariableOp, if so, set the resource associated as "potentially
|
||||
// written". Do this recursively across the chain of funcs via call or control
|
||||
// flow ops.
|
||||
// TODO(ashwinm): Move to iterative traversal.
|
||||
LogicalResult AnalyzeFunc(FuncOp func, const SymbolTable& symbol_table) {
|
||||
// Avoid infinite recursion.
|
||||
if (!discovered_.insert(func).second) {
|
||||
return success();
|
||||
}
|
||||
|
||||
func.walk([&](Operation* op) {
|
||||
if (isa<TF::ReadVariableOp>(op) || isa<ReturnOp>(op)) {
|
||||
return;
|
||||
}
|
||||
if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
|
||||
SetPotentiallyWritten(assign_variable.resource());
|
||||
return;
|
||||
}
|
||||
if (auto call = dyn_cast<CallOpInterface>(op)) {
|
||||
if (auto sym = op->getAttrOfType<SymbolRefAttr>("f")) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(
|
||||
sym.cast<FlatSymbolRefAttr>().getValue(), call.getArgOperands(),
|
||||
symbol_table);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto if_op = dyn_cast<TF::IfOp>(op)) {
|
||||
for (auto callee : {if_op.then_branch(), if_op.else_branch()}) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input(),
|
||||
symbol_table);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
|
||||
for (auto callee : {while_op.cond(), while_op.body()}) {
|
||||
PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input(),
|
||||
symbol_table);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// For all other ops, we assume it mutates all resources it uses, so
|
||||
// this errs on the side of being conservative. We should improve
|
||||
// this by using either a property or a trait that clearly
|
||||
// identifies ops with resource mutating behavior.
|
||||
if (PropagatePotentiallyWrittenWithinUnhandledOp(op)) {
|
||||
return;
|
||||
}
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
// If an op is not one of the handled ones, we assume all resource usages
|
||||
// within its purview are mutating in nature.
|
||||
bool PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (IsResource(operand)) {
|
||||
SetPotentiallyWritten(operand);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
bool uses_resources = false;
|
||||
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
|
||||
if (IsResource(operand->get())) {
|
||||
SetPotentiallyWritten(operand->get());
|
||||
uses_resources = true;
|
||||
}
|
||||
});
|
||||
return uses_resources;
|
||||
}
|
||||
|
||||
// Given a funcOp associated with the callee and operands from the
|
||||
// corresponding callOp, propagate the potentially written decision to the
|
||||
// callOp's operands, if the corresponding func's arguments are potentially
|
||||
// written resources.
|
||||
void PropagatePotentiallyWrittenUpFromCallee(
|
||||
StringRef callee, Operation::operand_range propagate_to,
|
||||
const SymbolTable& symbol_table) {
|
||||
auto func = symbol_table.lookup<FuncOp>(callee);
|
||||
AnalyzeFunc(func, symbol_table);
|
||||
for (auto t : llvm::zip(func.getArguments(), propagate_to)) {
|
||||
if (!IsResource(std::get<0>(t))) {
|
||||
continue;
|
||||
}
|
||||
if (IsPotentiallyWritten(std::get<0>(t))) {
|
||||
SetPotentiallyWritten(std::get<1>(t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetPotentiallyWritten(Value resource) {
|
||||
assert(IsResource(resource));
|
||||
resource_infos_[resource].potentially_written = true;
|
||||
}
|
||||
struct ResourceInfo {
|
||||
bool potentially_written = false;
|
||||
};
|
||||
// Key: Resource Value's
|
||||
// Value: Information we know about that Value.
|
||||
// Note that these Value's are in general in different functions.
|
||||
DenseMap<Value, ResourceInfo> resource_infos_;
|
||||
// The set of func's we already discovered.
|
||||
DenseSet<FuncOp> discovered_;
|
||||
};
|
||||
|
||||
bool IsImmutable(GlobalTensorOp global_tensor,
|
||||
ArrayRef<GlobalTensorUse> global_tensor_uses) {
|
||||
ArrayRef<GlobalTensorUse> global_tensor_uses,
|
||||
const ResourceAnalyzer& resource_analyzer) {
|
||||
// Global tensor is already known to be immutable.
|
||||
if (!global_tensor.is_mutable()) {
|
||||
return false;
|
||||
@ -57,17 +202,11 @@ bool IsImmutable(GlobalTensorOp global_tensor,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check the uses to see if this global tensor is only used in a way that
|
||||
// is compatible with being immutable.
|
||||
// Right now, this uses a very simple algorithm that only checks the top-level
|
||||
// func for tf.ReadVariableOp. If the resource is passed into other functions
|
||||
// or control flow, we fail to prove it is freezable even though we could.
|
||||
// A global tensor is immutable if the resource analyzer deems it so.
|
||||
for (auto& global_tensor_use : global_tensor_uses) {
|
||||
auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
|
||||
for (auto user : arg.getUsers()) {
|
||||
if (!isa<TF::ReadVariableOp>(user)) {
|
||||
return false;
|
||||
}
|
||||
if (resource_analyzer.IsPotentiallyWritten(arg)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
@ -96,11 +235,12 @@ static GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
|
||||
// Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
|
||||
// can prove it is safe to do so.
|
||||
void MarkGlobalTensorsImmutable(
|
||||
ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map) {
|
||||
ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
|
||||
const ResourceAnalyzer& resource_analyzer) {
|
||||
for (const auto& kv : global_tensor_uses_map) {
|
||||
auto global_tensor = kv.first;
|
||||
const auto& global_tensor_uses = kv.second;
|
||||
if (IsImmutable(global_tensor, global_tensor_uses)) {
|
||||
if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
|
||||
global_tensor.removeAttr("is_mutable");
|
||||
}
|
||||
}
|
||||
@ -137,17 +277,15 @@ void EraseUnusedBoundInputs(ModuleOp module) {
|
||||
}
|
||||
|
||||
void OptimizeGlobalTensorsPass::runOnModule() {
|
||||
// This analysis could be much more elaborate, including tracking global
|
||||
// tensors interprocedurally and uses in a wide variety of ops. But I don't
|
||||
// know if we need that complexity.
|
||||
auto module = getModule();
|
||||
|
||||
EraseUnusedBoundInputs(module);
|
||||
|
||||
// Figure out which func's use each tf_saved_model.global_tensor.
|
||||
ResourceAnalyzer resource_analyzer(module);
|
||||
|
||||
GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
|
||||
|
||||
MarkGlobalTensorsImmutable(module, global_tensor_uses);
|
||||
MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
|
||||
|
||||
EraseUnusedGlobalTensors(module, global_tensor_uses);
|
||||
}
|
||||
|
||||
|
@ -186,6 +186,10 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateParallelExecuteToIslandsPass();
|
||||
// same data across replicas.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateAnnotateParameterReplicationPass();
|
||||
|
||||
// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
|
||||
// attribute to each TensorFlow dialect op in the body based on the `device`
|
||||
// attribute on the `tf_device.launch`.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLaunchToDeviceAttributePass();
|
||||
} // namespace TFDevice
|
||||
|
||||
namespace TFTPU {
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -42,7 +42,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Analysis/Verifier.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // TF:llvm-project
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_module_importer.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
|
@ -1130,7 +1130,7 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
|
||||
// Illegal attributes.
|
||||
ShapedType attr_ty = start_indices.getType();
|
||||
if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank ||
|
||||
!attr_ty.getElementType().isInteger(64) ||
|
||||
!attr_ty.getElementType().isSignlessInteger(64) ||
|
||||
limit_indices.getType() != attr_ty || strides.getType() != attr_ty)
|
||||
return ty;
|
||||
|
||||
|
@ -52,7 +52,7 @@ def HLO_FpTensor : TensorOf<[AnyFloat]>;
|
||||
|
||||
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
|
||||
|
||||
def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger, AnyComplex]>;
|
||||
def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||
|
||||
def HLO_ComplexTensor : TensorOf<[AnyComplex]>;
|
||||
|
||||
@ -64,13 +64,13 @@ def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
|
||||
// an index type (as it stores indices) but that is currently disallowed in
|
||||
// MLIR.
|
||||
def HLO_DimensionTensor : ShapedContainerType<
|
||||
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
[AnySignlessInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
"a 1D tensor of dimensions">;
|
||||
|
||||
// In general, static shaped tensor constraints should be avoided unless
|
||||
// it is for a legacy op which is only correct with static shapes.
|
||||
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
|
||||
AnyFloat, AnyInteger, AnyComplex]>;
|
||||
AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA combined type definitions.
|
||||
@ -784,7 +784,7 @@ def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor",
|
||||
compute shape arguments to dynamic operations.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyInteger>);
|
||||
let arguments = (ins Variadic<AnySignlessInteger>);
|
||||
let results = (outs HLO_DimensionTensor);
|
||||
|
||||
// Cannot be exported to legacy formats.
|
||||
|
@ -40,7 +40,7 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
||||
|
||||
// Handle integer elements.
|
||||
Attribute elementAttr;
|
||||
if (valElementType.isa<IntegerType>())
|
||||
if (valElementType.isSignlessInteger())
|
||||
elementAttr = b->getIntegerAttr(valElementType, constant);
|
||||
else if (valElementType.isa<FloatType>())
|
||||
elementAttr = b->getFloatAttr(valElementType, constant);
|
||||
|
@ -42,7 +42,7 @@ def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||
// Any integer or floating-point tensor types
|
||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||
|
||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger]>;
|
||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger]>;
|
||||
|
||||
def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>;
|
||||
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
@ -1221,7 +1221,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
||||
<< "requires arg " << padding_arg_index
|
||||
<< " to be a scalar for use as a dynamic parameter";
|
||||
|
||||
if (!mlir::getElementTypeOrSelf(padding_arg_type).isa<IntegerType>())
|
||||
if (!mlir::getElementTypeOrSelf(padding_arg_type).isSignlessInteger())
|
||||
return entry_func.emitError()
|
||||
<< "requires arg " << padding_arg_index
|
||||
<< " to be of an int type for use as a dynamic parameter";
|
||||
|
@ -210,7 +210,6 @@ func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MPA:.*]] = affine_map<(d0, d1, d2) -> (0)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%operand: memref<f32>, %result: memref<7x10x6xf32>) {
|
||||
@ -219,9 +218,10 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<7x10x6xf32>) {
|
||||
: (memref<f32>, memref<7x10x6xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:[a-zA-Z0-9_]*]]: f32, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: %[[CONST:.*]] = load %{{.*}} : memref<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[CONST]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This file implements logic for lowering HLO dialect to LHLO dialect.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
@ -119,7 +119,7 @@ static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
|
||||
Type GetSumAccumulationType(Type input_type) {
|
||||
MLIRContext *ctx = input_type.getContext();
|
||||
if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
|
||||
if (input_type.isInteger(8) || input_type.isInteger(16))
|
||||
if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
|
||||
return IntegerType::get(32, ctx);
|
||||
return input_type;
|
||||
}
|
||||
@ -1274,7 +1274,7 @@ class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> {
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type element_type =
|
||||
op.input().getType().cast<TensorType>().getElementType();
|
||||
if (!element_type.isIntOrFloat()) return matchFailure();
|
||||
if (!element_type.isSignlessIntOrFloat()) return matchFailure();
|
||||
Location loc = op.getLoc();
|
||||
ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
|
||||
|
||||
@ -2248,7 +2248,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
|
||||
Type input_element_type = input_type.getElementType();
|
||||
// TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
|
||||
// tf.ArgMax doesn't support complex data types, this check can be removed.
|
||||
if (!input_element_type.isIntOrFloat()) return this->matchFailure();
|
||||
if (!input_element_type.isSignlessIntOrFloat()) return this->matchFailure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value init_value =
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This is the legalization pattern definition file for TF to XLA.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
@ -504,8 +504,8 @@ def : Pat<(TF_SignOp $x),
|
||||
)>;
|
||||
|
||||
def BothElementTypesSameWidthIntOrFloat : Constraint<CPred<
|
||||
"getElementTypeOrSelf($0.getType()).isIntOrFloat() && "
|
||||
"getElementTypeOrSelf($1.getType()).isIntOrFloat() && "
|
||||
"getElementTypeOrSelf($0.getType()).isSignlessIntOrFloat() && "
|
||||
"getElementTypeOrSelf($1.getType()).isSignlessIntOrFloat() && "
|
||||
"getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == "
|
||||
"getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">,
|
||||
"element types must be integers or floats of same width">;
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This file implements logic for lowering XLA dialect to Standard dialect.
|
||||
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
@ -45,8 +45,8 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
// Broadcasting not supported by this rewrite.
|
||||
if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();
|
||||
|
||||
if (!lhs_type.getElementType().isa<IntegerType>() ||
|
||||
!rhs_type.getElementType().isa<IntegerType>())
|
||||
if (!lhs_type.getElementType().isSignlessInteger() ||
|
||||
!rhs_type.getElementType().isSignlessInteger())
|
||||
return matchFailure();
|
||||
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
@ -113,7 +113,8 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
// TODO(prakalps): Handle FP and ComplexType iota ops.
|
||||
if (!output_type.getElementType().isa<IntegerType>()) return matchFailure();
|
||||
if (!output_type.getElementType().isSignlessInteger())
|
||||
return matchFailure();
|
||||
auto output_size = output_type.getNumElements();
|
||||
auto dimension = op.iota_dimension().getSExtValue();
|
||||
auto max_dim_size = output_type.getDimSize(dimension);
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This is the legalization pattern definition file for XLA to StandardOps.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
// equivalent real value operations.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||
|
||||
@ -206,7 +206,7 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
Type element_type = lhs.getType();
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
if (element_type.isSignlessInteger()) {
|
||||
Optional<CmpIPredicate> predicate =
|
||||
getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction());
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
@ -288,8 +288,8 @@ template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
const Type& sourceType = args.front().getType();
|
||||
const Type& targetType = result_types.front();
|
||||
Type sourceType = args.front().getType();
|
||||
Type targetType = result_types.front();
|
||||
|
||||
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
|
||||
return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args,
|
||||
@ -307,7 +307,7 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
// No conversion is needed for the same width floats
|
||||
return args.front();
|
||||
}
|
||||
if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
|
||||
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
|
||||
IntegerType src = sourceType.cast<IntegerType>();
|
||||
IntegerType res = targetType.cast<IntegerType>();
|
||||
if (src.getWidth() > res.getWidth()) {
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
@ -83,7 +83,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
emitError(loc, "lhlo to linalg conversion expects ranked args");
|
||||
return ConversionPattern::matchFailure();
|
||||
}
|
||||
if (!argType.getElementType().isIntOrFloat()) {
|
||||
if (!argType.getElementType().isSignlessIntOrFloat()) {
|
||||
return ConversionPattern::matchFailure();
|
||||
}
|
||||
|
||||
@ -171,7 +171,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
||||
auto loc = lhlo_op.getLoc();
|
||||
auto argType =
|
||||
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||
if (!argType || !argType.getElementType().isIntOrFloat() ||
|
||||
if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
|
||||
(argType.getRank() != 0)) {
|
||||
return ConversionPattern::matchFailure();
|
||||
}
|
||||
@ -208,6 +208,9 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
||||
auto resultType = getXLAOpResultType<isLHLO>(op);
|
||||
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op))
|
||||
return ConversionPattern::matchFailure();
|
||||
// TODO(b/150203558) Enable once tiling/fusion works in this case.
|
||||
if (isLHLO && (operandType.getRank() == 0))
|
||||
return ConversionPattern::matchFailure();
|
||||
ArrayAttr indexingMapsAttr =
|
||||
static_cast<const Derived&>(*this).getIndexingMapsAttr(op, &rewriter);
|
||||
if (!indexingMapsAttr) return ConversionPattern::matchFailure();
|
||||
@ -278,6 +281,52 @@ class BroadcastInDimConverter
|
||||
}
|
||||
};
|
||||
|
||||
// Special case for scalar broadcast in lhlo.
|
||||
// TODO(b/150203558) Remove once the bug is fixed.
|
||||
class ScalarBroadcastInDimConverter
|
||||
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
xla_lhlo::BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto operandMemrefType =
|
||||
broadcastOp.operand().getType().dyn_cast<MemRefType>();
|
||||
// Only support scalar operands.
|
||||
if (operandMemrefType.getRank() != 0) return matchFailure();
|
||||
auto resultMemrefType =
|
||||
broadcastOp.output().getType().dyn_cast<MemRefType>();
|
||||
if (!operandMemrefType || !resultMemrefType) return matchFailure();
|
||||
auto broadcastDims = broadcastOp.broadcast_dimensions();
|
||||
if (!broadcastDims.hasValue()) return matchFailure();
|
||||
|
||||
unsigned nloops = resultMemrefType.getRank();
|
||||
SmallVector<Attribute, 1> indexingMaps{
|
||||
AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops))};
|
||||
auto loc = broadcastOp.getLoc();
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, ArrayRef<Type>{}, broadcastOp.output(),
|
||||
rewriter.getI64IntegerAttr(0), // args_in
|
||||
rewriter.getI64IntegerAttr(1), // args_out
|
||||
rewriter.getArrayAttr(indexingMaps),
|
||||
GetNParallelLoopsAttrs(nloops, &rewriter),
|
||||
/*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);
|
||||
|
||||
// Add a block to the region.
|
||||
auto* region = &linalgOp.region();
|
||||
auto* block = rewriter.createBlock(region, region->end());
|
||||
block->addArguments(resultMemrefType.getElementType());
|
||||
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
auto scalar =
|
||||
rewriter.create<LoadOp>(loc, broadcastOp.operand(), llvm::None);
|
||||
rewriter.create<linalg::YieldOp>(loc, scalar.getResult());
|
||||
rewriter.eraseOp(broadcastOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class TransposeConverter
|
||||
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||
@ -385,7 +434,7 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
|
||||
if (!resultMemrefType) return matchFailure();
|
||||
|
||||
auto resultElementType = resultMemrefType.getElementType();
|
||||
if (!resultElementType.isIntOrFloat()) return matchFailure();
|
||||
if (!resultElementType.isSignlessIntOrFloat()) return matchFailure();
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
unsigned nloops = resultMemrefType.getRank();
|
||||
@ -502,6 +551,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
||||
ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>,
|
||||
ScalarBroadcastInDimConverter,
|
||||
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
|
||||
SliceConverter
|
||||
>(context);
|
||||
|
@ -1514,6 +1514,7 @@ cuda_py_test(
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
xla_enabled = True,
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
@ -1534,6 +1535,7 @@ cuda_py_test(
|
||||
"no_rocm",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
xla_enabled = True,
|
||||
deps = [
|
||||
":test_utils",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
@ -1558,6 +1560,7 @@ cuda_py_test(
|
||||
"no_rocm",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
xla_enabled = True,
|
||||
deps = [
|
||||
":test_utils",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
@ -1650,6 +1653,7 @@ cuda_py_test(
|
||||
"no_rocm",
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
xla_enabled = True,
|
||||
deps = [
|
||||
":lstm",
|
||||
":xla_test",
|
||||
|
@ -695,7 +695,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
wholly_compiled_f = def_function.function(f)
|
||||
op_by_op_f = def_function.function(f, experimental_compile=False)
|
||||
|
||||
x = constant_op.constant([0.0, 2.0], name='data')
|
||||
x = array_ops.identity([0.0, 2.0], name='data')
|
||||
|
||||
# When function is wholly compiled, all outputs will be on the
|
||||
# device on which it is run.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user