Update shape inference to not insert tf.Cast ops from tf_executor.graph results.
Currently breakup islands and export does not support a mix of tf_executor.graph and the functional form in a function. Breakup islands may introduce a graph in an island, breaking export from TF MLIR to Graph. PiperOrigin-RevId: 341124926 Change-Id: Ifc654b03ad45a21fc6b4bffbe03d206e55207c76
This commit is contained in:
parent
8687106526
commit
4e3cb92e6b
@ -601,4 +601,25 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
%size = "tf.Size"(%add) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
|
||||
return %size : tensor<*xi32>
|
||||
}
|
||||
|
||||
// Test no tf.Cast ops are inserted when refining tf_executor.graph results.
|
||||
// CHECK-LABEL: func @call_in_graph({{%.+}}: tensor<i32>) -> tensor<i32>
|
||||
func @call_in_graph(%arg0: tensor<i32>) -> tensor<*xi32> {
|
||||
// CHECK-NOT: tf.Cast
|
||||
%0 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @call_in_graph_func} : (tensor<i32>) -> tensor<*xi32>
|
||||
tf_executor.fetch %1#0 : tensor<*xi32>
|
||||
}
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @call_in_graph_func({{%.+}}: tensor<i32>) -> tensor<i32>
|
||||
func @call_in_graph_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
// CHECK-NOT: tf.Cast
|
||||
%0 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
tf_executor.fetch %1#0 : tensor<*xi32>
|
||||
}
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
}
|
||||
|
@ -115,6 +115,11 @@ bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
|
||||
// tf.Cast operation for uses that are incompatible with the new type.
|
||||
void UpdateTypeAndInsertIncompatibleUseCasts(Dialect* tf_dialect, Type new_type,
|
||||
Value result) {
|
||||
if (isa_and_nonnull<tf_executor::GraphOp>(result.getDefiningOp())) {
|
||||
result.setType(new_type);
|
||||
return;
|
||||
}
|
||||
|
||||
// A tf.Cast operation is lazily created on the first use requires a cast.
|
||||
TF::CastOp cast_op;
|
||||
auto get_cast_op = [&]() {
|
||||
|
Loading…
Reference in New Issue
Block a user