From f24faa153ad31a4b51578f8181d3aaab77a1ddeb Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Wed, 20 May 2020 20:04:33 -0700 Subject: [PATCH] Add dataset element compression ops. These allow us to implement tf.data service compression/decompression as a part of the tf.data pipeline. PiperOrigin-RevId: 312605093 Change-Id: I4a833bc89e602c8fd78abc4c1a0026c2a397449f --- .../base_api/api_def_CompressElement.pbtxt | 5 ++ .../base_api/api_def_UncompressElement.pbtxt | 5 ++ tensorflow/core/framework/common_shape_fns.cc | 19 +++++ tensorflow/core/framework/common_shape_fns.h | 3 + .../core/kernels/data/experimental/BUILD | 15 ++++ .../data/experimental/compression_ops.cc | 76 +++++++++++++++++ .../data/experimental/compression_ops.h | 49 +++++++++++ tensorflow/core/ops/dataset_ops.cc | 35 ++------ .../core/ops/experimental_dataset_ops.cc | 13 +++ .../data/experimental/kernel_tests/BUILD | 15 +++- .../kernel_tests/compression_ops_test.py | 81 +++++++++++++++++++ tensorflow/python/data/experimental/ops/BUILD | 10 +++ .../data/experimental/ops/compression_ops.py | 55 +++++++++++++ .../api/golden/v1/tensorflow.raw_ops.pbtxt | 8 ++ .../api/golden/v2/tensorflow.raw_ops.pbtxt | 8 ++ 15 files changed, 366 insertions(+), 31 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_CompressElement.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_UncompressElement.pbtxt create mode 100644 tensorflow/core/kernels/data/experimental/compression_ops.cc create mode 100644 tensorflow/core/kernels/data/experimental/compression_ops.h create mode 100644 tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py create mode 100644 tensorflow/python/data/experimental/ops/compression_ops.py diff --git a/tensorflow/core/api_def/base_api/api_def_CompressElement.pbtxt b/tensorflow/core/api_def/base_api/api_def_CompressElement.pbtxt new file mode 100644 index 00000000000..17b63e4ab2f --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CompressElement.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "CompressElement" + visibility: HIDDEN + summary: "Compresses a dataset element." +} diff --git a/tensorflow/core/api_def/base_api/api_def_UncompressElement.pbtxt b/tensorflow/core/api_def/base_api/api_def_UncompressElement.pbtxt new file mode 100644 index 00000000000..e2039b674f0 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UncompressElement.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "UncompressElement" + visibility: HIDDEN + summary: "Uncompresses a compressed dataset element." +} diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 113adbdd432..216002ad8e7 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -468,6 +468,25 @@ Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, return Status::OK(); } +Status DatasetIteratorShape(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast(i), output_shape_handle); + } + return Status::OK(); +} + Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, const std::vector& spatial, DimensionOrConstant C, ShapeHandle* out, diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index e1984abab7e..218400c2435 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -92,6 +92,9 @@ inline Status MergeBothInputsShapeFn(InferenceContext* c) { return Status::OK(); } +// Shape function for dataset iterators. +Status DatasetIteratorShape(shape_inference::InferenceContext* c); + // Returns a new shape with the specified dims arranged in the specified // format. The returned value is owned by this context. // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 85f8af878ee..f4b9240ca31 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -109,6 +109,20 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "compression_ops", + srcs = ["compression_ops.cc"], + hdrs = ["compression_ops.h"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/data:compression_utils", + "//tensorflow/core/data:dataset_proto_cc", + ], +) + tf_kernel_library( name = "csv_dataset_op", srcs = ["csv_dataset_op.cc"], @@ -681,6 +695,7 @@ tf_kernel_library( ":auto_shard_dataset_op", ":choose_fastest_branch_dataset_op", ":choose_fastest_dataset_op", + ":compression_ops", ":csv_dataset_op", ":dense_to_sparse_batch_dataset_op", ":directed_interleave_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/compression_ops.cc b/tensorflow/core/kernels/data/experimental/compression_ops.cc new file mode 100644 index 00000000000..efa7018acb6 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/compression_ops.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/experimental/compression_ops.h" + +#include "tensorflow/core/data/compression_utils.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +CompressElementOp::CompressElementOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + +void CompressElementOp::Compute(OpKernelContext* ctx) { + std::vector components; + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + components.push_back(ctx->input(i)); + } + CompressedElement compressed; + OP_REQUIRES_OK(ctx, CompressElement(components, &compressed)); + + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = std::move(compressed); +} + +UncompressElementOp::UncompressElementOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); +} + +void UncompressElementOp::Compute(OpKernelContext* ctx) { + Tensor tensor = ctx->input(0); + const Variant& variant = tensor.scalar()(); + const CompressedElement* compressed = variant.get(); + + std::vector components; + OP_REQUIRES_OK(ctx, UncompressElement(*compressed, &components)); + OP_REQUIRES(ctx, components.size() == output_types_.size(), + errors::FailedPrecondition("Expected ", output_types_.size(), + " outputs from uncompress, but got ", + components.size())); + for (int i = 0; i < components.size(); ++i) { + OP_REQUIRES( + ctx, components[i].dtype() == output_types_[i], + errors::FailedPrecondition("Expected a tensor of type ", + DataTypeString(output_types_[i]), + " but got a tensor of type ", + DataTypeString(components[i].dtype()))); + ctx->set_output(i, components[i]); + } +} + +REGISTER_KERNEL_BUILDER(Name("CompressElement").Device(DEVICE_CPU), + CompressElementOp); +REGISTER_KERNEL_BUILDER(Name("UncompressElement").Device(DEVICE_CPU), + UncompressElementOp); + +} // namespace experimental +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/compression_ops.h b/tensorflow/core/kernels/data/experimental/compression_ops.h new file mode 100644 index 00000000000..6dd89ea4e5d --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/compression_ops.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class CompressElementOp : public OpKernel { + public: + explicit CompressElementOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; +}; + +class UncompressElementOp : public OpKernel { + public: + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit UncompressElementOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 0122cbed087..6a633fb679d 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -731,42 +731,19 @@ REGISTER_OP("OneShotIterator") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); -namespace { - -Status IteratorGetNextShapeFn(shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast(i), output_shape_handle); - } - return Status::OK(); -} - -} // namespace - REGISTER_OP("IteratorGetNext") .Input("iterator: resource") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); REGISTER_OP("IteratorGetNextSync") .Input("iterator: resource") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); // TODO(b/124308596): Instead of conservatively marking this op as stateful, // implement a mechanism to determine whether `dataset` has a side-effect @@ -778,7 +755,7 @@ REGISTER_OP("DatasetToSingleElement") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetIsStateful() - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); // TODO(b/124308596): Instead of conservatively marking this op as stateful, // implement a mechanism to determine whether `dataset` has a side-effect @@ -796,7 +773,7 @@ REGISTER_OP("ReduceDataset") .Attr("output_shapes: list(shape) >= 1") .Attr("use_inter_op_parallelism: bool = true") .SetIsStateful() - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); REGISTER_OP("IteratorToStringHandle") .Input("resource_handle: resource") @@ -875,7 +852,7 @@ REGISTER_OP("OptionalGetValue") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); REGISTER_OP("IteratorGetNextAsOptional") .Input("iterator: resource") @@ -992,7 +969,7 @@ REGISTER_OP("MultiDeviceIteratorGetNextFromShard") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); REGISTER_OP("MultiDeviceIteratorToStringHandle") .Input("multi_device_iterator: resource") diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 2c9cbe2f416..aa4bd64270a 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -132,6 +132,19 @@ REGISTER_OP("ExperimentalChooseFastestDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("CompressElement") + .Input("components: input_types") + .Output("compressed: variant") + .Attr("input_types: list(type) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("UncompressElement") + .Input("compressed: variant") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::DatasetIteratorShape); + REGISTER_OP("CSVDataset") .Input("filenames: string") .Input("compression_type: string") diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index d5d6cb00733..1d5abb9871b 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load package( default_visibility = ["//tensorflow:internal"], @@ -87,6 +87,17 @@ tf_py_test( ], ) +tf_py_test( + name = "compression_ops_test", + srcs = ["compression_ops_test.py"], + deps = [ + "//tensorflow/python/data/experimental/ops:compression_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "copy_to_device_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py new file mode 100644 index 00000000000..a091bdca8b9 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py @@ -0,0 +1,81 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for compression ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import compression_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure +from tensorflow.python.framework import combinations +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +def _test_objects(): + return [ + combinations.NamedObject("int", 1), + combinations.NamedObject("string", "dog"), + combinations.NamedObject("tuple", (1, 1)), + combinations.NamedObject("int_string_tuple", (1, "dog")), + combinations.NamedObject( + "sparse", + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])), + combinations.NamedObject( + "sparse_structured", { + "a": + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], + values=[1, 2], + dense_shape=[3, 4]), + "b": (1, 2, "dog") + }) + ] + + +class CompressionOpsTest(test_base.DatasetTestBase, parameterized.TestCase): + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(element=_test_objects()))) + def testCompression(self, element): + element = element._obj + + compressed = compression_ops.compress(element) + uncompressed = compression_ops.uncompress( + compressed, structure.type_spec_from_value(element)) + self.assertValuesEqual(element, self.evaluate(uncompressed)) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(element=_test_objects()))) + def testDatasetCompression(self, element): + element = element._obj + + dataset = dataset_ops.Dataset.from_tensors(element) + element_spec = dataset.element_spec + + dataset = dataset.map(lambda *x: compression_ops.compress(x)) + dataset = dataset.map(lambda x: compression_ops.uncompress(x, element_spec)) + self.assertDatasetProduces(dataset, [element]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index 50d095e46f6..2adf2a6362d 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -33,6 +33,15 @@ py_library( ], ) +py_library( + name = "compression_ops", + srcs = ["compression_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:experimental_dataset_ops_gen", + ], +) + py_library( name = "counter", srcs = ["counter.py"], @@ -475,6 +484,7 @@ py_library( deps = [ ":batching", ":cardinality", + ":compression_ops", ":counter", ":data_service_ops", ":distribute", diff --git a/tensorflow/python/data/experimental/ops/compression_ops.py b/tensorflow/python/data/experimental/ops/compression_ops.py new file mode 100644 index 00000000000..1ef7c8b3f01 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/compression_ops.py @@ -0,0 +1,55 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for compressing and uncompressing dataset elements.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.util import structure +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops + + +def compress(element): + """Compress a dataset element. + + Args: + element: A nested structure of types supported by Tensorflow. + + Returns: + A variant tensor representing the compressed element. This variant can be + passed to `uncompress` to get back the original element. + """ + element_spec = structure.type_spec_from_value(element) + tensor_list = structure.to_tensor_list(element_spec, element) + return ged_ops.compress_element(tensor_list) + + +def uncompress(element, output_spec): + """Uncompress a compressed dataset element. + + Args: + element: A scalar variant tensor to uncompress. The element should have been + created by calling `compress`. + output_spec: A nested structure of `tf.TypeSpec` representing the type(s) of + the uncompressed element. + + Returns: + The uncompressed element. + """ + flat_types = structure.get_flat_tensor_types(output_spec) + flat_shapes = structure.get_flat_tensor_shapes(output_spec) + tensor_list = ged_ops.uncompress_element( + element, output_types=flat_types, output_shapes=flat_shapes) + return structure.from_tensor_list(output_spec, tensor_list) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index f798ebf25fd..3db327300a9 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -736,6 +736,10 @@ tf_module { name: "ComplexAbs" argspec: "args=[\'x\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } + member_method { + name: "CompressElement" + argspec: "args=[\'components\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "ComputeAccidentalHits" argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " @@ -4956,6 +4960,10 @@ tf_module { name: "UnbatchGrad" argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], " } + member_method { + name: "UncompressElement" + argspec: "args=[\'compressed\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "UnicodeDecode" argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"\", \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index f798ebf25fd..3db327300a9 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -736,6 +736,10 @@ tf_module { name: "ComplexAbs" argspec: "args=[\'x\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } + member_method { + name: "CompressElement" + argspec: "args=[\'components\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "ComputeAccidentalHits" argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " @@ -4956,6 +4960,10 @@ tf_module { name: "UnbatchGrad" argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], " } + member_method { + name: "UncompressElement" + argspec: "args=[\'compressed\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "UnicodeDecode" argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"\", \'None\'], "