From b3f1b34934c87f2843a6bd9e6e5eeac36d24a284 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Tue, 3 Sep 2019 13:22:53 -0700 Subject: [PATCH] Add ConvertTypeToTensorShape method. PiperOrigin-RevId: 266999034 --- tensorflow/compiler/mlir/tensorflow/BUILD | 15 +++++ .../mlir/tensorflow/utils/convert_tensor.cc | 17 +++++ .../mlir/tensorflow/utils/convert_tensor.h | 3 + .../tensorflow/utils/convert_tensor_test.cc | 65 +++++++++++++++++++ 4 files changed, 100 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 2f0dee89252..4b64dfcb9dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -386,6 +386,21 @@ cc_library( ], ) +tf_cc_test( + name = "convert_tensor_test", + size = "small", + srcs = ["utils/convert_tensor_test.cc"], + deps = [ + ":convert_tensor", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@local_config_mlir//:IR", + ], +) + cc_library( name = "mangling_util", srcs = ["utils/mangling_util.cc"], diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 10411bfffad..f57ecd5ae39 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -147,6 +147,23 @@ void ConvertToTensorShapeProto(ArrayRef shape, } } +PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) { + if (type.isa()) { + // An empty PartialTensorShape indicates an unranked tensor. + return PartialTensorShape(); + } + + if (auto tensor_type = type.dyn_cast()) { + TensorShapeProto tensor_shape_proto; + ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto); + return PartialTensorShape(tensor_shape_proto); + } + + // If type is not a RankedTensor or UnrankedTensor, it must be a scalar. + // Empty TensorShape indicates a scalar. + return TensorShape(); +} + // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto. Status ConvertOpaqueElementsAttr(const ElementsAttr attr, TensorProto* output_tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index d934f96acc7..f57cd1c872a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -45,6 +45,9 @@ StatusOr ConvertTensor(const Tensor& input_tensor, void ConvertToTensorShapeProto(llvm::ArrayRef shape, TensorShapeProto* output_shape); +// Converts an MLIR type with static tensor shape to an TensorFlow tensor shape. +PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type); + // Converts an MLIR elements attribute to an TensorFlow tensor proto. Status ConvertToTensorProto(mlir::ElementsAttr attr, TensorProto* output_tensor); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc new file mode 100644 index 00000000000..1c53adcdda3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" + +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { +namespace { + +TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + PartialTensorShape output_shape = + ConvertTypeToTensorShape(b.getTensorType(b.getF32Type())); + EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape())); +} + +TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + PartialTensorShape output_shape = + ConvertTypeToTensorShape(b.getTensorType({-1, 2, 3}, b.getF32Type())); + EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape({-1, 2, 3}))); +} + +TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + PartialTensorShape output_shape = + ConvertTypeToTensorShape(b.getTensorType({1, 2, 3}, b.getF32Type())); + EXPECT_TRUE(output_shape.IsIdenticalTo(PartialTensorShape({1, 2, 3}))); +} + +TEST(ConvertTypeToTensorTypeTest, ScalarTensorType) { + mlir::MLIRContext context; + mlir::Builder b(&context); + + PartialTensorShape output_shape = ConvertTypeToTensorShape(b.getF32Type()); + EXPECT_TRUE(output_shape.IsIdenticalTo(TensorShape())); +} + +} // namespace +} // namespace tensorflow