[tf.data] Add an op that computes the statically-known batch size of a dataset where possible.

PiperOrigin-RevId: 322444756
Change-Id: Ib628e9463bf34e5a607fd023d7c065389fff599c
This commit is contained in:
Rachel Lim 2020-07-21 14:36:45 -07:00 committed by TensorFlower Gardener
parent be88c5f8a7
commit 332adf10ab
15 changed files with 392 additions and 26 deletions

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "ComputeBatchSize"
visibility: HIDDEN
summary: "Computes the static batch size of a dataset sans partial batches."
}

View File

@ -54,6 +54,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:regexp",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
],

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
@ -898,5 +899,11 @@ std::string DeterminismPolicy::String() const {
}
}
bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match) {
// Matches all versions of an op by appending an optional version suffix
auto expected_re = strings::StrCat(RE2::QuoteMeta(op_prefix), "(V\\d+)?");
return RE2::FullMatch(op_to_match, expected_re);
}
} // namespace data
} // namespace tensorflow

View File

@ -296,6 +296,14 @@ class DummyResourceOp : public OpKernel {
}
};
// Given an op prefix and an op to match, returns whether the op to match
// is a regex match for any version of the op prefix. For example,
// MatchesAnyVersionRE("BatchDataset", "BatchDataset") == true
// MatchesAnyVersionRE("BatchDataset", "BatchDatasetV2") == true
// MatchesAnyVersionRE("BatchDataset", "BatchDatasetV3") == true
// MatchesAnyVersionRE("PaddedBatchDataset", "BatchDataset") == false
bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match);
} // namespace data
} // namespace tensorflow

View File

