Add ConvertTypeToTensorShape method.

PiperOrigin-RevId: 266999034
This commit is contained in:
Yanan Cao 2019-09-03 13:22:53 -07:00 committed by TensorFlower Gardener
parent 53ed37363c
commit b3f1b34934
4 changed files with 100 additions and 0 deletions

View File

@ -386,6 +386,21 @@ cc_library(
],
)
tf_cc_test(
name = "convert_tensor_test",
size = "small",
srcs = ["utils/convert_tensor_test.cc"],
deps = [
":convert_tensor",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor/lib",
"@local_config_mlir//:IR",
],
)
cc_library(
name = "mangling_util",
srcs = ["utils/mangling_util.cc"],

View File

@ -147,6 +147,23 @@ void ConvertToTensorShapeProto(ArrayRef<int64_t> shape,
}
}
PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) {
if (type.isa<mlir::UnrankedTensorType>()) {
// An empty PartialTensorShape indicates an unranked tensor.
return PartialTensorShape();
}
if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>()) {
TensorShapeProto tensor_shape_proto;
ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto);
return PartialTensorShape(tensor_shape_proto);
}
// If type is not a RankedTensor or UnrankedTensor, it must be a scalar.
// Empty TensorShape indicates a scalar.
return TensorShape();
}
// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto.
Status ConvertOpaqueElementsAttr(const ElementsAttr attr,
TensorProto* output_tensor) {

View File

@ -45,6 +45,9 @@ StatusOr<mlir::ElementsAttr> ConvertTensor(const Tensor& input_tensor,
void ConvertToTensorShapeProto(llvm::ArrayRef<int64_t> shape,
TensorShapeProto* output_shape);
// Converts an MLIR type with static tensor shape to an TensorFlow tensor shape.
PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type);
// Converts an MLIR elements attribute to an TensorFlow tensor proto.
Status ConvertToTensorProto(mlir::ElementsAttr attr,
TensorProto* output_tensor);

View File

@ -0,0 +1,65 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace {
TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
mlir::MLIRContext context;
mlir::Builder b(&context);
PartialTensorShape output_shape =
ConvertTypeToTensorShape(b.getTensorType(b.getF32Type()));
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape()));
}
TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
mlir::MLIRContext context;
mlir::Builder b(&context);
PartialTensorShape output_shape =
ConvertTypeToTensorShape(b.getTensorType({-1, 2, 3}, b.getF32Type()));
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape({-1, 2, 3})));
}
TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
mlir::MLIRContext context;
mlir::Builder b(&context);
PartialTensorShape output_shape =
ConvertTypeToTensorShape(b.getTensorType({1, 2, 3}, b.getF32Type()));
EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape({1, 2, 3})));
}
TEST(ConvertTypeToTensorTypeTest, ScalarTensorType) {
mlir::MLIRContext context;
mlir::Builder b(&context);
PartialTensorShape output_shape = ConvertTypeToTensorShape(b.getF32Type());
EXPECT_TRUE(output_shape.IsIdenticalTo(TensorShape()));
}
} // namespace
} // namespace tensorflow