From 332adf10ab7e2b575896edfad2d03570c6863ba6 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Tue, 21 Jul 2020 14:36:45 -0700 Subject: [PATCH] [tf.data] Add an op that computes the statically-known batch size of a dataset where possible. PiperOrigin-RevId: 322444756 Change-Id: Ib628e9463bf34e5a607fd023d7c065389fff599c --- .../base_api/api_def_ComputeBatchSize.pbtxt | 5 + tensorflow/core/kernels/data/BUILD | 1 + tensorflow/core/kernels/data/dataset_utils.cc | 7 + tensorflow/core/kernels/data/dataset_utils.h | 8 + .../core/kernels/data/experimental/BUILD | 22 +- .../experimental/assert_next_dataset_op.cc | 15 +- .../experimental/compute_batch_size_op.cc | 191 ++++++++++++++++++ tensorflow/core/kernels/data/rewrite_utils.cc | 18 +- .../core/kernels/data/serialization_utils.cc | 23 +++ .../core/kernels/data/serialization_utils.h | 9 + .../core/ops/experimental_dataset_ops.cc | 5 + .../kernel_tests/rebatch_dataset_test.py | 54 +++++ .../data/experimental/ops/distribute.py | 52 +++++ .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 15 files changed, 392 insertions(+), 26 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_ComputeBatchSize.pbtxt create mode 100644 tensorflow/core/kernels/data/experimental/compute_batch_size_op.cc diff --git a/tensorflow/core/api_def/base_api/api_def_ComputeBatchSize.pbtxt b/tensorflow/core/api_def/base_api/api_def_ComputeBatchSize.pbtxt new file mode 100644 index 00000000000..b92d02e256d --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ComputeBatchSize.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "ComputeBatchSize" + visibility: HIDDEN + summary: "Computes the static batch size of a dataset sans partial batches." +} diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index efce4fb0cf5..f0a58f3cdfe 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -54,6 +54,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index 5f0068445a9..1bd29638df6 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { @@ -898,5 +899,11 @@ std::string DeterminismPolicy::String() const { } } +bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match) { + // Matches all versions of an op by appending an optional version suffix + auto expected_re = strings::StrCat(RE2::QuoteMeta(op_prefix), "(V\\d+)?"); + return RE2::FullMatch(op_to_match, expected_re); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 9a7e274714a..5c6b14a8782 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -296,6 +296,14 @@ class DummyResourceOp : public OpKernel { } }; +// Given an op prefix and an op to match, returns whether the op to match +// is a regex match for any version of the op prefix. For example, +// MatchesAnyVersionRE("BatchDataset", "BatchDataset") == true +// MatchesAnyVersionRE("BatchDataset", "BatchDatasetV2") == true +// MatchesAnyVersionRE("BatchDataset", "BatchDatasetV3") == true +// MatchesAnyVersionRE("PaddedBatchDataset", "BatchDataset") == false +bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 35446bdfbea..56220b7bd85 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -32,7 +32,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:experimental_dataset_ops_op_lib", "//tensorflow/core:framework", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:name_utils", ], ) @@ -124,6 +124,25 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "compute_batch_size_op", + srcs = ["compute_batch_size_op.cc"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:regexp_internal", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler/optimizers/data:graph_utils", + "//tensorflow/core/kernels/data:dataset_utils", + "//tensorflow/core/kernels/data:name_utils", + "//tensorflow/core/kernels/data:serialization_utils", + ], +) + tf_kernel_library( name = "csv_dataset_op", srcs = ["csv_dataset_op.cc"], @@ -736,6 +755,7 @@ tf_kernel_library( ":choose_fastest_branch_dataset_op", ":choose_fastest_dataset_op", ":compression_ops", + ":compute_batch_size_op", ":csv_dataset_op", ":dense_to_sparse_batch_dataset_op", ":directed_interleave_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc index adda54a0cd9..cb8dc67d6dd 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/name_utils.h" -#include "tensorflow/core/platform/regexp.h" namespace tensorflow { namespace data { @@ -97,15 +97,12 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { } int n = tokens.size(); for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { - std::string transformation_escaped = - RE2::QuoteMeta(dataset()->transformations_[i]); - std::string version_suffix = "(V\\d+)?"; - std::string expected_re = - absl::StrCat(transformation_escaped, version_suffix); - if (!RE2::FullMatch(tokens[n - 2 - i], expected_re)) { + if (!MatchesAnyVersionRE(dataset()->transformations_[i], + tokens[n - 2 - i])) { return errors::InvalidArgument("Asserted transformation matching ", - expected_re, " at offset ", i, - " but encountered ", tokens[n - 2 - i], + dataset()->transformations_[i], + " at offset ", i, " but encountered ", + tokens[n - 2 - i], " transformation instead."); } } diff --git a/tensorflow/core/kernels/data/experimental/compute_batch_size_op.cc b/tensorflow/core/kernels/data/experimental/compute_batch_size_op.cc new file mode 100644 index 00000000000..1c4c5dea248 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/compute_batch_size_op.cc @@ -0,0 +1,191 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/name_utils.h" +#include "tensorflow/core/kernels/data/serialization_utils.h" +#include "tensorflow/core/platform/stringprintf.h" + +namespace tensorflow { +namespace data { +namespace experimental { +namespace { + +using grappler::graph_utils::GetScalarConstNodeValue; + +constexpr char kMapAndBatchOp[] = "MapAndBatchDataset"; +constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset"; + +constexpr std::array kBatchDatasetOps = { + "BatchDataset", + "PaddedBatchDataset", + kMapAndBatchOp, + kExperimentalMapAndBatchOp, +}; + +constexpr std::array kMultipleInputDatasetOps = { + "ConcatenateDataset", + "ZipDataset", +}; + +constexpr std::array kPassThroughOps = { + "AssertCardinalityDataset", + "CacheDataset", + "FilterDataset", + "Identity", + "ModelDataset", + "OptimizeDataset", + "ParseExampleDataset", + "PrefetchDataset", + "RepeatDataset", + "ShardDataset", + "ShuffleAndRepeatDataset", + "ShuffleDataset", + "SkipDataset", + "TakeDataset", +}; + +template +bool IsDatasetNodeOfType(const NodeDef& node, + const std::array& arr) { + for (const auto& dataset_op : arr) { + if (MatchesAnyVersionRE(dataset_op, node.op())) return true; + } + return false; +} + +const NodeDef* GetInputNode(const NodeDef& node, + const grappler::GraphView& graph, + int64 input_index) { + if (node.input_size() == 0) return nullptr; + grappler::GraphView::InputPort input_port = + graph.GetInputPort(node.name(), input_index); + return graph.GetRegularFanin(input_port).node; +} + +// TODO(rachelim): This op traverses the dataset graph using a allowlist-based +// approach. As an alternative, we could instead rewrite all batching datasets' +// drop_remainder parameter to True, then rerun the dataset graph to derive +// new output shapes using C++ shape inference. This is more robust in cases +// where datasets have shape inference implemented in C++. If this allowlist- +// based approach proves hard to maintain, consider doing the alternative. +class ComputeBatchSizeOp : public OpKernel { + public: + explicit ComputeBatchSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + DatasetBase* dataset; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + + std::vector> input_list; + GraphDef graph_def; + string dataset_node_name; + OP_REQUIRES_OK(ctx, AsGraphDefMinimal(ctx, dataset, &input_list, &graph_def, + &dataset_node_name)); + + // Create GraphView for easier traversal of graph. + grappler::GraphView graph_view(&graph_def); + + const NodeDef* node = graph_view.GetNode(dataset_node_name); + OP_REQUIRES(ctx, node != nullptr, + errors::InvalidArgument("Node does not exist in graph")); + int64 batch_size = GetBatchSize(*node, graph_view); + Tensor* result; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result)); + result->scalar()() = batch_size; + } + + private: + int64 GetBatchSizeFromBatchNode(const NodeDef& node, + const grappler::GraphView& graph) { + int64 arg_index; + if (node.op() == kMapAndBatchOp || + node.op() == kExperimentalMapAndBatchOp) { + arg_index = node.input_size() - 3; + } else { + arg_index = 1; + } + + auto batch_size_node = GetInputNode(node, graph, arg_index); + int64 batch_size; + auto s = GetScalarConstNodeValue(*batch_size_node, &batch_size); + if (!s.ok()) { + VLOG(1) << "Could not compute static batch size. Found batching dataset (" + << node.name() << "), but failed to get its input batch size: " + << s.error_message(); + return -1; + } + return batch_size; + } + + // Helper function that returns the static 0th dimension of a given dataset + // node in the graph. It starts from a node in the graph and recursively + // traverses its inputs until it finds a valid BatchDataset operation, + // and returns its batch size. If the batch size cannot be determined, + // returns -1. + // + // During recursion, it handles four kinds of cases: + // 1. BatchDataset type ops: Returns the value from its batch_size input node. + // 2. Zip / Concatenate dataset ops: Recurses into all inputs to these ops, + // which are themselves all datasets, and returns the batch sizes computed + // by the inputs if they are all the same. + // 3. Core dataset ops which cannot change the size of the 0th dimension of + // dataset output elements: Recurses into the first input parameter. + // 4. All other ops: Fail, returning -1 for unknown. + // TODO(rachelim): For FlatMap type mapping dataset ops, recurse into the + // function definition. + int64 GetBatchSize(const NodeDef& node, const grappler::GraphView& graph) { + if (IsDatasetNodeOfType(node, kBatchDatasetOps)) { + return GetBatchSizeFromBatchNode(node, graph); + } + if (IsDatasetNodeOfType(node, kMultipleInputDatasetOps)) { + const NodeDef* input_0 = GetInputNode(node, graph, 0); + int64 batch_size_0 = GetBatchSize(*input_0, graph); + for (int i = 1; i < node.input_size(); ++i) { + const NodeDef* input = GetInputNode(node, graph, i); + auto batch_size_i = GetBatchSize(*input, graph); + if (batch_size_i != batch_size_0) { + VLOG(1) << "Could not compute batch size: inputs to " << node.name() + << " (" << node.op() << ") had different batch sizes." + << " Namely, input 0 had batch size " << batch_size_0 + << " while input " << i << " had batch size " << batch_size_i + << "."; + return -1; + } + } + return batch_size_0; + } + if (IsDatasetNodeOfType(node, kPassThroughOps)) { + const NodeDef* input = GetInputNode(node, graph, 0); + return GetBatchSize(*input, graph); + } + VLOG(1) << "Encountered dataset node " << node.name() << " (" << node.op() + << ") that prevented further static batch size analysis."; + + return -1; + } +}; + +REGISTER_KERNEL_BUILDER(Name("ComputeBatchSize").Device(DEVICE_CPU), + ComputeBatchSizeOp); + +} // anonymous namespace +} // namespace experimental +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/rewrite_utils.cc b/tensorflow/core/kernels/data/rewrite_utils.cc index 0ea708abbc7..dd9bfdb5143 100644 --- a/tensorflow/core/kernels/data/rewrite_utils.cc +++ b/tensorflow/core/kernels/data/rewrite_utils.cc @@ -144,25 +144,11 @@ Status ApplyRewrites(OpKernelContext* ctx, Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, std::function config_factory, bool record_fingerprint, DatasetBase** rewritten_input) { - SerializationContext::Params params; std::vector> input_list; - params.input_list = &input_list; - params.external_state_policy = - SerializationContext::ExternalStatePolicy::kIgnore; - params.fail_if_unimplemented = false; - params.serialize_data_tensors = false; - params.preserve_random_seeds = false; - SerializationContext serialization_ctx(params); GraphDef graph_def; - TF_RETURN_IF_ERROR( - AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def)); - string output_node; - for (const auto& node : graph_def.node()) { - if (node.op() == "_Retval") { - output_node = node.input(0); - } - } + TF_RETURN_IF_ERROR( + AsGraphDefMinimal(ctx, input, &input_list, &graph_def, &output_node)); VLOG(3) << "Before graph rewrites: " << graph_def.DebugString(); TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/serialization_utils.cc b/tensorflow/core/kernels/data/serialization_utils.cc index 5965c9b3295..628d6952c6d 100644 --- a/tensorflow/core/kernels/data/serialization_utils.cc +++ b/tensorflow/core/kernels/data/serialization_utils.cc @@ -53,6 +53,29 @@ Status FindStatefulOps(const GraphDef& graph_def, } // namespace +Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input, + std::vector>* input_list, + GraphDef* result, string* dataset_node) { + SerializationContext::Params params; + params.input_list = input_list; + params.external_state_policy = + SerializationContext::ExternalStatePolicy::kIgnore; + params.fail_if_unimplemented = false; + params.serialize_data_tensors = false; + params.preserve_random_seeds = false; + SerializationContext serialization_ctx(params); + TF_RETURN_IF_ERROR( + AsGraphDef(ctx, input, std::move(serialization_ctx), result)); + + // Symbolic `_Retval` node indicates which node corresponds to the dataset. + for (const auto& node : result->node()) { + if (node.op() == "_Retval") { + *dataset_node = node.input(0); + } + } + return Status::OK(); +} + Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset, SerializationContext&& serialization_ctx, GraphDef* graph_def) { diff --git a/tensorflow/core/kernels/data/serialization_utils.h b/tensorflow/core/kernels/data/serialization_utils.h index 2e580ec7fdc..5702919b556 100644 --- a/tensorflow/core/kernels/data/serialization_utils.h +++ b/tensorflow/core/kernels/data/serialization_utils.h @@ -27,6 +27,15 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset, SerializationContext&& serialization_ctx, GraphDef* graph_def); +// Returns a GraphDef representation of the given dataset using the minimal +// serialization parameters (i.e. ignoring external state, not serializing +// data tensors, not failing if there are datasets which do not have AsGraphDef +// implemented). Sets the `dataset_node` parameter to the dataset's +// node name in the resulting GraphDef. +Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input, + std::vector>* input_list, + GraphDef* result, string* dataset_node); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 5e869a2f0be..2d4b2f43746 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -145,6 +145,11 @@ REGISTER_OP("UncompressElement") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::DatasetIteratorShape); +REGISTER_OP("ComputeBatchSize") + .Input("input_dataset : variant") + .Output("batch_size : int64") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("CSVDataset") .Input("filenames: string") .Input("compression_type: string") diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py index 841c25b6856..c9d0d14dead 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -230,5 +230,59 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): _ = distribute._RebatchDataset(dataset, num_replicas=2) +class ComputeBatchSizeTest(test_base.DatasetTestBase, parameterized.TestCase): + + @combinations.generate(test_base.default_test_combinations()) + def testComputeBatchSizeKnown(self): + # When drop_remainder=True, batch size can be inferred from the type spec. + dataset = dataset_ops.Dataset.range(32).batch(4, drop_remainder=True) + dataset = dataset_ops.Dataset.zip((dataset, dataset)) + batch_size = distribute.compute_batch_size(dataset) + self.assertEqual(4, self.evaluate(batch_size)) + + @combinations.generate(test_base.default_test_combinations()) + def testComputeBatchSizeKnownAndMismatched(self): + # Return -1 when different components have different batch sizes. + dataset = dataset_ops.Dataset.range(32) + dataset = dataset_ops.Dataset.zip((dataset.batch(4, drop_remainder=True), + dataset.batch(8, drop_remainder=True))) + batch_size = distribute.compute_batch_size(dataset) + self.assertEqual(-1, self.evaluate(batch_size)) + + @combinations.generate(test_base.default_test_combinations()) + def testComputeBatchSizeUnknown(self): + dataset = dataset_ops.Dataset.range(32).batch(4) + batch_size = distribute.compute_batch_size(dataset) + self.assertEqual(4, self.evaluate(batch_size)) + + @combinations.generate(test_base.default_test_combinations()) + def testComputeBatchSizeWithPassthrough(self): + dataset = dataset_ops.Dataset.range(32).batch(4) + dataset = dataset.take(5) + batch_size = distribute.compute_batch_size(dataset) + self.assertEqual(4, self.evaluate(batch_size)) + + @combinations.generate(test_base.default_test_combinations()) + def testComputeBatchSizeWithPassthroughInvalid(self): + dataset = dataset_ops.Dataset.range(32).batch(4) + dataset = dataset.map(lambda x: x + 1) + batch_size = distribute.compute_batch_size(dataset) + self.assertEqual(-1, self.evaluate(batch_size)) + + @combinations.generate(test_base.default_test_combinations()) + def testComputeBatchSizeWithZip(self): + dataset = dataset_ops.Dataset.range(32).batch(4) + dataset = dataset_ops.Dataset.zip((dataset, dataset)) + batch_size = distribute.compute_batch_size(dataset) + self.assertEqual(4, self.evaluate(batch_size)) + + @combinations.generate(test_base.default_test_combinations()) + def testComputeBatchSizeWithZipMismatched(self): + dataset = dataset_ops.Dataset.range(32) + dataset = dataset_ops.Dataset.zip((dataset.batch(4), dataset.batch(8))) + batch_size = distribute.compute_batch_size(dataset) + self.assertEqual(-1, self.evaluate(batch_size)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index ae3c13ecc97..9f274201e78 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -20,6 +20,8 @@ from __future__ import print_function from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops @@ -169,4 +171,54 @@ def replicate(dataset, devices): return datasets +def compute_batch_size(dataset): + """An operation that returns the batch size of the dataset. + + This op tries to infer the batch size statically by walking up the dataset + tree from the final dataset node and returning the batch size of the first + batching dataset (such as from .batch() and .padded_batch()) that it + encounters. This differs from using the `element_spec` of a dataset in that it + does not account for partial batches. + + This operation may fail if it encounters contradictory batch sizes (for + example, if the dataset is created by zipping together two datasets with + different batch sizes), if there are no explicit batching transformations, or + if there are operations downstream from the batching transformation that may + modify its batch size. In these cases, it returns a -1. + + Args: + dataset: A `tf.data.Dataset` object. + + Returns: + A `tf.int64` Tensor representing the batch size of the dataset sans partial + batches. If this cannot be inferred statically, the value of this tensor + will be -1. + """ + + def get_static_batch_dim(output_shape): + if output_shape.rank is None: + return None + return output_shape.dims[0].value + + batch_dims = [ + get_static_batch_dim(ts._to_legacy_output_shapes()) # pylint: disable=protected-access + for ts in nest.flatten(dataset_ops.get_structure(dataset)) + ] + + if all(d is not None for d in batch_dims): + + if all(d == batch_dims[0] for d in batch_dims): + # If all batch dimensions are known and equal, return that directly. + batch_dim = batch_dims[0] + else: + # If all batch dimensions are known but not all equal, return -1. + batch_dim = -1 + + return constant_op.constant( + batch_dim, dtype=dtypes.int64, name="static_batch_size") + + # If any batch dimensions are unknown, use compute_batch_size op. + return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access + + _AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__ diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 0db6da3dad2..c597bc2f8f1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -792,6 +792,10 @@ tf_module { name: "ComputeAccidentalHits" argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " } + member_method { + name: "ComputeBatchSize" + argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Concat" argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 0db6da3dad2..c597bc2f8f1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -792,6 +792,10 @@ tf_module { name: "ComputeAccidentalHits" argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " } + member_method { + name: "ComputeBatchSize" + argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "Concat" argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "