[MLIR:TF/XLA] Handle resource handle shapes (subtypes) in shape inference.

Set the input_handle_shapes_and_types field in InferenceContext so that they will be visible to op registry's shape functions. Parse the output_handle_shapes_and_types as well to preserve this information in the MLIR type.

PiperOrigin-RevId: 287922553
Change-Id: I0f33ae5325e67bb688ec5f38b12d538b23be380b
This commit is contained in:
Yuanzhong Xu 2020-01-02 17:41:35 -08:00 committed by TensorFlower Gardener
parent e473233b84
commit 75fbbf8c7e
2 changed files with 125 additions and 42 deletions

View File

@ -105,14 +105,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
}
// CHECK-LABEL: func @shape_from_while_to_cond_body_functions
func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%0 = "tf.While"(%arg0) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<4xf32> {
// CHECK "tf.While"
// CHECK-SAME (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>)
%0:3 = "tf.While"(%arg0, %arg1, %arg2) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>)
return %0#0 : tensor<4xf32>
}
// CHECK-LABEL: func @while_cond_func
// CHECK-SAME: %arg0: tensor<4xf32>) -> tensor<i1>
func @while_cond_func(%arg0: tensor<*xf32>) -> tensor<i1> {
// CHECK-SAME: (%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<i1>
func @while_cond_func(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<[1.000000e-04,2.000000e-04,3.000000e-04,4.000000e-04]> : tensor<4xf32>} : () -> tensor<4xf32>
%1 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: tf.Equal
@ -124,14 +126,27 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
}
// CHECK-LABEL: func @while_body_func
func @while_body_func(%arg0: tensor<*xf32>) -> tensor<*xf32> {
func @while_body_func(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>) {
%0 = "tf.Const"() {value = dense<1.000000e-04> : tensor<f32>} : () -> tensor<f32>
// CHECK: tf.AddV2
// CHECK-SAME: (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
%1 = "tf.AddV2"(%arg0, %0) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: "tf.Identity"
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<!tf.resource<tensor<4xf32>>>
%2 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource>
// CHECK: "tf.TPUReplicatedInput"
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<!tf.resource<tensor<4xf32>>>
%ri = "tf.TPUReplicatedInput"(%2) : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource>
// CHECK: "tf.ReadVariableOp"
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
%read = "tf.ReadVariableOp"(%ri) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: "tf.ReadVariableOp"
// CHECK-SAME: (tensor<!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
%read1 = "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: return
// CHECK-SAME: tensor<4xf32>
return %1 : tensor<*xf32>
// CHECK-SAME: tensor<!tf.resource<tensor<4xf32>>>
return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>
}
// CHECK-LABEL: func @invalid_function_reused_by_control_flows

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
@ -39,11 +40,14 @@ limitations under the License.
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.pb.h"
#define DEBUG_TYPE "tf-shape-inference"
@ -91,6 +95,40 @@ Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
return llvm::to_vector<4>(return_op.getOperandTypes());
}
// Inserts tf.Cast operation when changing the type of a result if the user is
// not a TF operation, as we can't guarantee that the new type will be OK.
void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result,
Dialect* tf_dialect, Type old_type) {
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
// A tf.Cast operation is lazily created on the first uses that isn't a TF
// operation.
TF::CastOp cast_op;
auto get_cast_op = [&]() {
if (!cast_op)
cast_op =
builder.create<TF::CastOp>(op->getLoc(), old_type, result,
/*truncate=*/builder.getBoolAttr(false));
return cast_op;
};
for (OpOperand& use : llvm::make_early_inc_range(result->getUses())) {
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op());
}
}
// Extracts a PartialTensorShape from the MLIR type.
Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) {
if (auto ranked_type = t.dyn_cast<RankedTensorType>()) {
// Convert the MLIR shape indices (int64_t) to TensorFlow indices
// (int64).
ArrayRef<int64_t> shape = ranked_type.getShape();
SmallVector<int64, 8> tf_shape(shape.begin(), shape.end());
return tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()});
}
return None;
}
} // namespace
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
@ -98,9 +136,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
assert(tf_dialect == op->getDialect());
// If no result for this op needs shape inference, we have a fast-path return.
// But if the type is a resource, we do not skip it because we might not have
// the handle shapes.
if (llvm::all_of(op->getResultTypes(), [](Type type) {
auto shape_type = type.dyn_cast<ShapedType>();
return !shape_type || shape_type.hasStaticShape();
return !shape_type ||
(shape_type.hasStaticShape() &&
!shape_type.getElementType().isa<TF::ResourceType>());
})) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
<< op->getName() << "'.\n";);
@ -160,6 +202,9 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
std::vector<tensorflow::PartialTensorShape> input_shapes(
op->getNumOperands());
std::vector<tensorflow::Tensor> tensors(op->getNumOperands());
std::vector<std::unique_ptr<std::vector<
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>>
handle_shapes_and_types(op->getNumOperands());
for (auto it : llvm::enumerate(op->getOperands())) {
Value operand = it.value();
size_t index = it.index();
@ -179,12 +224,31 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
}
Type operand_type = operand.getType();
if (auto ranked_type = operand_type.dyn_cast<RankedTensorType>()) {
// Convert the MLIR shape indices (int64_t) to TensorFlow indices (int64).
ArrayRef<int64_t> shape = ranked_type.getShape();
SmallVector<int64, 8> tf_shape(shape.begin(), shape.end());
input_shapes[index] =
tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()});
if (auto shape = GetShapeFromMlirType(operand_type)) {
input_shapes[index] = *shape;
}
// Collect the handle shapes and types for a resource.
if (auto resource_type = operand_type.cast<TensorType>()
.getElementType()
.dyn_cast<TF::ResourceType>()) {
if (resource_type.getSubtypes().empty()) continue;
auto shapes_and_types = absl::make_unique<std::vector<
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
for (auto subtype : resource_type.getSubtypes()) {
auto shape = GetShapeFromMlirType(subtype);
// handle_shapes_and_types requires all shapes to be known. So if any
// subtype is unknown, clear the vector.
if (!shape) {
shapes_and_types = nullptr;
break;
}
tensorflow::DataType dtype;
auto status =
tensorflow::ConvertToDataType(subtype.getElementType(), &dtype);
assert(status.ok() && "Unknown element type");
shapes_and_types->emplace_back(*shape, dtype);
}
handle_shapes_and_types[index] = std::move(shapes_and_types);
}
}
@ -193,8 +257,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
// function operates on.
tensorflow::shape_inference::InferenceContext c(
graph_version, *node_def, op_reg_data->op_def, input_shapes,
input_tensors, /*input_tensors_as_shapes=*/{},
/*input_handle_shapes_and_types=*/{});
input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
auto status = c.Run(op_reg_data->shape_inference_fn);
if (!status.ok()) {
LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op
@ -206,12 +269,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
"inference context matches the MLIR number of results.");
// Update the shape for each of the operation result if the InferenceContext
// has more precise shapes recorded. A builder is used to insert tf.Cast
// operation when changing the type of a result is the user is not a TF
// operation, as we can't guarantee that the new type will be OK.
// has more precise shapes recorded.
bool changed = false;
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
for (int output : llvm::seq<int>(0, c.num_outputs())) {
// Skip already statically shaped results.
Value result = op->getResult(output);
@ -221,30 +280,39 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output);
LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : "
<< c.DebugString(shape_handle) << "\n");
if (!c.RankKnown(shape_handle)) continue;
// Convert the shape from TensorFlow (int64) to MLIR (int64_t).
SmallVector<int64_t, 8> shape;
for (int dim : llvm::seq<int>(0, c.Rank(shape_handle)))
shape.push_back(c.Value(c.Dim(shape_handle, dim)));
auto new_type = RankedTensorType::get(shape, shaped_type.getElementType());
// A tf.Cast operation is lazily created on the first uses that isn't a TF
// operation.
TF::CastOp cast_op;
auto get_cast_op = [&]() {
if (!cast_op)
cast_op =
builder.create<TF::CastOp>(op->getLoc(), result.getType(), result,
/*truncate=*/builder.getBoolAttr(false));
return cast_op;
auto get_tensor_type =
[&c](const tensorflow::shape_inference::ShapeHandle& sh,
Type element_type) -> TensorType {
if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type);
// Convert the shape from TensorFlow (int64) to MLIR (int64_t).
SmallVector<int64_t, 8> shape;
for (int dim : llvm::seq<int>(0, c.Rank(sh)))
shape.push_back(c.Value(c.Dim(sh, dim)));
return RankedTensorType::get(shape, element_type);
};
for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op());
auto new_element_type = shaped_type.getElementType();
// Populate the handle shapes for a resource.
if (auto resource_type = new_element_type.dyn_cast<TF::ResourceType>()) {
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
if (handle_shapes_types) {
llvm::SmallVector<mlir::TensorType, 1> subtypes;
OpBuilder b(op);
for (const auto& shape_n_type : *handle_shapes_types) {
Type element_type;
auto status =
tensorflow::ConvertDataType(shape_n_type.dtype, b, &element_type);
assert(status.ok() && "Unknown element type");
subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type));
}
new_element_type = TF::ResourceType::get(subtypes, op->getContext());
}
}
auto new_type = get_tensor_type(shape_handle, new_element_type);
if (result.getType() == new_type) continue;
// Inserts a cast back to the original type if any user is not in the TF
// dialect.
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
result.getType());
// Finally we inferred the shape and replace the type for this result.
result.setType(new_type);
changed = true;