[tf.data] Add an op that computes the statically-known batch size of a dataset where possible.
PiperOrigin-RevId: 322444756 Change-Id: Ib628e9463bf34e5a607fd023d7c065389fff599c
This commit is contained in:
parent
be88c5f8a7
commit
332adf10ab
tensorflow
core
api_def/base_api
kernels/data
BUILDdataset_utils.ccdataset_utils.h
experimental
rewrite_utils.ccserialization_utils.ccserialization_utils.hops
python/data/experimental
tools/api/golden
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "ComputeBatchSize"
|
||||
visibility: HIDDEN
|
||||
summary: "Computes the static batch size of a dataset sans partial batches."
|
||||
}
|
@ -54,6 +54,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:regexp",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -898,5 +899,11 @@ std::string DeterminismPolicy::String() const {
|
||||
}
|
||||
}
|
||||
|
||||
bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match) {
|
||||
// Matches all versions of an op by appending an optional version suffix
|
||||
auto expected_re = strings::StrCat(RE2::QuoteMeta(op_prefix), "(V\\d+)?");
|
||||
return RE2::FullMatch(op_to_match, expected_re);
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -296,6 +296,14 @@ class DummyResourceOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
// Given an op prefix and an op to match, returns whether the op to match
|
||||
// is a regex match for any version of the op prefix. For example,
|
||||
// MatchesAnyVersionRE("BatchDataset", "BatchDataset") == true
|
||||
// MatchesAnyVersionRE("BatchDataset", "BatchDatasetV2") == true
|
||||
// MatchesAnyVersionRE("BatchDataset", "BatchDatasetV3") == true
|
||||
// MatchesAnyVersionRE("PaddedBatchDataset", "BatchDataset") == false
|
||||
bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -32,7 +32,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
],
|
||||
)
|
||||
@ -124,6 +124,25 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "compute_batch_size_op",
|
||||
srcs = ["compute_batch_size_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler/optimizers/data:graph_utils",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
"//tensorflow/core/kernels/data:serialization_utils",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "csv_dataset_op",
|
||||
srcs = ["csv_dataset_op.cc"],
|
||||
@ -736,6 +755,7 @@ tf_kernel_library(
|
||||
":choose_fastest_branch_dataset_op",
|
||||
":choose_fastest_dataset_op",
|
||||
":compression_ops",
|
||||
":compute_batch_size_op",
|
||||
":csv_dataset_op",
|
||||
":dense_to_sparse_batch_dataset_op",
|
||||
":directed_interleave_dataset_op",
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -97,15 +97,12 @@ class AssertNextDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
int n = tokens.size();
|
||||
for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
|
||||
std::string transformation_escaped =
|
||||
RE2::QuoteMeta(dataset()->transformations_[i]);
|
||||
std::string version_suffix = "(V\\d+)?";
|
||||
std::string expected_re =
|
||||
absl::StrCat(transformation_escaped, version_suffix);
|
||||
if (!RE2::FullMatch(tokens[n - 2 - i], expected_re)) {
|
||||
if (!MatchesAnyVersionRE(dataset()->transformations_[i],
|
||||
tokens[n - 2 - i])) {
|
||||
return errors::InvalidArgument("Asserted transformation matching ",
|
||||
expected_re, " at offset ", i,
|
||||
" but encountered ", tokens[n - 2 - i],
|
||||
dataset()->transformations_[i],
|
||||
" at offset ", i, " but encountered ",
|
||||
tokens[n - 2 - i],
|
||||
" transformation instead.");
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,191 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/serialization_utils.h"
|
||||
#include "tensorflow/core/platform/stringprintf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace experimental {
|
||||
namespace {
|
||||
|
||||
using grappler::graph_utils::GetScalarConstNodeValue;
|
||||
|
||||
constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
|
||||
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
|
||||
|
||||
constexpr std::array<const char*, 4> kBatchDatasetOps = {
|
||||
"BatchDataset",
|
||||
"PaddedBatchDataset",
|
||||
kMapAndBatchOp,
|
||||
kExperimentalMapAndBatchOp,
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 2> kMultipleInputDatasetOps = {
|
||||
"ConcatenateDataset",
|
||||
"ZipDataset",
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 14> kPassThroughOps = {
|
||||
"AssertCardinalityDataset",
|
||||
"CacheDataset",
|
||||
"FilterDataset",
|
||||
"Identity",
|
||||
"ModelDataset",
|
||||
"OptimizeDataset",
|
||||
"ParseExampleDataset",
|
||||
"PrefetchDataset",
|
||||
"RepeatDataset",
|
||||
"ShardDataset",
|
||||
"ShuffleAndRepeatDataset",
|
||||
"ShuffleDataset",
|
||||
"SkipDataset",
|
||||
"TakeDataset",
|
||||
};
|
||||
|
||||
template <std::size_t SIZE>
|
||||
bool IsDatasetNodeOfType(const NodeDef& node,
|
||||
const std::array<const char*, SIZE>& arr) {
|
||||
for (const auto& dataset_op : arr) {
|
||||
if (MatchesAnyVersionRE(dataset_op, node.op())) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const NodeDef* GetInputNode(const NodeDef& node,
|
||||
const grappler::GraphView& graph,
|
||||
int64 input_index) {
|
||||
if (node.input_size() == 0) return nullptr;
|
||||
grappler::GraphView::InputPort input_port =
|
||||
graph.GetInputPort(node.name(), input_index);
|
||||
return graph.GetRegularFanin(input_port).node;
|
||||
}
|
||||
|
||||
// TODO(rachelim): This op traverses the dataset graph using a allowlist-based
|
||||
// approach. As an alternative, we could instead rewrite all batching datasets'
|
||||
// drop_remainder parameter to True, then rerun the dataset graph to derive
|
||||
// new output shapes using C++ shape inference. This is more robust in cases
|
||||
// where datasets have shape inference implemented in C++. If this allowlist-
|
||||
// based approach proves hard to maintain, consider doing the alternative.
|
||||
class ComputeBatchSizeOp : public OpKernel {
|
||||
public:
|
||||
explicit ComputeBatchSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
|
||||
std::vector<std::pair<string, Tensor>> input_list;
|
||||
GraphDef graph_def;
|
||||
string dataset_node_name;
|
||||
OP_REQUIRES_OK(ctx, AsGraphDefMinimal(ctx, dataset, &input_list, &graph_def,
|
||||
&dataset_node_name));
|
||||
|
||||
// Create GraphView for easier traversal of graph.
|
||||
grappler::GraphView graph_view(&graph_def);
|
||||
|
||||
const NodeDef* node = graph_view.GetNode(dataset_node_name);
|
||||
OP_REQUIRES(ctx, node != nullptr,
|
||||
errors::InvalidArgument("Node does not exist in graph"));
|
||||
int64 batch_size = GetBatchSize(*node, graph_view);
|
||||
Tensor* result;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
|
||||
result->scalar<int64>()() = batch_size;
|
||||
}
|
||||
|
||||
private:
|
||||
int64 GetBatchSizeFromBatchNode(const NodeDef& node,
|
||||
const grappler::GraphView& graph) {
|
||||
int64 arg_index;
|
||||
if (node.op() == kMapAndBatchOp ||
|
||||
node.op() == kExperimentalMapAndBatchOp) {
|
||||
arg_index = node.input_size() - 3;
|
||||
} else {
|
||||
arg_index = 1;
|
||||
}
|
||||
|
||||
auto batch_size_node = GetInputNode(node, graph, arg_index);
|
||||
int64 batch_size;
|
||||
auto s = GetScalarConstNodeValue(*batch_size_node, &batch_size);
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Could not compute static batch size. Found batching dataset ("
|
||||
<< node.name() << "), but failed to get its input batch size: "
|
||||
<< s.error_message();
|
||||
return -1;
|
||||
}
|
||||
return batch_size;
|
||||
}
|
||||
|
||||
// Helper function that returns the static 0th dimension of a given dataset
|
||||
// node in the graph. It starts from a node in the graph and recursively
|
||||
// traverses its inputs until it finds a valid BatchDataset operation,
|
||||
// and returns its batch size. If the batch size cannot be determined,
|
||||
// returns -1.
|
||||
//
|
||||
// During recursion, it handles four kinds of cases:
|
||||
// 1. BatchDataset type ops: Returns the value from its batch_size input node.
|
||||
// 2. Zip / Concatenate dataset ops: Recurses into all inputs to these ops,
|
||||
// which are themselves all datasets, and returns the batch sizes computed
|
||||
// by the inputs if they are all the same.
|
||||
// 3. Core dataset ops which cannot change the size of the 0th dimension of
|
||||
// dataset output elements: Recurses into the first input parameter.
|
||||
// 4. All other ops: Fail, returning -1 for unknown.
|
||||
// TODO(rachelim): For FlatMap type mapping dataset ops, recurse into the
|
||||
// function definition.
|
||||
int64 GetBatchSize(const NodeDef& node, const grappler::GraphView& graph) {
|
||||
if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
|
||||
return GetBatchSizeFromBatchNode(node, graph);
|
||||
}
|
||||
if (IsDatasetNodeOfType(node, kMultipleInputDatasetOps)) {
|
||||
const NodeDef* input_0 = GetInputNode(node, graph, 0);
|
||||
int64 batch_size_0 = GetBatchSize(*input_0, graph);
|
||||
for (int i = 1; i < node.input_size(); ++i) {
|
||||
const NodeDef* input = GetInputNode(node, graph, i);
|
||||
auto batch_size_i = GetBatchSize(*input, graph);
|
||||
if (batch_size_i != batch_size_0) {
|
||||
VLOG(1) << "Could not compute batch size: inputs to " << node.name()
|
||||
<< " (" << node.op() << ") had different batch sizes."
|
||||
<< " Namely, input 0 had batch size " << batch_size_0
|
||||
<< " while input " << i << " had batch size " << batch_size_i
|
||||
<< ".";
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return batch_size_0;
|
||||
}
|
||||
if (IsDatasetNodeOfType(node, kPassThroughOps)) {
|
||||
const NodeDef* input = GetInputNode(node, graph, 0);
|
||||
return GetBatchSize(*input, graph);
|
||||
}
|
||||
VLOG(1) << "Encountered dataset node " << node.name() << " (" << node.op()
|
||||
<< ") that prevented further static batch size analysis.";
|
||||
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ComputeBatchSize").Device(DEVICE_CPU),
|
||||
ComputeBatchSizeOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace experimental
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -144,25 +144,11 @@ Status ApplyRewrites(OpKernelContext* ctx,
|
||||
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::function<RewriterConfig(void)> config_factory,
|
||||
bool record_fingerprint, DatasetBase** rewritten_input) {
|
||||
SerializationContext::Params params;
|
||||
std::vector<std::pair<string, Tensor>> input_list;
|
||||
params.input_list = &input_list;
|
||||
params.external_state_policy =
|
||||
SerializationContext::ExternalStatePolicy::kIgnore;
|
||||
params.fail_if_unimplemented = false;
|
||||
params.serialize_data_tensors = false;
|
||||
params.preserve_random_seeds = false;
|
||||
SerializationContext serialization_ctx(params);
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
|
||||
|
||||
string output_node;
|
||||
for (const auto& node : graph_def.node()) {
|
||||
if (node.op() == "_Retval") {
|
||||
output_node = node.input(0);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AsGraphDefMinimal(ctx, input, &input_list, &graph_def, &output_node));
|
||||
|
||||
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -53,6 +53,29 @@ Status FindStatefulOps(const GraphDef& graph_def,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::vector<std::pair<string, Tensor>>* input_list,
|
||||
GraphDef* result, string* dataset_node) {
|
||||
SerializationContext::Params params;
|
||||
params.input_list = input_list;
|
||||
params.external_state_policy =
|
||||
SerializationContext::ExternalStatePolicy::kIgnore;
|
||||
params.fail_if_unimplemented = false;
|
||||
params.serialize_data_tensors = false;
|
||||
params.preserve_random_seeds = false;
|
||||
SerializationContext serialization_ctx(params);
|
||||
TF_RETURN_IF_ERROR(
|
||||
AsGraphDef(ctx, input, std::move(serialization_ctx), result));
|
||||
|
||||
// Symbolic `_Retval` node indicates which node corresponds to the dataset.
|
||||
for (const auto& node : result->node()) {
|
||||
if (node.op() == "_Retval") {
|
||||
*dataset_node = node.input(0);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||
SerializationContext&& serialization_ctx,
|
||||
GraphDef* graph_def) {
|
||||
|
@ -27,6 +27,15 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||
SerializationContext&& serialization_ctx,
|
||||
GraphDef* graph_def);
|
||||
|
||||
// Returns a GraphDef representation of the given dataset using the minimal
|
||||
// serialization parameters (i.e. ignoring external state, not serializing
|
||||
// data tensors, not failing if there are datasets which do not have AsGraphDef
|
||||
// implemented). Sets the `dataset_node` parameter to the dataset's
|
||||
// node name in the resulting GraphDef.
|
||||
Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input,
|
||||
std::vector<std::pair<string, Tensor>>* input_list,
|
||||
GraphDef* result, string* dataset_node);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -145,6 +145,11 @@ REGISTER_OP("UncompressElement")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::DatasetIteratorShape);
|
||||
|
||||
REGISTER_OP("ComputeBatchSize")
|
||||
.Input("input_dataset : variant")
|
||||
.Output("batch_size : int64")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("CSVDataset")
|
||||
.Input("filenames: string")
|
||||
.Input("compression_type: string")
|
||||
|
@ -230,5 +230,59 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
_ = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||
|
||||
|
||||
class ComputeBatchSizeTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testComputeBatchSizeKnown(self):
|
||||
# When drop_remainder=True, batch size can be inferred from the type spec.
|
||||
dataset = dataset_ops.Dataset.range(32).batch(4, drop_remainder=True)
|
||||
dataset = dataset_ops.Dataset.zip((dataset, dataset))
|
||||
batch_size = distribute.compute_batch_size(dataset)
|
||||
self.assertEqual(4, self.evaluate(batch_size))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testComputeBatchSizeKnownAndMismatched(self):
|
||||
# Return -1 when different components have different batch sizes.
|
||||
dataset = dataset_ops.Dataset.range(32)
|
||||
dataset = dataset_ops.Dataset.zip((dataset.batch(4, drop_remainder=True),
|
||||
dataset.batch(8, drop_remainder=True)))
|
||||
batch_size = distribute.compute_batch_size(dataset)
|
||||
self.assertEqual(-1, self.evaluate(batch_size))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testComputeBatchSizeUnknown(self):
|
||||
dataset = dataset_ops.Dataset.range(32).batch(4)
|
||||
batch_size = distribute.compute_batch_size(dataset)
|
||||
self.assertEqual(4, self.evaluate(batch_size))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testComputeBatchSizeWithPassthrough(self):
|
||||
dataset = dataset_ops.Dataset.range(32).batch(4)
|
||||
dataset = dataset.take(5)
|
||||
batch_size = distribute.compute_batch_size(dataset)
|
||||
self.assertEqual(4, self.evaluate(batch_size))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testComputeBatchSizeWithPassthroughInvalid(self):
|
||||
dataset = dataset_ops.Dataset.range(32).batch(4)
|
||||
dataset = dataset.map(lambda x: x + 1)
|
||||
batch_size = distribute.compute_batch_size(dataset)
|
||||
self.assertEqual(-1, self.evaluate(batch_size))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testComputeBatchSizeWithZip(self):
|
||||
dataset = dataset_ops.Dataset.range(32).batch(4)
|
||||
dataset = dataset_ops.Dataset.zip((dataset, dataset))
|
||||
batch_size = distribute.compute_batch_size(dataset)
|
||||
self.assertEqual(4, self.evaluate(batch_size))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testComputeBatchSizeWithZipMismatched(self):
|
||||
dataset = dataset_ops.Dataset.range(32)
|
||||
dataset = dataset_ops.Dataset.zip((dataset.batch(4), dataset.batch(8)))
|
||||
batch_size = distribute.compute_batch_size(dataset)
|
||||
self.assertEqual(-1, self.evaluate(batch_size))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
@ -169,4 +171,54 @@ def replicate(dataset, devices):
|
||||
return datasets
|
||||
|
||||
|
||||
def compute_batch_size(dataset):
|
||||
"""An operation that returns the batch size of the dataset.
|
||||
|
||||
This op tries to infer the batch size statically by walking up the dataset
|
||||
tree from the final dataset node and returning the batch size of the first
|
||||
batching dataset (such as from .batch() and .padded_batch()) that it
|
||||
encounters. This differs from using the `element_spec` of a dataset in that it
|
||||
does not account for partial batches.
|
||||
|
||||
This operation may fail if it encounters contradictory batch sizes (for
|
||||
example, if the dataset is created by zipping together two datasets with
|
||||
different batch sizes), if there are no explicit batching transformations, or
|
||||
if there are operations downstream from the batching transformation that may
|
||||
modify its batch size. In these cases, it returns a -1.
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` object.
|
||||
|
||||
Returns:
|
||||
A `tf.int64` Tensor representing the batch size of the dataset sans partial
|
||||
batches. If this cannot be inferred statically, the value of this tensor
|
||||
will be -1.
|
||||
"""
|
||||
|
||||
def get_static_batch_dim(output_shape):
|
||||
if output_shape.rank is None:
|
||||
return None
|
||||
return output_shape.dims[0].value
|
||||
|
||||
batch_dims = [
|
||||
get_static_batch_dim(ts._to_legacy_output_shapes()) # pylint: disable=protected-access
|
||||
for ts in nest.flatten(dataset_ops.get_structure(dataset))
|
||||
]
|
||||
|
||||
if all(d is not None for d in batch_dims):
|
||||
|
||||
if all(d == batch_dims[0] for d in batch_dims):
|
||||
# If all batch dimensions are known and equal, return that directly.
|
||||
batch_dim = batch_dims[0]
|
||||
else:
|
||||
# If all batch dimensions are known but not all equal, return -1.
|
||||
batch_dim = -1
|
||||
|
||||
return constant_op.constant(
|
||||
batch_dim, dtype=dtypes.int64, name="static_batch_size")
|
||||
|
||||
# If any batch dimensions are unknown, use compute_batch_size op.
|
||||
return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access
|
||||
|
||||
|
||||
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__
|
||||
|
@ -792,6 +792,10 @@ tf_module {
|
||||
name: "ComputeAccidentalHits"
|
||||
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ComputeBatchSize"
|
||||
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Concat"
|
||||
argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -792,6 +792,10 @@ tf_module {
|
||||
name: "ComputeAccidentalHits"
|
||||
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ComputeBatchSize"
|
||||
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Concat"
|
||||
argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user