Add ExtensionTypeVariant class, which can be used to store any ExtensionType values (aka CompositeTensor values) in a variant scalar.
PiperOrigin-RevId: 343870805 Change-Id: I66303071c34503ca8924fee78317d59d0a108218
This commit is contained in:
parent
e2de99a4e3
commit
5b9437eba6
tensorflow/core
@ -1251,6 +1251,31 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "extension_type_variant",
|
||||
srcs = ["extension_type_variant.cc"],
|
||||
hdrs = ["extension_type_variant.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "extension_type_variant_test",
|
||||
size = "small",
|
||||
srcs = ["extension_type_variant_test.cc"],
|
||||
deps = [
|
||||
":extension_type_variant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
# All framewrok protos are self-contained, i.e. they only import other
|
||||
# protos from the same package, so we can build the protos here and then
|
||||
# link them from core:protos_all without circular dependencies.
|
||||
|
63
tensorflow/core/framework/extension_type_variant.cc
Normal file
63
tensorflow/core/framework/extension_type_variant.cc
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/extension_type_variant.h"
|
||||
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
constexpr const char ExtensionTypeVariant::kTypeName[];
|
||||
|
||||
void ExtensionTypeVariant::Encode(VariantTensorData* data) const {
|
||||
data->set_type_name(TypeName());
|
||||
metadata_.type_spec_proto().SerializeToString(&data->metadata_string());
|
||||
for (const Tensor& tensor : flat_components_) {
|
||||
data->add_tensor(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
bool ExtensionTypeVariant::Decode(const VariantTensorData& data) {
|
||||
if (!metadata_.mutable_type_spec_proto()->ParseFromString(
|
||||
data.metadata_string())) {
|
||||
return false;
|
||||
}
|
||||
flat_components_ = data.tensors();
|
||||
return true;
|
||||
}
|
||||
|
||||
string ExtensionTypeVariant::DebugString() const {
|
||||
string type_spec;
|
||||
::tensorflow::protobuf::TextFormat::Printer printer;
|
||||
printer.SetSingleLineMode(true);
|
||||
printer.PrintToString(metadata_.type_spec_proto(), &type_spec);
|
||||
string result("<ExtensionTypeVariant type_spec={");
|
||||
result.append(type_spec.empty() ? "none" : type_spec);
|
||||
result.append("}, components=[");
|
||||
for (const auto& tensor : flat_components_) {
|
||||
if (&tensor != &flat_components_[0]) {
|
||||
result.append(", ");
|
||||
}
|
||||
result.append(tensor.DebugString());
|
||||
}
|
||||
result.append("]>");
|
||||
return result;
|
||||
}
|
||||
|
||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(ExtensionTypeVariant,
|
||||
ExtensionTypeVariant::kTypeName);
|
||||
|
||||
} // namespace tensorflow
|
96
tensorflow/core/framework/extension_type_variant.h
Normal file
96
tensorflow/core/framework/extension_type_variant.h
Normal file
@ -0,0 +1,96 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/protobuf/extension_type_variant.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Encoding for a `tf.ExtensionType` value, that can be saved as a Variant.
|
||||
//
|
||||
// `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class
|
||||
// used to Python types that are supported by TensorFlow APIs. Example
|
||||
// ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`.
|
||||
//
|
||||
// `ExtensionTypeVariant` decomposes the `ExtensionType` value into two
|
||||
// parts:
|
||||
//
|
||||
// * `components`: A list of Tensors, which encodes the value's dynamic
|
||||
// data -- i.e., data that may change for different executions of a graph.
|
||||
// * `type_spec_proto`: A serialized TypeSpec, which encodes the value's
|
||||
// static data -- i.e., data that is the same for all executions of a graph.
|
||||
//
|
||||
// ExtensionTypeVariant can be stored in a Tensor with dtype=DT_VARIANT.
|
||||
// Typically, extension type values are encoded with a scalar tensor containing
|
||||
// a single ExtensionTypeVariant value.
|
||||
class ExtensionTypeVariant {
|
||||
public:
|
||||
ExtensionTypeVariant(const TypeSpecProto& type_spec_proto,
|
||||
absl::Span<Tensor> flat_components)
|
||||
: flat_components_(flat_components.begin(), flat_components.end()) {
|
||||
*metadata_.mutable_type_spec_proto() = type_spec_proto;
|
||||
}
|
||||
|
||||
// This type is default-constructible, copyable, assignable, and movable.
|
||||
ExtensionTypeVariant() = default;
|
||||
ExtensionTypeVariant(const ExtensionTypeVariant& other) = default;
|
||||
ExtensionTypeVariant& operator=(ExtensionTypeVariant&& other) = default;
|
||||
ExtensionTypeVariant& operator=(const ExtensionTypeVariant& other) = default;
|
||||
|
||||
// Returns the list of Tensor components that encode this value's dynamic
|
||||
// data.
|
||||
absl::Span<const Tensor> flat_components() const {
|
||||
return absl::MakeConstSpan(flat_components_);
|
||||
}
|
||||
|
||||
// Returns the serialized TypeSpec that encodes the value's static data.
|
||||
TypeSpecProto type_spec_proto() const { return metadata_.type_spec_proto(); }
|
||||
|
||||
// Variant methods.
|
||||
string TypeName() const { return kTypeName; }
|
||||
|
||||
// Updates `VariantTensorData` with an encoding for this value.
|
||||
void Encode(VariantTensorData* data) const;
|
||||
|
||||
// Updates this value to match the encoding in a given `VariantTensorData`.
|
||||
bool Decode(const VariantTensorData& data);
|
||||
|
||||
// Returns a string summary for this value.
|
||||
string DebugString() const;
|
||||
|
||||
// Name of this type (used for variant serialization).
|
||||
static constexpr const char kTypeName[] = "ExtensionTypeVariant";
|
||||
|
||||
private:
|
||||
// Tensor components for this value.
|
||||
std::vector<Tensor> flat_components_;
|
||||
|
||||
// TypeSpec for this value. ExtensionTypeVariantMetadata is a thin wrapper
|
||||
// around a TypeSpecProto, which is used to retain flexibility to change the
|
||||
// variant encoding.
|
||||
ExtensionTypeVariantMetadata metadata_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_EXTENSION_TYPE_VARIANT_H_
|
99
tensorflow/core/framework/extension_type_variant_test.cc
Normal file
99
tensorflow/core/framework/extension_type_variant_test.cc
Normal file
@ -0,0 +1,99 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/extension_type_variant.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// TypeSpecProto for a 2D Ragged Tensor.
|
||||
constexpr const char* k2DRaggedTensorSpec = R"(
|
||||
type_spec_class: RAGGED_TENSOR_SPEC
|
||||
type_state: {
|
||||
tuple_value: {
|
||||
values: [
|
||||
{tensor_shape_value: {dim: [{size: -1}, {size: -1}]}}, # shape
|
||||
{tensor_dtype_value: DT_INT32}, # dtype
|
||||
{int64_value: 1}, # ragged_rank
|
||||
{tensor_dtype_value: DT_INT64} # row_splits_dtype
|
||||
]
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
// Returns an ExtensionTypeVariant encoding for a 2D ragged tensor with
|
||||
// the specified values and row_splits.
|
||||
ExtensionTypeVariant Make2DRaggedTensor(const std::vector<int32>& values,
|
||||
const std::vector<int64>& splits) {
|
||||
TypeSpecProto type_spec;
|
||||
EXPECT_TRUE(
|
||||
protobuf::TextFormat::ParseFromString(k2DRaggedTensorSpec, &type_spec));
|
||||
std::vector<Tensor> components;
|
||||
components.push_back(test::AsTensor<int32>(values));
|
||||
components.push_back(test::AsTensor<int64>(splits));
|
||||
ExtensionTypeVariant v(type_spec, absl::MakeSpan(components));
|
||||
return v;
|
||||
}
|
||||
|
||||
TEST(ExtensionTypeVariantTest, EncodeAndDecodeRagged) {
|
||||
ExtensionTypeVariant v = Make2DRaggedTensor(
|
||||
/* values = */ {5, 5, 3, 4, 1, 8},
|
||||
/* splits = */ {0, 2, 3, 6});
|
||||
Tensor t(DT_VARIANT, {});
|
||||
|
||||
t.flat<Variant>()(0) = v; // Encode to variant.
|
||||
auto* decoded = t.flat<Variant>()(0).get<ExtensionTypeVariant>();
|
||||
|
||||
EXPECT_EQ(v.type_spec_proto().SerializeAsString(),
|
||||
decoded->type_spec_proto().SerializeAsString());
|
||||
EXPECT_EQ(v.flat_components().size(), 2);
|
||||
test::ExpectTensorEqual<int32>(v.flat_components()[0],
|
||||
decoded->flat_components()[0]);
|
||||
test::ExpectTensorEqual<int64>(v.flat_components()[1],
|
||||
decoded->flat_components()[1]);
|
||||
}
|
||||
|
||||
TEST(ExtensionTypeVariantTest, DebugStringForDefaultConstructed) {
|
||||
ExtensionTypeVariant v;
|
||||
EXPECT_EQ(v.DebugString(),
|
||||
"<ExtensionTypeVariant type_spec={none}, components=[]>");
|
||||
}
|
||||
|
||||
TEST(ExtensionTypeVariantTest, DebugStringForRagged) {
|
||||
ExtensionTypeVariant v = Make2DRaggedTensor(
|
||||
/* values = */ {5, 5, 3, 4, 1},
|
||||
/* splits = */ {0, 2, 3, 5});
|
||||
EXPECT_EQ(v.DebugString(),
|
||||
"<ExtensionTypeVariant type_spec={type_spec_class: "
|
||||
"RAGGED_TENSOR_SPEC type_state { tuple_value { values { "
|
||||
"tensor_shape_value { dim { size: -1 } dim { size: -1 } } } "
|
||||
"values { tensor_dtype_value: DT_INT32 } values "
|
||||
"{ int64_value: 1 } values { tensor_dtype_value: DT_INT64 } } } }, "
|
||||
"components=[Tensor<type: int32 shape: [5] values: 5 5 3...>, "
|
||||
"Tensor<type: int64 shape: [4] values: 0 2 3...>]>");
|
||||
}
|
||||
|
||||
TEST(ExtensionTypeVariantTest, TypeName) {
|
||||
ExtensionTypeVariant v;
|
||||
EXPECT_EQ(v.TypeName(), "ExtensionTypeVariant");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -141,6 +141,7 @@ exports_files(
|
||||
"snapshot.proto",
|
||||
"service_config.proto",
|
||||
"debug_event.proto",
|
||||
"extension_type_variant.proto",
|
||||
"meta_graph.proto",
|
||||
"named_tensor.proto",
|
||||
"remote_tensor_handle.proto",
|
||||
@ -170,6 +171,7 @@ tf_proto_library(
|
||||
"snapshot.proto",
|
||||
"service_config.proto",
|
||||
"debug_event.proto",
|
||||
"extension_type_variant.proto",
|
||||
"meta_graph.proto",
|
||||
"named_tensor.proto",
|
||||
"remote_tensor_handle.proto",
|
||||
|
14
tensorflow/core/protobuf/extension_type_variant.proto
Normal file
14
tensorflow/core/protobuf/extension_type_variant.proto
Normal file
@ -0,0 +1,14 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/protobuf/struct.proto";
|
||||
|
||||
// Metadata for ExtensionTypeVariant, used when serializing as Variant.
|
||||
//
|
||||
// We define a new message here (rather than directly using TypeSpecProto for
|
||||
// the metadata string) to retain flexibility to change the metadata encoding
|
||||
// to support additional features.
|
||||
message ExtensionTypeVariantMetadata {
|
||||
TypeSpecProto type_spec_proto = 1;
|
||||
}
|
Loading…
Reference in New Issue
Block a user