@ -32,7 +32,7 @@ tf_kernel_library(
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:regexp_internal",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils",
],
)
@ -124,6 +124,25 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "compute_batch_size_op",
srcs = ["compute_batch_size_op.cc"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:regexp_internal",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils",
"//tensorflow/core/kernels/data:serialization_utils",
],
)
tf_kernel_library(
name = "csv_dataset_op",
srcs = ["csv_dataset_op.cc"],
@ -736,6 +755,7 @@ tf_kernel_library(
":choose_fastest_branch_dataset_op",
":choose_fastest_dataset_op",
":compression_ops",
":compute_batch_size_op",
":csv_dataset_op",
":dense_to_sparse_batch_dataset_op",
":directed_interleave_dataset_op",

View File

@ -18,8 +18,8 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
namespace data {
@ -97,15 +97,12 @@ class AssertNextDatasetOp::Dataset : public DatasetBase {
}
int n = tokens.size();
for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
std::string transformation_escaped =
RE2::QuoteMeta(dataset()->transformations_[i]);
std::string version_suffix = "(V\\d+)?";
std::string expected_re =
absl::StrCat(transformation_escaped, version_suffix);
if (!RE2::FullMatch(tokens[n - 2 - i], expected_re)) {
if (!MatchesAnyVersionRE(dataset()->transformations_[i],
tokens[n - 2 - i])) {
return errors::InvalidArgument("Asserted transformation matching ",
expected_re, " at offset ", i,
" but encountered ", tokens[n - 2 - i],
dataset()->transformations_[i],
" at offset ", i, " but encountered ",
tokens[n - 2 - i],
" transformation instead.");
}
}

View File

@ -0,0 +1,191 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/serialization_utils.h"
#include "tensorflow/core/platform/stringprintf.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
using grappler::graph_utils::GetScalarConstNodeValue;
constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
constexpr std::array<const char*, 4> kBatchDatasetOps = {
"BatchDataset",
"PaddedBatchDataset",
kMapAndBatchOp,
kExperimentalMapAndBatchOp,
};
constexpr std::array<const char*, 2> kMultipleInputDatasetOps = {
"ConcatenateDataset",
"ZipDataset",
};
constexpr std::array<const char*, 14> kPassThroughOps = {
"AssertCardinalityDataset",
"CacheDataset",
"FilterDataset",
"Identity",
"ModelDataset",
"OptimizeDataset",
"ParseExampleDataset",
"PrefetchDataset",
"RepeatDataset",
"ShardDataset",
"ShuffleAndRepeatDataset",
"ShuffleDataset",
"SkipDataset",
"TakeDataset",
};
template <std::size_t SIZE>
bool IsDatasetNodeOfType(const NodeDef& node,
const std::array<const char*, SIZE>& arr) {
for (const auto& dataset_op : arr) {
if (MatchesAnyVersionRE(dataset_op, node.op())) return true;
}
return false;
}
const NodeDef* GetInputNode(const NodeDef& node,
const grappler::GraphView& graph,
int64 input_index) {
if (node.input_size() == 0) return nullptr;
grappler::GraphView::InputPort input_port =
graph.GetInputPort(node.name(), input_index);
return graph.GetRegularFanin(input_port).node;
}
// TODO(rachelim): This op traverses the dataset graph using a allowlist-based
// approach. As an alternative, we could instead rewrite all batching datasets'
// drop_remainder parameter to True, then rerun the dataset graph to derive
// new output shapes using C++ shape inference. This is more robust in cases
// where datasets have shape inference implemented in C++. If this allowlist-
// based approach proves hard to maintain, consider doing the alternative.
class ComputeBatchSizeOp : public OpKernel {
public:
explicit ComputeBatchSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
std::vector<std::pair<string, Tensor>> input_list;
GraphDef graph_def;
string dataset_node_name;
OP_REQUIRES_OK(ctx, AsGraphDefMinimal(ctx, dataset, &input_list, &graph_def,
&dataset_node_name));
// Create GraphView for easier traversal of graph.
grappler::GraphView graph_view(&graph_def);
const NodeDef* node = graph_view.GetNode(dataset_node_name);
OP_REQUIRES(ctx, node != nullptr,
errors::InvalidArgument("Node does not exist in graph"));
int64 batch_size = GetBatchSize(*node, graph_view);
Tensor* result;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
result->scalar<int64>()() = batch_size;
}
private:
int64 GetBatchSizeFromBatchNode(const NodeDef& node,
const grappler::GraphView& graph) {
int64 arg_index;
if (node.op() == kMapAndBatchOp ||
node.op() == kExperimentalMapAndBatchOp) {
arg_index = node.input_size() - 3;
} else {
arg_index = 1;
}
auto batch_size_node = GetInputNode(node, graph, arg_index);
int64 batch_size;
auto s = GetScalarConstNodeValue(*batch_size_node, &batch_size);
if (!s.ok()) {
VLOG(1) << "Could not compute static batch size. Found batching dataset ("
<< node.name() << "), but failed to get its input batch size: "
<< s.error_message();
return -1;
}
return batch_size;
}
// Helper function that returns the static 0th dimension of a given dataset
// node in the graph. It starts from a node in the graph and recursively
// traverses its inputs until it finds a valid BatchDataset operation,
// and returns its batch size. If the batch size cannot be determined,
// returns -1.
//
// During recursion, it handles four kinds of cases:
// 1. BatchDataset type ops: Returns the value from its batch_size input node.
// 2. Zip / Concatenate dataset ops: Recurses into all inputs to these ops,
// which are themselves all datasets, and returns the batch sizes computed
// by the inputs if they are all the same.
// 3. Core dataset ops which cannot change the size of the 0th dimension of
// dataset output elements: Recurses into the first input parameter.
// 4. All other ops: Fail, returning -1 for unknown.
// TODO(rachelim): For FlatMap type mapping dataset ops, recurse into the
// function definition.
int64 GetBatchSize(const NodeDef& node, const grappler::GraphView& graph) {
if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
return GetBatchSizeFromBatchNode(node, graph);
}
if (IsDatasetNodeOfType(node, kMultipleInputDatasetOps)) {
const NodeDef* input_0 = GetInputNode(node, graph, 0);
int64 batch_size_0 = GetBatchSize(*input_0, graph);
for (int i = 1; i < node.input_size(); ++i) {
const NodeDef* input = GetInputNode(node, graph, i);
auto batch_size_i = GetBatchSize(*input, graph);
if (batch_size_i != batch_size_0) {
VLOG(1) << "Could not compute batch size: inputs to " << node.name()
<< " (" << node.op() << ") had different batch sizes."
<< " Namely, input 0 had batch size " << batch_size_0
<< " while input " << i << " had batch size " << batch_size_i
<< ".";
return -1;
}
}
return batch_size_0;
}
if (IsDatasetNodeOfType(node, kPassThroughOps)) {
const NodeDef* input = GetInputNode(node, graph, 0);
return GetBatchSize(*input, graph);
}
VLOG(1) << "Encountered dataset node " << node.name() << " (" << node.op()
<< ") that prevented further static batch size analysis.";
return -1;
}
};
REGISTER_KERNEL_BUILDER(Name("ComputeBatchSize").Device(DEVICE_CPU),
ComputeBatchSizeOp);
} // anonymous namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow

