Update shape inference to not insert tf.Cast ops from tf_executor.graph results.

Currently breakup islands and export does not support a mix of tf_executor.graph and the functional form in a function. Breakup islands may introduce a graph in an island, breaking export from TF MLIR to Graph.

PiperOrigin-RevId: 341124926
Change-Id: Ifc654b03ad45a21fc6b4bffbe03d206e55207c76
This commit is contained in:
Andy Ly 2020-11-06 14:59:30 -08:00 committed by TensorFlower Gardener
parent 8687106526
commit 4e3cb92e6b
2 changed files with 26 additions and 0 deletions

View File

@ -601,4 +601,25 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
%size = "tf.Size"(%add) {device = ""} : (tensor<*xi32>) -> tensor<*xi32>
return %size : tensor<*xi32>
}
// Test no tf.Cast ops are inserted when refining tf_executor.graph results.
// CHECK-LABEL: func @call_in_graph({{%.+}}: tensor<i32>) -> tensor<i32>
func @call_in_graph(%arg0: tensor<i32>) -> tensor<*xi32> {
// CHECK-NOT: tf.Cast
%0 = tf_executor.graph {
%1:2 = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @call_in_graph_func} : (tensor<i32>) -> tensor<*xi32>
tf_executor.fetch %1#0 : tensor<*xi32>
}
return %0 : tensor<*xi32>
}
// CHECK-LABEL: func @call_in_graph_func({{%.+}}: tensor<i32>) -> tensor<i32>
func @call_in_graph_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// CHECK-NOT: tf.Cast
%0 = tf_executor.graph {
%1:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
tf_executor.fetch %1#0 : tensor<*xi32>
}
return %0 : tensor<*xi32>
}
}

View File

@ -115,6 +115,11 @@ bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
// tf.Cast operation for uses that are incompatible with the new type.
void UpdateTypeAndInsertIncompatibleUseCasts(Dialect* tf_dialect, Type new_type,
Value result) {
if (isa_and_nonnull<tf_executor::GraphOp>(result.getDefiningOp())) {
result.setType(new_type);
return;
}
// A tf.Cast operation is lazily created on the first use requires a cast.
TF::CastOp cast_op;
auto get_cast_op = [&]() {