Only generate xla_hlo.constant where result and attribute types match exactly.

xla_hlo.constant would no longer allow the result type to be different from the attribute type. Since some of the legalizations from TF -> XLA used this relaxed property to generate constants with dynamic/unranked types, they have been modified to use std.tensor_cast op explicitly for type conversion.

PiperOrigin-RevId: 308648605
Change-Id: I52839f9d6f5248e4447936f402886646efeb6e4b
This commit is contained in:
Prakalp Srivastava 2020-04-27 10:29:40 -07:00 committed by TensorFlower Gardener
parent 5ab3af7a7b
commit 7c977c938e
3 changed files with 35 additions and 19 deletions

View File

@ -33,8 +33,9 @@ func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %a
// CHECK: [[Y:%.*]] = "xla_hlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} // CHECK: [[Y:%.*]] = "xla_hlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>)
// CHECK: [[Y_CONVERT:%.*]] = "xla_hlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> // CHECK: [[Y_CONVERT:%.*]] = "xla_hlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
// CHECK: [[DUMMY:%.*]] = xla_hlo.constant {value = dense<0.000000e+00> : tensor<0xf32>} : tensor<*xf32> // CHECK: [[DUMMY:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<0xf32>
// CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY]] // CHECK: [[DUMMY_CAST:%.*]] = tensor_cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32>
// CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]]
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>
} }
@ -821,7 +822,8 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> {
%cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32>
// CHECK: [[CST:%.+]] = xla_hlo.constant // CHECK: [[CST:%.+]] = xla_hlo.constant
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CST]]) // CHECK: [[CAST:%.+]] = tensor_cast [[CST]] : tensor<4xi32> to tensor<4xi32>
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CAST]])
// CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32>
return %0 : tensor<16x16x16x16xf32> return %0 : tensor<16x16x16x16xf32>
@ -1202,8 +1204,10 @@ func @const() -> tensor<2xi32> {
// CHECK-LABEL: @const_dynamic_output // CHECK-LABEL: @const_dynamic_output
func @const_dynamic_output() -> tensor<*xi32> { func @const_dynamic_output() -> tensor<*xi32> {
// CHECK: xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32> // CHECK: [[CONST:%.*]] = xla_hlo.constant dense<0> : tensor<2xi32>
// CHECK: [[CAST:%.*]] = tensor_cast [[CONST]] : tensor<2xi32> to tensor<*xi32>
%0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>)
// CHECK: return [[CAST]]
return %0: tensor<*xi32> return %0: tensor<*xi32>
} }
@ -2339,7 +2343,8 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
// CHECK-LABEL: slice_constant_start // CHECK-LABEL: slice_constant_start
func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> // CHECK: %[[CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[CAST]]) : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
@ -2360,7 +2365,8 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK-LABEL: slice_i32_consts // CHECK-LABEL: slice_i32_consts
func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32> // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64> // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi32> to tensor<1xi32>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
@ -2376,7 +2382,8 @@ func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK-LABEL: slice_constant_start_negative_one_size // CHECK-LABEL: slice_constant_start_negative_one_size
func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
@ -2393,7 +2400,8 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi
// CHECK-LABEL: slice_constant_start_dynamic_shape // CHECK-LABEL: slice_constant_start_dynamic_shape
func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> // CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64> // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<2xi64> to tensor<2xi64>
// CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]]) // CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]])
// CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
@ -3127,7 +3135,8 @@ func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> {
// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32> // CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32>
func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> { func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> {
// CHECK-DAG: [[NUM:%.*]] = xla_hlo.constant dense<4> // CHECK-DAG: [[NUM:%.*]] = xla_hlo.constant dense<4>
// CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM]]) // CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]]
// CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM_CAST]])
// CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> // CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00>
// CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]] // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]]
// CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]] // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]]
@ -3150,10 +3159,10 @@ func @linspace_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32
// CHECK-LABEL: func @linspace_invalid_num // CHECK-LABEL: func @linspace_invalid_num
func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> { func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> {
// CHECK: xla_hlo.constant {value = dense<[]> : tensor<0xi32>} : tensor<i32> // CHECK: xla_hlo.constant dense<[]> : tensor<0xi32>
// CHECK: "tf.LinSpace" // CHECK: "tf.LinSpace"
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor<i32> %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
%1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<?xf32> %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<0xi32>) -> tensor<?xf32>
return %1 : tensor<?xf32> return %1 : tensor<?xf32>
} }
@ -3976,16 +3985,18 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
// CHECK-LABLE: @variable_shape32 // CHECK-LABLE: @variable_shape32
func @variable_shape32(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi32> { func @variable_shape32(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi32> {
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi32> // CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi32>
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi32>) %0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi32>)
// CHECK: return [[CST]] // CHECK: return [[CST_CAST]]
return %0: tensor<3xi32> return %0: tensor<3xi32>
} }
// CHECK-LABLE: @variable_shape64 // CHECK-LABLE: @variable_shape64
func @variable_shape64(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi64> { func @variable_shape64(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi64> {
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi64> // CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi64>
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi64>) %0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi64>)
// CHECK: return [[CST]] // CHECK: return [[CST_CAST]]
return %0: tensor<3xi64> return %0: tensor<3xi64>
} }

View File

@ -1432,15 +1432,18 @@ class ConvertFusedBatchNormV3Op
: 0; : 0;
auto const_attr_type = RankedTensorType::get( auto const_attr_type = RankedTensorType::get(
{num_elements}, getElementTypeOrSelf(reserve_space_3_type)); {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
auto dummy_const = rewriter.create<ConstOp>(
op.getLoc(), reserve_space_3_type, Value dummy_const = rewriter.create<ConstOp>(
DenseElementsAttr::get<float>(const_attr_type, 0.0)); op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
if (const_attr_type != reserve_space_3_type)
dummy_const = rewriter.create<TensorCastOp>(
op.getLoc(), reserve_space_3_type, dummy_const);
rewriter.replaceOp(op, {/*y=*/y_out, rewriter.replaceOp(op, {/*y=*/y_out,
/*batch_mean=*/op.mean(), /*batch_mean=*/op.mean(),
/*batch_variance=*/op.variance(), /*batch_variance=*/op.variance(),
/*reserve_space_1=*/op.mean(), /*reserve_space_1=*/op.mean(),
/*reserve_space_2=*/op.variance(), /*reserve_space_2=*/op.variance(),
/*reserve_space_3=*/dummy_const.getResult()}); /*reserve_space_3=*/dummy_const});
} }
return success(); return success();
} }
@ -4751,6 +4754,7 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<XlaHloDialect>(); target.addLegalDialect<XlaHloDialect>();
target.addLegalOp<CallOp>(); target.addLegalOp<CallOp>();
target.addLegalOp<TensorCastOp>();
if (!allow_partial_conversion) { if (!allow_partial_conversion) {
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.

View File

@ -434,7 +434,8 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_
// Nullary op patterns. // Nullary op patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value), def : Pat<(TF_ConstOp:$res ElementsAttr:$value),
(TensorCastOp (HLO_ConstOp $value)),
[(HLO_Tensor $res)]>; [(HLO_Tensor $res)]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//