Add ExtensionTypeVariant class, which can be used to store any ExtensionType values (aka CompositeTensor values) in a variant scalar.

PiperOrigin-RevId: 343870805
Change-Id: I66303071c34503ca8924fee78317d59d0a108218
This commit is contained in:
Edward Loper 2020-11-23 09:53:00 -08:00 committed by TensorFlower Gardener
parent e2de99a4e3
commit 5b9437eba6
6 changed files with 299 additions and 0 deletions

View File

@ -1251,6 +1251,31 @@ cc_library(
],
)
cc_library(
name = "extension_type_variant",
srcs = ["extension_type_variant.cc"],
hdrs = ["extension_type_variant.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "extension_type_variant_test",
size = "small",
srcs = ["extension_type_variant_test.cc"],
deps = [
":extension_type_variant",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
# All framewrok protos are self-contained, i.e. they only import other
# protos from the same package, so we can build the protos here and then
# link them from core:protos_all without circular dependencies.

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/extension_type_variant.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
constexpr const char ExtensionTypeVariant::kTypeName[];
void ExtensionTypeVariant::Encode(VariantTensorData* data) const {
data->set_type_name(TypeName());
metadata_.type_spec_proto().SerializeToString(&data->metadata_string());
for (const Tensor& tensor : flat_components_) {
data->add_tensor(tensor);
}
}
bool ExtensionTypeVariant::Decode(const VariantTensorData& data) {
if (!metadata_.mutable_type_spec_proto()->ParseFromString(
data.metadata_string())) {
return false;
}
flat_components_ = data.tensors();
return true;
}
string ExtensionTypeVariant::DebugString() const {
string type_spec;
::tensorflow::protobuf::TextFormat::Printer printer;
printer.SetSingleLineMode(true);
printer.PrintToString(metadata_.type_spec_proto(), &type_spec);
string result("<ExtensionTypeVariant type_spec={");
result.append(type_spec.empty() ? "none" : type_spec);
result.append("}, components=[");
for (const auto& tensor : flat_components_) {
if (&tensor != &flat_components_[0]) {
result.append(", ");
}
result.append(tensor.DebugString());
}
result.append("]>");
return result;
}
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(ExtensionTypeVariant,
ExtensionTypeVariant::kTypeName);
} // namespace tensorflow

View File

@ -0,0 +1,96 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_
#define TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/protobuf/extension_type_variant.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
// Encoding for a `tf.ExtensionType` value, that can be saved as a Variant.
//
// `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class
// used to Python types that are supported by TensorFlow APIs. Example
// ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`.
//
// `ExtensionTypeVariant` decomposes the `ExtensionType` value into two
// parts:
//
// * `components`: A list of Tensors, which encodes the value's dynamic
// data -- i.e., data that may change for different executions of a graph.
// * `type_spec_proto`: A serialized TypeSpec, which encodes the value's
// static data -- i.e., data that is the same for all executions of a graph.
//
// ExtensionTypeVariant can be stored in a Tensor with dtype=DT_VARIANT.
// Typically, extension type values are encoded with a scalar tensor containing
// a single ExtensionTypeVariant value.
class ExtensionTypeVariant {
public:
ExtensionTypeVariant(const TypeSpecProto& type_spec_proto,
absl::Span<Tensor> flat_components)
: flat_components_(flat_components.begin(), flat_components.end()) {
*metadata_.mutable_type_spec_proto() = type_spec_proto;
}
// This type is default-constructible, copyable, assignable, and movable.
ExtensionTypeVariant() = default;
ExtensionTypeVariant(const ExtensionTypeVariant& other) = default;
ExtensionTypeVariant& operator=(ExtensionTypeVariant&& other) = default;
ExtensionTypeVariant& operator=(const ExtensionTypeVariant& other) = default;
// Returns the list of Tensor components that encode this value's dynamic
// data.
absl::Span<const Tensor> flat_components() const {
return absl::MakeConstSpan(flat_components_);
}
// Returns the serialized TypeSpec that encodes the value's static data.
TypeSpecProto type_spec_proto() const { return metadata_.type_spec_proto(); }
// Variant methods.
string TypeName() const { return kTypeName; }
// Updates `VariantTensorData` with an encoding for this value.
void Encode(VariantTensorData* data) const;
// Updates this value to match the encoding in a given `VariantTensorData`.
bool Decode(const VariantTensorData& data);
// Returns a string summary for this value.
string DebugString() const;
// Name of this type (used for variant serialization).
static constexpr const char kTypeName[] = "ExtensionTypeVariant";
private:
// Tensor components for this value.
std::vector<Tensor> flat_components_;
// TypeSpec for this value. ExtensionTypeVariantMetadata is a thin wrapper
// around a TypeSpecProto, which is used to retain flexibility to change the
// variant encoding.
ExtensionTypeVariantMetadata metadata_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_

View File

@ -0,0 +1,99 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/extension_type_variant.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// TypeSpecProto for a 2D Ragged Tensor.
constexpr const char* k2DRaggedTensorSpec = R"(
type_spec_class: RAGGED_TENSOR_SPEC
type_state: {
tuple_value: {
values: [
{tensor_shape_value: {dim: [{size: -1}, {size: -1}]}}, # shape
{tensor_dtype_value: DT_INT32}, # dtype
{int64_value: 1}, # ragged_rank
{tensor_dtype_value: DT_INT64} # row_splits_dtype
]
}
}
)";
// Returns an ExtensionTypeVariant encoding for a 2D ragged tensor with
// the specified values and row_splits.
ExtensionTypeVariant Make2DRaggedTensor(const std::vector<int32>& values,
const std::vector<int64>& splits) {
TypeSpecProto type_spec;
EXPECT_TRUE(
protobuf::TextFormat::ParseFromString(k2DRaggedTensorSpec, &type_spec));
std::vector<Tensor> components;
components.push_back(test::AsTensor<int32>(values));
components.push_back(test::AsTensor<int64>(splits));
ExtensionTypeVariant v(type_spec, absl::MakeSpan(components));
return v;
}
TEST(ExtensionTypeVariantTest, EncodeAndDecodeRagged) {
ExtensionTypeVariant v = Make2DRaggedTensor(
/* values = */ {5, 5, 3, 4, 1, 8},
/* splits = */ {0, 2, 3, 6});
Tensor t(DT_VARIANT, {});
t.flat<Variant>()(0) = v; // Encode to variant.
auto* decoded = t.flat<Variant>()(0).get<ExtensionTypeVariant>();
EXPECT_EQ(v.type_spec_proto().SerializeAsString(),
decoded->type_spec_proto().SerializeAsString());
EXPECT_EQ(v.flat_components().size(), 2);
test::ExpectTensorEqual<int32>(v.flat_components()[0],
decoded->flat_components()[0]);
test::ExpectTensorEqual<int64>(v.flat_components()[1],
decoded->flat_components()[1]);
}
TEST(ExtensionTypeVariantTest, DebugStringForDefaultConstructed) {
ExtensionTypeVariant v;
EXPECT_EQ(v.DebugString(),
"<ExtensionTypeVariant type_spec={none}, components=[]>");
}
TEST(ExtensionTypeVariantTest, DebugStringForRagged) {
ExtensionTypeVariant v = Make2DRaggedTensor(
/* values = */ {5, 5, 3, 4, 1},
/* splits = */ {0, 2, 3, 5});
EXPECT_EQ(v.DebugString(),
"<ExtensionTypeVariant type_spec={type_spec_class: "
"RAGGED_TENSOR_SPEC type_state { tuple_value { values { "
"tensor_shape_value { dim { size: -1 } dim { size: -1 } } } "
"values { tensor_dtype_value: DT_INT32 } values "
"{ int64_value: 1 } values { tensor_dtype_value: DT_INT64 } } } }, "
"components=[Tensor<type: int32 shape: [5] values: 5 5 3...>, "
"Tensor<type: int64 shape: [4] values: 0 2 3...>]>");
}
TEST(ExtensionTypeVariantTest, TypeName) {
ExtensionTypeVariant v;
EXPECT_EQ(v.TypeName(), "ExtensionTypeVariant");
}
} // namespace
} // namespace tensorflow

View File

@ -141,6 +141,7 @@ exports_files(
"snapshot.proto",
"service_config.proto",
"debug_event.proto",
"extension_type_variant.proto",
"meta_graph.proto",
"named_tensor.proto",
"remote_tensor_handle.proto",
@ -170,6 +171,7 @@ tf_proto_library(
"snapshot.proto",
"service_config.proto",
"debug_event.proto",
"extension_type_variant.proto",
"meta_graph.proto",
"named_tensor.proto",
"remote_tensor_handle.proto",

View File

@ -0,0 +1,14 @@
syntax = "proto3";
package tensorflow;
import "tensorflow/core/protobuf/struct.proto";
// Metadata for ExtensionTypeVariant, used when serializing as Variant.
//
// We define a new message here (rather than directly using TypeSpecProto for
// the metadata string) to retain flexibility to change the metadata encoding
// to support additional features.
message ExtensionTypeVariantMetadata {
TypeSpecProto type_spec_proto = 1;
}