View File

@ -144,25 +144,11 @@ Status ApplyRewrites(OpKernelContext* ctx,
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
std::function<RewriterConfig(void)> config_factory,
bool record_fingerprint, DatasetBase** rewritten_input) {
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.external_state_policy =
SerializationContext::ExternalStatePolicy::kIgnore;
params.fail_if_unimplemented = false;
params.serialize_data_tensors = false;
params.preserve_random_seeds = false;
SerializationContext serialization_ctx(params);
GraphDef graph_def;
TF_RETURN_IF_ERROR(
AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
string output_node;
for (const auto& node : graph_def.node()) {
if (node.op() == "_Retval") {
output_node = node.input(0);
}
}
TF_RETURN_IF_ERROR(
AsGraphDefMinimal(ctx, input, &input_list, &graph_def, &output_node));
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(

View File

@ -53,6 +53,29 @@ Status FindStatefulOps(const GraphDef& graph_def,
} // namespace
Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input,
std::vector<std::pair<string, Tensor>>* input_list,
GraphDef* result, string* dataset_node) {
SerializationContext::Params params;
params.input_list = input_list;
params.external_state_policy =
SerializationContext::ExternalStatePolicy::kIgnore;
params.fail_if_unimplemented = false;
params.serialize_data_tensors = false;
params.preserve_random_seeds = false;
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
AsGraphDef(ctx, input, std::move(serialization_ctx), result));
// Symbolic `_Retval` node indicates which node corresponds to the dataset.
for (const auto& node : result->node()) {
if (node.op() == "_Retval") {
*dataset_node = node.input(0);
}
}
return Status::OK();
}
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def) {

View File

@ -27,6 +27,15 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def);
// Returns a GraphDef representation of the given dataset using the minimal
// serialization parameters (i.e. ignoring external state, not serializing
// data tensors, not failing if there are datasets which do not have AsGraphDef
// implemented). Sets the `dataset_node` parameter to the dataset's
// node name in the resulting GraphDef.
Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input,
std::vector<std::pair<string, Tensor>>* input_list,
GraphDef* result, string* dataset_node);
} // namespace data
} // namespace tensorflow

View File

@ -145,6 +145,11 @@ REGISTER_OP("UncompressElement")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::DatasetIteratorShape);
REGISTER_OP("ComputeBatchSize")
.Input("input_dataset : variant")
.Output("batch_size : int64")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("CSVDataset")
.Input("filenames: string")
.Input("compression_type: string")

View File

