diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 0ae85946ced..2cf86cbde1a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -33,8 +33,9 @@ func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %a // CHECK: [[Y:%.*]] = "xla_hlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) // CHECK: [[Y_CONVERT:%.*]] = "xla_hlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - // CHECK: [[DUMMY:%.*]] = xla_hlo.constant {value = dense<0.000000e+00> : tensor<0xf32>} : tensor<*xf32> - // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY]] + // CHECK: [[DUMMY:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<0xf32> + // CHECK: [[DUMMY_CAST:%.*]] = tensor_cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> + // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> } @@ -821,7 +822,8 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> // CHECK: [[CST:%.+]] = xla_hlo.constant - // CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CST]]) + // CHECK: [[CAST:%.+]] = tensor_cast [[CST]] : tensor<4xi32> to tensor<4xi32> + // CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CAST]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> return %0 : tensor<16x16x16x16xf32> @@ -1202,8 +1204,10 @@ func @const() -> tensor<2xi32> { // CHECK-LABEL: @const_dynamic_output func @const_dynamic_output() -> tensor<*xi32> { - // CHECK: xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32> + // CHECK: [[CONST:%.*]] = xla_hlo.constant dense<0> : tensor<2xi32> + // CHECK: [[CAST:%.*]] = tensor_cast [[CONST]] : tensor<2xi32> to tensor<*xi32> %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) + // CHECK: return [[CAST]] return %0: tensor<*xi32> } @@ -2339,7 +2343,8 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { // CHECK-LABEL: slice_constant_start func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[CAST]]) : (tensor<1xi64>) -> tensor<1xi64> // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, @@ -2360,7 +2365,8 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-LABEL: slice_i32_consts func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64> + // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi32> to tensor<1xi32> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi32>) -> tensor<1xi64> // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, @@ -2376,7 +2382,8 @@ func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-LABEL: slice_constant_start_negative_one_size func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi64>) -> tensor<1xi64> // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, @@ -2393,7 +2400,8 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi // CHECK-LABEL: slice_constant_start_dynamic_shape func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64> + // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<2xi64> to tensor<2xi64> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<2xi64>) -> tensor<2xi64> // CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]]) // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, @@ -3127,7 +3135,8 @@ func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { // CHECK-SAME: [[START:%.*]]: tensor, [[STOP:%.*]]: tensor func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { // CHECK-DAG: [[NUM:%.*]] = xla_hlo.constant dense<4> - // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM]]) + // CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]] + // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM_CAST]]) // CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]] // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]] @@ -3150,10 +3159,10 @@ func @linspace_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg1: tensor) -> tensor { - // CHECK: xla_hlo.constant {value = dense<[]> : tensor<0xi32>} : tensor + // CHECK: xla_hlo.constant dense<[]> : tensor<0xi32> // CHECK: "tf.LinSpace" - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor - %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor<0xi32>) -> tensor return %1 : tensor } @@ -3976,16 +3985,18 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK-LABLE: @variable_shape32 func @variable_shape32(%input: tensor>>) -> tensor<3xi32> { // CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi32> + // CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]] %0 = "tf.VariableShape"(%input) : (tensor>>) -> (tensor<3xi32>) - // CHECK: return [[CST]] + // CHECK: return [[CST_CAST]] return %0: tensor<3xi32> } // CHECK-LABLE: @variable_shape64 func @variable_shape64(%input: tensor>>) -> tensor<3xi64> { // CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi64> + // CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]] %0 = "tf.VariableShape"(%input) : (tensor>>) -> (tensor<3xi64>) - // CHECK: return [[CST]] + // CHECK: return [[CST_CAST]] return %0: tensor<3xi64> } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 38538212268..cceb7427789 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -1432,15 +1432,18 @@ class ConvertFusedBatchNormV3Op : 0; auto const_attr_type = RankedTensorType::get( {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); - auto dummy_const = rewriter.create( - op.getLoc(), reserve_space_3_type, - DenseElementsAttr::get(const_attr_type, 0.0)); + + Value dummy_const = rewriter.create( + op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); + if (const_attr_type != reserve_space_3_type) + dummy_const = rewriter.create( + op.getLoc(), reserve_space_3_type, dummy_const); rewriter.replaceOp(op, {/*y=*/y_out, /*batch_mean=*/op.mean(), /*batch_variance=*/op.variance(), /*reserve_space_1=*/op.mean(), /*reserve_space_2=*/op.variance(), - /*reserve_space_3=*/dummy_const.getResult()}); + /*reserve_space_3=*/dummy_const}); } return success(); } @@ -4751,6 +4754,7 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConversionTarget target(*context); target.addLegalDialect(); target.addLegalOp(); + target.addLegalOp(); if (!allow_partial_conversion) { // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 3e8989050bb..61ed0e610c1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -434,7 +434,8 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_ // Nullary op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value), +def : Pat<(TF_ConstOp:$res ElementsAttr:$value), + (TensorCastOp (HLO_ConstOp $value)), [(HLO_Tensor $res)]>; //===----------------------------------------------------------------------===//