[tf.data] Add an op that computes the statically-known batch size of a dataset where possible.
PiperOrigin-RevId: 322444756 Change-Id: Ib628e9463bf34e5a607fd023d7c065389fff599c
This commit is contained in:
parent
be88c5f8a7
commit
332adf10ab
@ -0,0 +1,5 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ComputeBatchSize"
|
||||||
|
visibility: HIDDEN
|
||||||
|
summary: "Computes the static batch size of a dataset sans partial batches."
|
||||||
|
}
|
@ -54,6 +54,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/platform:regexp",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
],
|
],
|
||||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||||
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -898,5 +899,11 @@ std::string DeterminismPolicy::String() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match) {
|
||||||
|
// Matches all versions of an op by appending an optional version suffix
|
||||||
|
auto expected_re = strings::StrCat(RE2::QuoteMeta(op_prefix), "(V\\d+)?");
|
||||||
|
return RE2::FullMatch(op_to_match, expected_re);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -296,6 +296,14 @@ class DummyResourceOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Given an op prefix and an op to match, returns whether the op to match
|
||||||
|
// is a regex match for any version of the op prefix. For example,
|
||||||
|
// MatchesAnyVersionRE("BatchDataset", "BatchDataset") == true
|
||||||
|
// MatchesAnyVersionRE("BatchDataset", "BatchDatasetV2") == true
|
||||||
|
// MatchesAnyVersionRE("BatchDataset", "BatchDatasetV3") == true
|
||||||
|
// MatchesAnyVersionRE("PaddedBatchDataset", "BatchDataset") == false
|
||||||
|
bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match);
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:regexp_internal",
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
"//tensorflow/core/kernels/data:name_utils",
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -124,6 +124,25 @@ tf_kernel_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "compute_batch_size_op",
|
||||||
|
srcs = ["compute_batch_size_op.cc"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:regexp_internal",
|
||||||
|
"//tensorflow/core/grappler:graph_view",
|
||||||
|
"//tensorflow/core/grappler/optimizers/data:graph_utils",
|
||||||
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
|
"//tensorflow/core/kernels/data:serialization_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "csv_dataset_op",
|
name = "csv_dataset_op",
|
||||||
srcs = ["csv_dataset_op.cc"],
|
srcs = ["csv_dataset_op.cc"],
|
||||||
@ -736,6 +755,7 @@ tf_kernel_library(
|
|||||||
":choose_fastest_branch_dataset_op",
|
":choose_fastest_branch_dataset_op",
|
||||||
":choose_fastest_dataset_op",
|
":choose_fastest_dataset_op",
|
||||||
":compression_ops",
|
":compression_ops",
|
||||||
|
":compute_batch_size_op",
|
||||||
":csv_dataset_op",
|
":csv_dataset_op",
|
||||||
":dense_to_sparse_batch_dataset_op",
|
":dense_to_sparse_batch_dataset_op",
|
||||||
":directed_interleave_dataset_op",
|
":directed_interleave_dataset_op",
|
||||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
#include "tensorflow/core/platform/regexp.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -97,15 +97,12 @@ class AssertNextDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
int n = tokens.size();
|
int n = tokens.size();
|
||||||
for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
|
for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
|
||||||
std::string transformation_escaped =
|
if (!MatchesAnyVersionRE(dataset()->transformations_[i],
|
||||||
RE2::QuoteMeta(dataset()->transformations_[i]);
|
tokens[n - 2 - i])) {
|
||||||
std::string version_suffix = "(V\\d+)?";
|
|
||||||
std::string expected_re =
|
|
||||||
absl::StrCat(transformation_escaped, version_suffix);
|
|
||||||
if (!RE2::FullMatch(tokens[n - 2 - i], expected_re)) {
|
|
||||||
return errors::InvalidArgument("Asserted transformation matching ",
|
return errors::InvalidArgument("Asserted transformation matching ",
|
||||||
expected_re, " at offset ", i,
|
dataset()->transformations_[i],
|
||||||
" but encountered ", tokens[n - 2 - i],
|
" at offset ", i, " but encountered ",
|
||||||
|
tokens[n - 2 - i],
|
||||||
" transformation instead.");
|
" transformation instead.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,191 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_util.h"
|
||||||
|
#include "tensorflow/core/grappler/graph_view.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||||
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
|
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||||
|
#include "tensorflow/core/kernels/data/serialization_utils.h"
|
||||||
|
#include "tensorflow/core/platform/stringprintf.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace data {
|
||||||
|
namespace experimental {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using grappler::graph_utils::GetScalarConstNodeValue;
|
||||||
|
|
||||||
|
constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
|
||||||
|
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
|
||||||
|
|
||||||
|
constexpr std::array<const char*, 4> kBatchDatasetOps = {
|
||||||
|
"BatchDataset",
|
||||||
|
"PaddedBatchDataset",
|
||||||
|
kMapAndBatchOp,
|
||||||
|
kExperimentalMapAndBatchOp,
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr std::array<const char*, 2> kMultipleInputDatasetOps = {
|
||||||
|
"ConcatenateDataset",
|
||||||
|
"ZipDataset",
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr std::array<const char*, 14> kPassThroughOps = {
|
||||||
|
"AssertCardinalityDataset",
|
||||||
|
"CacheDataset",
|
||||||
|
"FilterDataset",
|
||||||
|
"Identity",
|
||||||
|
"ModelDataset",
|
||||||
|
"OptimizeDataset",
|
||||||
|
"ParseExampleDataset",
|
||||||
|
"PrefetchDataset",
|
||||||
|
"RepeatDataset",
|
||||||
|
"ShardDataset",
|
||||||
|
"ShuffleAndRepeatDataset",
|
||||||
|
"ShuffleDataset",
|
||||||
|
"SkipDataset",
|
||||||
|
"TakeDataset",
|
||||||
|
};
|
||||||
|
|
||||||
|
template <std::size_t SIZE>
|
||||||
|
bool IsDatasetNodeOfType(const NodeDef& node,
|
||||||
|
const std::array<const char*, SIZE>& arr) {
|
||||||
|
for (const auto& dataset_op : arr) {
|
||||||
|
if (MatchesAnyVersionRE(dataset_op, node.op())) return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const NodeDef* GetInputNode(const NodeDef& node,
|
||||||
|
const grappler::GraphView& graph,
|
||||||
|
int64 input_index) {
|
||||||
|
if (node.input_size() == 0) return nullptr;
|
||||||
|
grappler::GraphView::InputPort input_port =
|
||||||
|
graph.GetInputPort(node.name(), input_index);
|
||||||
|
return graph.GetRegularFanin(input_port).node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(rachelim): This op traverses the dataset graph using a allowlist-based
|
||||||
|
// approach. As an alternative, we could instead rewrite all batching datasets'
|
||||||
|
// drop_remainder parameter to True, then rerun the dataset graph to derive
|
||||||
|
// new output shapes using C++ shape inference. This is more robust in cases
|
||||||
|
// where datasets have shape inference implemented in C++. If this allowlist-
|
||||||
|
// based approach proves hard to maintain, consider doing the alternative.
|
||||||
|
class ComputeBatchSizeOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit ComputeBatchSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
DatasetBase* dataset;
|
||||||
|
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||||
|
|
||||||
|
std::vector<std::pair<string, Tensor>> input_list;
|
||||||
|
GraphDef graph_def;
|
||||||
|
string dataset_node_name;
|
||||||
|
OP_REQUIRES_OK(ctx, AsGraphDefMinimal(ctx, dataset, &input_list, &graph_def,
|
||||||
|
&dataset_node_name));
|
||||||
|
|
||||||
|
// Create GraphView for easier traversal of graph.
|
||||||
|
grappler::GraphView graph_view(&graph_def);
|
||||||
|
|
||||||
|
const NodeDef* node = graph_view.GetNode(dataset_node_name);
|
||||||
|
OP_REQUIRES(ctx, node != nullptr,
|
||||||
|
errors::InvalidArgument("Node does not exist in graph"));
|
||||||
|
int64 batch_size = GetBatchSize(*node, graph_view);
|
||||||
|
Tensor* result;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
|
||||||
|
result->scalar<int64>()() = batch_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 GetBatchSizeFromBatchNode(const NodeDef& node,
|
||||||
|
const grappler::GraphView& graph) {
|
||||||
|
int64 arg_index;
|
||||||
|
if (node.op() == kMapAndBatchOp ||
|
||||||
|
node.op() == kExperimentalMapAndBatchOp) {
|
||||||
|
arg_index = node.input_size() - 3;
|
||||||
|
} else {
|
||||||
|
arg_index = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto batch_size_node = GetInputNode(node, graph, arg_index);
|
||||||
|
int64 batch_size;
|
||||||
|
auto s = GetScalarConstNodeValue(*batch_size_node, &batch_size);
|
||||||
|
if (!s.ok()) {
|
||||||
|
VLOG(1) << "Could not compute static batch size. Found batching dataset ("
|
||||||
|
<< node.name() << "), but failed to get its input batch size: "
|
||||||
|
<< s.error_message();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return batch_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function that returns the static 0th dimension of a given dataset
|
||||||
|
// node in the graph. It starts from a node in the graph and recursively
|
||||||
|
// traverses its inputs until it finds a valid BatchDataset operation,
|
||||||
|
// and returns its batch size. If the batch size cannot be determined,
|
||||||
|
// returns -1.
|
||||||
|
//
|
||||||
|
// During recursion, it handles four kinds of cases:
|
||||||
|
// 1. BatchDataset type ops: Returns the value from its batch_size input node.
|
||||||
|
// 2. Zip / Concatenate dataset ops: Recurses into all inputs to these ops,
|
||||||
|
// which are themselves all datasets, and returns the batch sizes computed
|
||||||
|
// by the inputs if they are all the same.
|
||||||
|
// 3. Core dataset ops which cannot change the size of the 0th dimension of
|
||||||
|
// dataset output elements: Recurses into the first input parameter.
|
||||||
|
// 4. All other ops: Fail, returning -1 for unknown.
|
||||||
|
// TODO(rachelim): For FlatMap type mapping dataset ops, recurse into the
|
||||||
|
// function definition.
|
||||||
|
int64 GetBatchSize(const NodeDef& node, const grappler::GraphView& graph) {
|
||||||
|
if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
|
||||||
|
return GetBatchSizeFromBatchNode(node, graph);
|
||||||
|
}
|
||||||
|
if (IsDatasetNodeOfType(node, kMultipleInputDatasetOps)) {
|
||||||
|
const NodeDef* input_0 = GetInputNode(node, graph, 0);
|
||||||
|
int64 batch_size_0 = GetBatchSize(*input_0, graph);
|
||||||
|
for (int i = 1; i < node.input_size(); ++i) {
|
||||||
|
const NodeDef* input = GetInputNode(node, graph, i);
|
||||||
|
auto batch_size_i = GetBatchSize(*input, graph);
|
||||||
|
if (batch_size_i != batch_size_0) {
|
||||||
|
VLOG(1) << "Could not compute batch size: inputs to " << node.name()
|
||||||
|
<< " (" << node.op() << ") had different batch sizes."
|
||||||
|
<< " Namely, input 0 had batch size " << batch_size_0
|
||||||
|
<< " while input " << i << " had batch size " << batch_size_i
|
||||||
|
<< ".";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return batch_size_0;
|
||||||
|
}
|
||||||
|
if (IsDatasetNodeOfType(node, kPassThroughOps)) {
|
||||||
|
const NodeDef* input = GetInputNode(node, graph, 0);
|
||||||
|
return GetBatchSize(*input, graph);
|
||||||
|
}
|
||||||
|
VLOG(1) << "Encountered dataset node " << node.name() << " (" << node.op()
|
||||||
|
<< ") that prevented further static batch size analysis.";
|
||||||
|
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ComputeBatchSize").Device(DEVICE_CPU),
|
||||||
|
ComputeBatchSizeOp);
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace experimental
|
||||||
|
} // namespace data
|
||||||
|
} // namespace tensorflow
|
@ -144,25 +144,11 @@ Status ApplyRewrites(OpKernelContext* ctx,
|
|||||||
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
std::function<RewriterConfig(void)> config_factory,
|
std::function<RewriterConfig(void)> config_factory,
|
||||||
bool record_fingerprint, DatasetBase** rewritten_input) {
|
bool record_fingerprint, DatasetBase** rewritten_input) {
|
||||||
SerializationContext::Params params;
|
|
||||||
std::vector<std::pair<string, Tensor>> input_list;
|
std::vector<std::pair<string, Tensor>> input_list;
|
||||||
params.input_list = &input_list;
|
|
||||||
params.external_state_policy =
|
|
||||||
SerializationContext::ExternalStatePolicy::kIgnore;
|
|
||||||
params.fail_if_unimplemented = false;
|
|
||||||
params.serialize_data_tensors = false;
|
|
||||||
params.preserve_random_seeds = false;
|
|
||||||
SerializationContext serialization_ctx(params);
|
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
AsGraphDef(ctx, input, std::move(serialization_ctx), &graph_def));
|
|
||||||
|
|
||||||
string output_node;
|
string output_node;
|
||||||
for (const auto& node : graph_def.node()) {
|
TF_RETURN_IF_ERROR(
|
||||||
if (node.op() == "_Retval") {
|
AsGraphDefMinimal(ctx, input, &input_list, &graph_def, &output_node));
|
||||||
output_node = node.input(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
|
VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
|
@ -53,6 +53,29 @@ Status FindStatefulOps(const GraphDef& graph_def,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
|
std::vector<std::pair<string, Tensor>>* input_list,
|
||||||
|
GraphDef* result, string* dataset_node) {
|
||||||
|
SerializationContext::Params params;
|
||||||
|
params.input_list = input_list;
|
||||||
|
params.external_state_policy =
|
||||||
|
SerializationContext::ExternalStatePolicy::kIgnore;
|
||||||
|
params.fail_if_unimplemented = false;
|
||||||
|
params.serialize_data_tensors = false;
|
||||||
|
params.preserve_random_seeds = false;
|
||||||
|
SerializationContext serialization_ctx(params);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
AsGraphDef(ctx, input, std::move(serialization_ctx), result));
|
||||||
|
|
||||||
|
// Symbolic `_Retval` node indicates which node corresponds to the dataset.
|
||||||
|
for (const auto& node : result->node()) {
|
||||||
|
if (node.op() == "_Retval") {
|
||||||
|
*dataset_node = node.input(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||||
SerializationContext&& serialization_ctx,
|
SerializationContext&& serialization_ctx,
|
||||||
GraphDef* graph_def) {
|
GraphDef* graph_def) {
|
||||||
|
@ -27,6 +27,15 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
|||||||
SerializationContext&& serialization_ctx,
|
SerializationContext&& serialization_ctx,
|
||||||
GraphDef* graph_def);
|
GraphDef* graph_def);
|
||||||
|
|
||||||
|
// Returns a GraphDef representation of the given dataset using the minimal
|
||||||
|
// serialization parameters (i.e. ignoring external state, not serializing
|
||||||
|
// data tensors, not failing if there are datasets which do not have AsGraphDef
|
||||||
|
// implemented). Sets the `dataset_node` parameter to the dataset's
|
||||||
|
// node name in the resulting GraphDef.
|
||||||
|
Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input,
|
||||||
|
std::vector<std::pair<string, Tensor>>* input_list,
|
||||||
|
GraphDef* result, string* dataset_node);
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -145,6 +145,11 @@ REGISTER_OP("UncompressElement")
|
|||||||
.Attr("output_shapes: list(shape) >= 1")
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
.SetShapeFn(shape_inference::DatasetIteratorShape);
|
.SetShapeFn(shape_inference::DatasetIteratorShape);
|
||||||
|
|
||||||
|
REGISTER_OP("ComputeBatchSize")
|
||||||
|
.Input("input_dataset : variant")
|
||||||
|
.Output("batch_size : int64")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("CSVDataset")
|
REGISTER_OP("CSVDataset")
|
||||||
.Input("filenames: string")
|
.Input("filenames: string")
|
||||||
.Input("compression_type: string")
|
.Input("compression_type: string")
|
||||||
|
@ -230,5 +230,59 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
_ = distribute._RebatchDataset(dataset, num_replicas=2)
|
_ = distribute._RebatchDataset(dataset, num_replicas=2)
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeBatchSizeTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testComputeBatchSizeKnown(self):
|
||||||
|
# When drop_remainder=True, batch size can be inferred from the type spec.
|
||||||
|
dataset = dataset_ops.Dataset.range(32).batch(4, drop_remainder=True)
|
||||||
|
dataset = dataset_ops.Dataset.zip((dataset, dataset))
|
||||||
|
batch_size = distribute.compute_batch_size(dataset)
|
||||||
|
self.assertEqual(4, self.evaluate(batch_size))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testComputeBatchSizeKnownAndMismatched(self):
|
||||||
|
# Return -1 when different components have different batch sizes.
|
||||||
|
dataset = dataset_ops.Dataset.range(32)
|
||||||
|
dataset = dataset_ops.Dataset.zip((dataset.batch(4, drop_remainder=True),
|
||||||
|
dataset.batch(8, drop_remainder=True)))
|
||||||
|
batch_size = distribute.compute_batch_size(dataset)
|
||||||
|
self.assertEqual(-1, self.evaluate(batch_size))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testComputeBatchSizeUnknown(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(32).batch(4)
|
||||||
|
batch_size = distribute.compute_batch_size(dataset)
|
||||||
|
self.assertEqual(4, self.evaluate(batch_size))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testComputeBatchSizeWithPassthrough(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(32).batch(4)
|
||||||
|
dataset = dataset.take(5)
|
||||||
|
batch_size = distribute.compute_batch_size(dataset)
|
||||||
|
self.assertEqual(4, self.evaluate(batch_size))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testComputeBatchSizeWithPassthroughInvalid(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(32).batch(4)
|
||||||
|
dataset = dataset.map(lambda x: x + 1)
|
||||||
|
batch_size = distribute.compute_batch_size(dataset)
|
||||||
|
self.assertEqual(-1, self.evaluate(batch_size))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testComputeBatchSizeWithZip(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(32).batch(4)
|
||||||
|
dataset = dataset_ops.Dataset.zip((dataset, dataset))
|
||||||
|
batch_size = distribute.compute_batch_size(dataset)
|
||||||
|
self.assertEqual(4, self.evaluate(batch_size))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testComputeBatchSizeWithZipMismatched(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(32)
|
||||||
|
dataset = dataset_ops.Dataset.zip((dataset.batch(4), dataset.batch(8)))
|
||||||
|
batch_size = distribute.compute_batch_size(dataset)
|
||||||
|
self.assertEqual(-1, self.evaluate(batch_size))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
|
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.util import nest
|
from tensorflow.python.data.util import nest
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||||
@ -169,4 +171,54 @@ def replicate(dataset, devices):
|
|||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
|
def compute_batch_size(dataset):
|
||||||
|
"""An operation that returns the batch size of the dataset.
|
||||||
|
|
||||||
|
This op tries to infer the batch size statically by walking up the dataset
|
||||||
|
tree from the final dataset node and returning the batch size of the first
|
||||||
|
batching dataset (such as from .batch() and .padded_batch()) that it
|
||||||
|
encounters. This differs from using the `element_spec` of a dataset in that it
|
||||||
|
does not account for partial batches.
|
||||||
|
|
||||||
|
This operation may fail if it encounters contradictory batch sizes (for
|
||||||
|
example, if the dataset is created by zipping together two datasets with
|
||||||
|
different batch sizes), if there are no explicit batching transformations, or
|
||||||
|
if there are operations downstream from the batching transformation that may
|
||||||
|
modify its batch size. In these cases, it returns a -1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: A `tf.data.Dataset` object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `tf.int64` Tensor representing the batch size of the dataset sans partial
|
||||||
|
batches. If this cannot be inferred statically, the value of this tensor
|
||||||
|
will be -1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_static_batch_dim(output_shape):
|
||||||
|
if output_shape.rank is None:
|
||||||
|
return None
|
||||||
|
return output_shape.dims[0].value
|
||||||
|
|
||||||
|
batch_dims = [
|
||||||
|
get_static_batch_dim(ts._to_legacy_output_shapes()) # pylint: disable=protected-access
|
||||||
|
for ts in nest.flatten(dataset_ops.get_structure(dataset))
|
||||||
|
]
|
||||||
|
|
||||||
|
if all(d is not None for d in batch_dims):
|
||||||
|
|
||||||
|
if all(d == batch_dims[0] for d in batch_dims):
|
||||||
|
# If all batch dimensions are known and equal, return that directly.
|
||||||
|
batch_dim = batch_dims[0]
|
||||||
|
else:
|
||||||
|
# If all batch dimensions are known but not all equal, return -1.
|
||||||
|
batch_dim = -1
|
||||||
|
|
||||||
|
return constant_op.constant(
|
||||||
|
batch_dim, dtype=dtypes.int64, name="static_batch_size")
|
||||||
|
|
||||||
|
# If any batch dimensions are unknown, use compute_batch_size op.
|
||||||
|
return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__
|
_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__
|
||||||
|
@ -792,6 +792,10 @@ tf_module {
|
|||||||
name: "ComputeAccidentalHits"
|
name: "ComputeAccidentalHits"
|
||||||
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ComputeBatchSize"
|
||||||
|
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Concat"
|
name: "Concat"
|
||||||
argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -792,6 +792,10 @@ tf_module {
|
|||||||
name: "ComputeAccidentalHits"
|
name: "ComputeAccidentalHits"
|
||||||
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "ComputeBatchSize"
|
||||||
|
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Concat"
|
name: "Concat"
|
||||||
argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'concat_dim\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user