[MLIR:TF/XLA] Handle resource handle shapes (subtypes) in shape inference.

Set the input_handle_shapes_and_types field in InferenceContext so that they will be visible to op registry's shape functions. Parse the output_handle_shapes_and_types as well to preserve this information in the MLIR type.

PiperOrigin-RevId: 287922553
Change-Id: I0f33ae5325e67bb688ec5f38b12d538b23be380b
This commit is contained in:
Yuanzhong Xu 2020-01-02 17:41:35 -08:00 committed by TensorFlower Gardener
parent e473233b84
commit 75fbbf8c7e
2 changed files with 125 additions and 42 deletions

View File

@ -105,14 +105,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
} }
// CHECK-LABEL: func @shape_from_while_to_cond_body_functions // CHECK-LABEL: func @shape_from_while_to_cond_body_functions
func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>) -> tensor<4xf32> { func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<4xf32> {
%0 = "tf.While"(%arg0) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>) -> tensor<4xf32> // CHECK "tf.While"
return %0 : tensor<4xf32> // CHECK-SAME (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>)
%0:3 = "tf.While"(%arg0, %arg1, %arg2) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>, tensor<!tf.resource<tensor<4xf32>>>, tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<4xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>)
return %0#0 : tensor<4xf32>
} }
// CHECK-LABEL: func @while_cond_func // CHECK-LABEL: func @while_cond_func
// CHECK-SAME: %arg0: tensor<4xf32>) -> tensor<i1> // CHECK-SAME: (%arg0: tensor<4xf32>, %arg1: tensor<!tf.resource<tensor<4xf32>>>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<i1>
func @while_cond_func(%arg0: tensor<*xf32>) -> tensor<i1> { func @while_cond_func(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<[1.000000e-04,2.000000e-04,3.000000e-04,4.000000e-04]> : tensor<4xf32>} : () -> tensor<4xf32> %0 = "tf.Const"() {value = dense<[1.000000e-04,2.000000e-04,3.000000e-04,4.000000e-04]> : tensor<4xf32>} : () -> tensor<4xf32>
%1 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> %1 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: tf.Equal // CHECK: tf.Equal
@ -124,14 +126,27 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
} }
// CHECK-LABEL: func @while_body_func // CHECK-LABEL: func @while_body_func
func @while_body_func(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @while_body_func(%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>, %arg2: tensor<!tf.resource<tensor<*xf32>>>) -> (tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>) {
%0 = "tf.Const"() {value = dense<1.000000e-04> : tensor<f32>} : () -> tensor<f32> %0 = "tf.Const"() {value = dense<1.000000e-04> : tensor<f32>} : () -> tensor<f32>
// CHECK: tf.AddV2 // CHECK: tf.AddV2
// CHECK-SAME: (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> // CHECK-SAME: (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
%1 = "tf.AddV2"(%arg0, %0) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32> %1 = "tf.AddV2"(%arg0, %0) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
// CHECK: "tf.Identity"
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<!tf.resource<tensor<4xf32>>>
%2 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource>
// CHECK: "tf.TPUReplicatedInput"
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<!tf.resource<tensor<4xf32>>>
%ri = "tf.TPUReplicatedInput"(%2) : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource>
// CHECK: "tf.ReadVariableOp"
// CHECK-SAME: (tensor<!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
%read = "tf.ReadVariableOp"(%ri) : (tensor<*x!tf.resource>) -> tensor<*xf32>
// CHECK: "tf.ReadVariableOp"
// CHECK-SAME: (tensor<!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
%read1 = "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<*xf32>>>) -> tensor<*xf32>
// CHECK: return // CHECK: return
// CHECK-SAME: tensor<4xf32> // CHECK-SAME: tensor<4xf32>
return %1 : tensor<*xf32> // CHECK-SAME: tensor<!tf.resource<tensor<4xf32>>>
return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor<!tf.resource<tensor<*xf32>>>
} }
// CHECK-LABEL: func @invalid_function_reused_by_control_flows // CHECK-LABEL: func @invalid_function_reused_by_control_flows

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h" #include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
@ -39,11 +40,14 @@ limitations under the License.
#include "mlir/Transforms/FoldUtils.h" // TF:llvm-project #include "mlir/Transforms/FoldUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.pb.h"
#define DEBUG_TYPE "tf-shape-inference" #define DEBUG_TYPE "tf-shape-inference"
@ -91,6 +95,40 @@ Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
return llvm::to_vector<4>(return_op.getOperandTypes()); return llvm::to_vector<4>(return_op.getOperandTypes());
} }
// Inserts tf.Cast operation when changing the type of a result if the user is
// not a TF operation, as we can't guarantee that the new type will be OK.
void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result,
Dialect* tf_dialect, Type old_type) {
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
// A tf.Cast operation is lazily created on the first uses that isn't a TF
// operation.
TF::CastOp cast_op;
auto get_cast_op = [&]() {
if (!cast_op)
cast_op =
builder.create<TF::CastOp>(op->getLoc(), old_type, result,
/*truncate=*/builder.getBoolAttr(false));
return cast_op;
};
for (OpOperand& use : llvm::make_early_inc_range(result->getUses())) {
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op());
}
}
// Extracts a PartialTensorShape from the MLIR type.
Optional<tensorflow::PartialTensorShape> GetShapeFromMlirType(Type t) {
if (auto ranked_type = t.dyn_cast<RankedTensorType>()) {
// Convert the MLIR shape indices (int64_t) to TensorFlow indices
// (int64).
ArrayRef<int64_t> shape = ranked_type.getShape();
SmallVector<int64, 8> tf_shape(shape.begin(), shape.end());
return tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()});
}
return None;
}
} // namespace } // namespace
bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
@ -98,9 +136,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
assert(tf_dialect == op->getDialect()); assert(tf_dialect == op->getDialect());
// If no result for this op needs shape inference, we have a fast-path return. // If no result for this op needs shape inference, we have a fast-path return.
// But if the type is a resource, we do not skip it because we might not have
// the handle shapes.
if (llvm::all_of(op->getResultTypes(), [](Type type) { if (llvm::all_of(op->getResultTypes(), [](Type type) {
auto shape_type = type.dyn_cast<ShapedType>(); auto shape_type = type.dyn_cast<ShapedType>();
return !shape_type || shape_type.hasStaticShape(); return !shape_type ||
(shape_type.hasStaticShape() &&
!shape_type.getElementType().isa<TF::ResourceType>());
})) { })) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '" LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
<< op->getName() << "'.\n";); << op->getName() << "'.\n";);
@ -160,6 +202,9 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
std::vector<tensorflow::PartialTensorShape> input_shapes( std::vector<tensorflow::PartialTensorShape> input_shapes(
op->getNumOperands()); op->getNumOperands());
std::vector<tensorflow::Tensor> tensors(op->getNumOperands()); std::vector<tensorflow::Tensor> tensors(op->getNumOperands());
std::vector<std::unique_ptr<std::vector<
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>>
handle_shapes_and_types(op->getNumOperands());
for (auto it : llvm::enumerate(op->getOperands())) { for (auto it : llvm::enumerate(op->getOperands())) {
Value operand = it.value(); Value operand = it.value();
size_t index = it.index(); size_t index = it.index();
@ -179,12 +224,31 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
} }
Type operand_type = operand.getType(); Type operand_type = operand.getType();
if (auto ranked_type = operand_type.dyn_cast<RankedTensorType>()) { if (auto shape = GetShapeFromMlirType(operand_type)) {
// Convert the MLIR shape indices (int64_t) to TensorFlow indices (int64). input_shapes[index] = *shape;
ArrayRef<int64_t> shape = ranked_type.getShape(); }
SmallVector<int64, 8> tf_shape(shape.begin(), shape.end()); // Collect the handle shapes and types for a resource.
input_shapes[index] = if (auto resource_type = operand_type.cast<TensorType>()
tensorflow::PartialTensorShape({tf_shape.data(), tf_shape.size()}); .getElementType()
.dyn_cast<TF::ResourceType>()) {
if (resource_type.getSubtypes().empty()) continue;
auto shapes_and_types = absl::make_unique<std::vector<
std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>>();
for (auto subtype : resource_type.getSubtypes()) {
auto shape = GetShapeFromMlirType(subtype);
// handle_shapes_and_types requires all shapes to be known. So if any
// subtype is unknown, clear the vector.
if (!shape) {
shapes_and_types = nullptr;
break;
}
tensorflow::DataType dtype;
auto status =
tensorflow::ConvertToDataType(subtype.getElementType(), &dtype);
assert(status.ok() && "Unknown element type");
shapes_and_types->emplace_back(*shape, dtype);
}
handle_shapes_and_types[index] = std::move(shapes_and_types);
} }
} }
@ -193,8 +257,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
// function operates on. // function operates on.
tensorflow::shape_inference::InferenceContext c( tensorflow::shape_inference::InferenceContext c(
graph_version, *node_def, op_reg_data->op_def, input_shapes, graph_version, *node_def, op_reg_data->op_def, input_shapes,
input_tensors, /*input_tensors_as_shapes=*/{}, input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types);
/*input_handle_shapes_and_types=*/{});
auto status = c.Run(op_reg_data->shape_inference_fn); auto status = c.Run(op_reg_data->shape_inference_fn);
if (!status.ok()) { if (!status.ok()) {
LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op
@ -206,12 +269,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
"inference context matches the MLIR number of results."); "inference context matches the MLIR number of results.");
// Update the shape for each of the operation result if the InferenceContext // Update the shape for each of the operation result if the InferenceContext
// has more precise shapes recorded. A builder is used to insert tf.Cast // has more precise shapes recorded.
// operation when changing the type of a result is the user is not a TF
// operation, as we can't guarantee that the new type will be OK.
bool changed = false; bool changed = false;
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
for (int output : llvm::seq<int>(0, c.num_outputs())) { for (int output : llvm::seq<int>(0, c.num_outputs())) {
// Skip already statically shaped results. // Skip already statically shaped results.
Value result = op->getResult(output); Value result = op->getResult(output);
@ -221,30 +280,39 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output); tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output);
LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : " LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : "
<< c.DebugString(shape_handle) << "\n"); << c.DebugString(shape_handle) << "\n");
if (!c.RankKnown(shape_handle)) continue; auto get_tensor_type =
[&c](const tensorflow::shape_inference::ShapeHandle& sh,
// Convert the shape from TensorFlow (int64) to MLIR (int64_t). Type element_type) -> TensorType {
SmallVector<int64_t, 8> shape; if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type);
for (int dim : llvm::seq<int>(0, c.Rank(shape_handle))) // Convert the shape from TensorFlow (int64) to MLIR (int64_t).
shape.push_back(c.Value(c.Dim(shape_handle, dim))); SmallVector<int64_t, 8> shape;
auto new_type = RankedTensorType::get(shape, shaped_type.getElementType()); for (int dim : llvm::seq<int>(0, c.Rank(sh)))
shape.push_back(c.Value(c.Dim(sh, dim)));
// A tf.Cast operation is lazily created on the first uses that isn't a TF return RankedTensorType::get(shape, element_type);
// operation.
TF::CastOp cast_op;
auto get_cast_op = [&]() {
if (!cast_op)
cast_op =
builder.create<TF::CastOp>(op->getLoc(), result.getType(), result,
/*truncate=*/builder.getBoolAttr(false));
return cast_op;
}; };
for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) { auto new_element_type = shaped_type.getElementType();
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op()); // Populate the handle shapes for a resource.
if (auto resource_type = new_element_type.dyn_cast<TF::ResourceType>()) {
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
if (handle_shapes_types) {
llvm::SmallVector<mlir::TensorType, 1> subtypes;
OpBuilder b(op);
for (const auto& shape_n_type : *handle_shapes_types) {
Type element_type;
auto status =
tensorflow::ConvertDataType(shape_n_type.dtype, b, &element_type);
assert(status.ok() && "Unknown element type");
subtypes.push_back(get_tensor_type(shape_n_type.shape, element_type));
}
new_element_type = TF::ResourceType::get(subtypes, op->getContext());
}
} }
auto new_type = get_tensor_type(shape_handle, new_element_type);
if (result.getType() == new_type) continue; if (result.getType() == new_type) continue;
// Inserts a cast back to the original type if any user is not in the TF
// dialect.
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
result.getType());
// Finally we inferred the shape and replace the type for this result. // Finally we inferred the shape and replace the type for this result.
result.setType(new_type); result.setType(new_type);
changed = true; changed = true;