@ -230,5 +230,59 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
_ = distribute._RebatchDataset(dataset, num_replicas=2)
class ComputeBatchSizeTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testComputeBatchSizeKnown(self):
# When drop_remainder=True, batch size can be inferred from the type spec.
dataset = dataset_ops.Dataset.range(32).batch(4, drop_remainder=True)
dataset = dataset_ops.Dataset.zip((dataset, dataset))
batch_size = distribute.compute_batch_size(dataset)
self.assertEqual(4, self.evaluate(batch_size))
@combinations.generate(test_base.default_test_combinations())
def testComputeBatchSizeKnownAndMismatched(self):
# Return -1 when different components have different batch sizes.
dataset = dataset_ops.Dataset.range(32)
dataset = dataset_ops.Dataset.zip((dataset.batch(4, drop_remainder=True),
dataset.batch(8, drop_remainder=True)))
batch_size = distribute.compute_batch_size(dataset)
self.assertEqual(-1, self.evaluate(batch_size))
@combinations.generate(test_base.default_test_combinations())
def testComputeBatchSizeUnknown(self):
dataset = dataset_ops.Dataset.range(32).batch(4)
batch_size = distribute.compute_batch_size(dataset)
self.assertEqual(4, self.evaluate(batch_size))
@combinations.generate(test_base.default_test_combinations())
def testComputeBatchSizeWithPassthrough(self):
dataset = dataset_ops.Dataset.range(32).batch(4)
dataset = dataset.take(5)
batch_size = distribute.compute_batch_size(dataset)
self.assertEqual(4, self.evaluate(batch_size))
@combinations.generate(test_base.default_test_combinations())
def testComputeBatchSizeWithPassthroughInvalid(self):
dataset = dataset_ops.Dataset.range(32).batch(4)
dataset = dataset.map(lambda x: x + 1)
batch_size = distribute.compute_batch_size(dataset)
self.assertEqual(-1, self.evaluate(batch_size))
@combinations.generate(test_base.default_test_combinations())
def testComputeBatchSizeWithZip(self):
dataset = dataset_ops.Dataset.range(32).batch(4)
dataset = dataset_ops.Dataset.zip((dataset, dataset))
batch_size = distribute.compute_batch_size(dataset)
self.assertEqual(4, self.evaluate(batch_size))
@combinations.generate(test_base.default_test_combinations())
def testComputeBatchSizeWithZipMismatched(self):
dataset = dataset_ops.Dataset.range(32)
dataset = dataset_ops.Dataset.zip((dataset.batch(4), dataset.batch(8)))
batch_size = distribute.compute_batch_size(dataset)
self.assertEqual(-1, self.evaluate(batch_size))
if __name__ == "__main__":
test.main()

View File

@ -20,6 +20,8 @@ from __future__ import print_function
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
@ -169,4 +171,54 @@ def replicate(dataset, devices):
return datasets
def compute_batch_size(dataset):
"""An operation that returns the batch size of the dataset.
This op tries to infer the batch size statically by walking up the dataset
tree from the final dataset node and returning the batch size of the first
batching dataset (such as from .batch() and .padded_batch()) that it
encounters. This differs from using the `element_spec` of a dataset in that it
does not account for partial batches.
This operation may fail if it encounters contradictory batch sizes (for
example, if the dataset is created by zipping together two datasets with
different batch sizes), if there are no explicit batching transformations, or
if there are operations downstream from the batching transformation that may
modify its batch size. In these cases, it returns a -1.
Args:
dataset: A `tf.data.Dataset` object.
Returns:
A `tf.int64` Tensor representing the batch size of the dataset sans partial
batches. If this cannot be inferred statically, the value of this tensor
will be -1.
"""
def get_static_batch_dim(output_shape):
if output_shape.rank is None:
return None
return output_shape.dims[0].value
batch_dims = [
get_static_batch_dim(ts._to_legacy_output_shapes()) # pylint: disable=protected-access
for ts in nest.flatten(dataset_ops.get_structure(dataset))
]
if all(d is not None for d in batch_dims):
if all(d == batch_dims[0] for d in batch_dims):
# If all batch dimensions are known and equal, return that directly.
batch_dim = batch_dims[0]
else:
# If all batch dimensions are known but not all equal, return -1.
batch_dim = -1
return constant_op.constant(
batch_dim, dtype=dtypes.int64, name="static_batch_size")
# If any batch dimensions are unknown, use compute_batch_size op.
return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__

View File

@ -792,6 +792,10 @@ tf_module {
name: "ComputeAccidentalHits"
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
}
member_method {
name: "ComputeBatchSize"
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Concat"
argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -792,6 +792,10 @@ tf_module {
name: "ComputeAccidentalHits"
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
}
member_method {
name: "ComputeBatchSize"
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Concat"
argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "