[MLIR:TF/XLA] Handle resource handle shapes (subtypes) in shape inference.
Set the input_handle_shapes_and_types field in InferenceContext so that they will be visible to op registry's shape functions. Parse the output_handle_shapes_and_types as well to preserve this information in the MLIR type. PiperOrigin-RevId: 287922553 Change-Id: I0f33ae5325e67bb688ec5f38b12d538b23be380b
This commit is contained in:
parent
e473233b84
commit
75fbbf8c7e
@ -105,14 +105,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_from_while_to_cond_body_functions
|
||||
func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "tf.While"(%arg0) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<4xf32> {
|
||||
// CHECK "tf.While"
|
||||
// CHECK-SAME (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>)
|
||||
%0:3 = "tf.While"(%arg0, %arg1, %arg2) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>)
|
||||
return %0#0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_cond_func
|
||||
// CHECK-SAME: %arg0: tensor<4xf32>) -> tensor<i1>
|
||||
func @while_cond_func(%arg0: tensor<*xf32>) -> tensor<i1> {
|
||||
// CHECK-SAME: (%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<i1>
|
||||
func @while_cond_func(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<[1.000000e-04,2.000000e-04,3.000000e-04,4.000000e-04]> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%1 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: tf.Equal
|
||||
@ -124,14 +126,27 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_body_func
|
||||
func @while_body_func(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @while_body_func(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>) {
|
||||
%0 = "tf.Const"() {value = dense<1.000000e-04> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: tf.AddV2
|
||||
// CHECK-SAME: (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: "tf.Identity"
|
||||
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<!tf.resource<tensor<4xf32>>>
|
||||
%2 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource>
|
||||
// CHECK: "tf.TPUReplicatedInput"
|
||||
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<!tf.resource<tensor<4xf32>>>
|
||||
%ri = "tf.TPUReplicatedInput"(%2) : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource>
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
%read = "tf.ReadVariableOp"(%ri) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
// CHECK-SAME: (tensor<!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
|
||||
%read1 = "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<4xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
// CHECK-SAME: tensor<!tf.resource<tensor<4xf32>>>
|
||||
return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @invalid_function_reused_by_control_flows
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
@ -39,11 +40,14 @@ limitations under the License.
|
||||
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-shape-inference"
|
||||
|
||||
@ -91,6 +95,40 @@ Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
|
||||
|
||||
return llvm::to_vector<4>(return_op.getOperandTypes());
|
||||
}
|
||||
|
||||
// Inserts tf.Cast operation when changing the type of a result if the user is
|
||||
// not a TF operation, as we can't guarantee that the new type will be OK.
|
||||
void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result,
|
||||
Dialect* tf_dialect, Type old_type) {
|
||||
OpBuilder builder(op);
|
||||
builder.setInsertionPointAfter(op);
|
||||
// A tf.Cast operation is lazily created on the first uses that isn't a TF
|
||||
// operation.
|
||||
TF::CastOp cast_op;
|
||||
auto get_cast_op = [&]() {
|
||||
if (!cast_op)
|
||||
cast_op =
|
||||
builder.create<TF::CastOp>(op->getLoc(), old_type, result,
|
||||
/*truncate=*/builder.getBoolAttr(false));
|
||||
return cast_op;
|
||||
};
|
||||
for (OpOperand& use : llvm::make_early_inc_range(result->getUses())) {
|
||||
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op());
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts a PartialTensorShape from the MLIR type.
|
||||
Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) {
|
||||
if (auto ranked_type = t.dyn_cast<RankedTensorType>()) {
|
||||
// Convert the MLIR shape indices (int64_t) to TensorFlow indices
|
||||
// (int64).
|
||||
ArrayRef<int64_t> shape = ranked_type.getShape();
|
||||
SmallVector<int64, 8> tf_shape(shape.begin(), shape.end());
|
||||
return tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()});
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
@ -98,9 +136,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
assert(tf_dialect == op->getDialect());
|
||||
|
||||
// If no result for this op needs shape inference, we have a fast-path return.
|
||||
// But if the type is a resource, we do not skip it because we might not have
|
||||
// the handle shapes.
|
||||
if (llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shape_type = type.dyn_cast<ShapedType>();
|
||||
return !shape_type || shape_type.hasStaticShape();
|
||||
return !shape_type ||
|
||||
(shape_type.hasStaticShape() &&
|
||||
!shape_type.getElementType().isa<TF::ResourceType>());
|
||||
})) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
|
||||
<< op->getName() << "'.\n";);
|
||||
@ -160,6 +202,9 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
std::vector<tensorflow::PartialTensorShape> input_shapes(
|
||||
op->getNumOperands());
|
||||
std::vector<tensorflow::Tensor> tensors(op->getNumOperands());
|
||||
std::vector<std::unique_ptr<std::vector<
|
||||
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>>
|
||||
handle_shapes_and_types(op->getNumOperands());
|
||||
for (auto it : llvm::enumerate(op->getOperands())) {
|
||||
Value operand = it.value();
|
||||
size_t index = it.index();
|
||||
@ -179,12 +224,31 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
}
|
||||
|
||||
Type operand_type = operand.getType();
|
||||
if (auto ranked_type = operand_type.dyn_cast<RankedTensorType>()) {
|
||||
// Convert the MLIR shape indices (int64_t) to TensorFlow indices (int64).
|
||||
ArrayRef<int64_t> shape = ranked_type.getShape();
|
||||
SmallVector<int64, 8> tf_shape(shape.begin(), shape.end());
|
||||
input_shapes[index] =
|
||||
tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()});
|
||||
if (auto shape = GetShapeFromMlirType(operand_type)) {
|
||||
input_shapes[index] = *shape;
|
||||
}
|
||||
// Collect the handle shapes and types for a resource.
|
||||
if (auto resource_type = operand_type.cast<TensorType>()
|
||||
.getElementType()
|
||||
.dyn_cast<TF::ResourceType>()) {
|
||||
if (resource_type.getSubtypes().empty()) continue;
|
||||
auto shapes_and_types = absl::make_unique<std::vector<
|
||||
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
|
||||
for (auto subtype : resource_type.getSubtypes()) {
|
||||
auto shape = GetShapeFromMlirType(subtype);
|
||||
// handle_shapes_and_types requires all shapes to be known. So if any
|
||||
// subtype is unknown, clear the vector.
|
||||
if (!shape) {
|
||||
shapes_and_types = nullptr;
|
||||
break;
|
||||
}
|
||||
tensorflow::DataType dtype;
|
||||
auto status =
|
||||
tensorflow::ConvertToDataType(subtype.getElementType(), &dtype);
|
||||
assert(status.ok() && "Unknown element type");
|
||||
shapes_and_types->emplace_back(*shape, dtype);
|
||||
}
|
||||
handle_shapes_and_types[index] = std::move(shapes_and_types);
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,8 +257,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
// function operates on.
|
||||
tensorflow::shape_inference::InferenceContext c(
|
||||
graph_version, *node_def, op_reg_data->op_def, input_shapes,
|
||||
input_tensors, /*input_tensors_as_shapes=*/{},
|
||||
/*input_handle_shapes_and_types=*/{});
|
||||
input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
|
||||
auto status = c.Run(op_reg_data->shape_inference_fn);
|
||||
if (!status.ok()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op
|
||||
@ -206,12 +269,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
"inference context matches the MLIR number of results.");
|
||||
|
||||
// Update the shape for each of the operation result if the InferenceContext
|
||||
// has more precise shapes recorded. A builder is used to insert tf.Cast
|
||||
// operation when changing the type of a result is the user is not a TF
|
||||
// operation, as we can't guarantee that the new type will be OK.
|
||||
// has more precise shapes recorded.
|
||||
bool changed = false;
|
||||
OpBuilder builder(op);
|
||||
builder.setInsertionPointAfter(op);
|
||||
for (int output : llvm::seq<int>(0, c.num_outputs())) {
|
||||
// Skip already statically shaped results.
|
||||
Value result = op->getResult(output);
|
||||
@ -221,30 +280,39 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
||||
tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : "
|
||||
<< c.DebugString(shape_handle) << "\n");
|
||||
if (!c.RankKnown(shape_handle)) continue;
|
||||
|
||||
// Convert the shape from TensorFlow (int64) to MLIR (int64_t).
|
||||
SmallVector<int64_t, 8> shape;
|
||||
for (int dim : llvm::seq<int>(0, c.Rank(shape_handle)))
|
||||
shape.push_back(c.Value(c.Dim(shape_handle, dim)));
|
||||
auto new_type = RankedTensorType::get(shape, shaped_type.getElementType());
|
||||
|
||||
// A tf.Cast operation is lazily created on the first uses that isn't a TF
|
||||
// operation.
|
||||
TF::CastOp cast_op;
|
||||
auto get_cast_op = [&]() {
|
||||
if (!cast_op)
|
||||
cast_op =
|
||||
builder.create<TF::CastOp>(op->getLoc(), result.getType(), result,
|
||||
/*truncate=*/builder.getBoolAttr(false));
|
||||
return cast_op;
|
||||
auto get_tensor_type =
|
||||
[&c](const tensorflow::shape_inference::ShapeHandle& sh,
|
||||
Type element_type) -> TensorType {
|
||||
if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type);
|
||||
// Convert the shape from TensorFlow (int64) to MLIR (int64_t).
|
||||
SmallVector<int64_t, 8> shape;
|
||||
for (int dim : llvm::seq<int>(0, c.Rank(sh)))
|
||||
shape.push_back(c.Value(c.Dim(sh, dim)));
|
||||
return RankedTensorType::get(shape, element_type);
|
||||
};
|
||||
for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
|
||||
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op());
|
||||
auto new_element_type = shaped_type.getElementType();
|
||||
// Populate the handle shapes for a resource.
|
||||
if (auto resource_type = new_element_type.dyn_cast<TF::ResourceType>()) {
|
||||
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
|
||||
if (handle_shapes_types) {
|
||||
llvm::SmallVector<mlir::TensorType, 1> subtypes;
|
||||
OpBuilder b(op);
|
||||
for (const auto& shape_n_type : *handle_shapes_types) {
|
||||
Type element_type;
|
||||
auto status =
|
||||
tensorflow::ConvertDataType(shape_n_type.dtype, b, &element_type);
|
||||
assert(status.ok() && "Unknown element type");
|
||||
subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type));
|
||||
}
|
||||
new_element_type = TF::ResourceType::get(subtypes, op->getContext());
|
||||
}
|
||||
}
|
||||
|
||||
auto new_type = get_tensor_type(shape_handle, new_element_type);
|
||||
if (result.getType() == new_type) continue;
|
||||
|
||||
// Inserts a cast back to the original type if any user is not in the TF
|
||||
// dialect.
|
||||
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
|
||||
result.getType());
|
||||
// Finally we inferred the shape and replace the type for this result.
|
||||
result.setType(new_type);
|
||||
changed = true;
|
||||
|
Loading…
x
Reference in New Issue
Block a user