[tf.data] Change the behavior of RebatchDataset when 1) drop_remainder = True or 2) batch size is not divisible by global batch size. In these cases, instead of mutating the batch size directly, we add a
`.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x).batch(new_batch_size))` after the batch. This has three effects: 1) Changes the behavior of _RebatchDataset, such that at each step (num_workers minibatches), the total number of examples is the same as the global batch size (v.s. before, when it was rounded up when global_batch_size is not divisible by num_workers) 2) Preserve behavior of `drop_remainder` (wrt to the global batch) 3) Probably less performant, since from_tensor_slices and batch both require data copies. PiperOrigin-RevId: 261233882
This commit is contained in:
parent
166c742245
commit
ef9f0e8f2f
@ -88,18 +88,27 @@ void ReplaceReferences(const string& from, const string& to,
|
|||||||
|
|
||||||
void AddFunctionOutputWithUniqueName(StringPiece prefix,
|
void AddFunctionOutputWithUniqueName(StringPiece prefix,
|
||||||
StringPiece output_tensor_name,
|
StringPiece output_tensor_name,
|
||||||
FunctionDef* function, DataType dt) {
|
FunctionDef* fdef, DataType dtype) {
|
||||||
string name = string(prefix);
|
string name = string(prefix);
|
||||||
int id = function->signature().output_arg_size();
|
int id = fdef->signature().output_arg_size();
|
||||||
while (ContainsFunctionOutputWithName(name, *function)) {
|
while (ContainsFunctionOutputWithName(name, *fdef)) {
|
||||||
name = strings::StrCat(prefix, "/_", id);
|
name = strings::StrCat(prefix, "/_", id);
|
||||||
++id;
|
++id;
|
||||||
}
|
}
|
||||||
auto* output = function->mutable_signature()->mutable_output_arg()->Add();
|
auto* output = fdef->mutable_signature()->mutable_output_arg()->Add();
|
||||||
output->set_name(name);
|
output->set_name(name);
|
||||||
output->set_type(dt);
|
output->set_type(dtype);
|
||||||
|
|
||||||
(*function->mutable_ret())[name] = string(output_tensor_name);
|
(*fdef->mutable_ret())[name] = string(output_tensor_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef,
|
||||||
|
DataType dtype) {
|
||||||
|
auto* input_arg = fdef->mutable_signature()->mutable_input_arg()->Add();
|
||||||
|
input_arg->set_type(dtype);
|
||||||
|
input_arg->set_name(name);
|
||||||
|
|
||||||
|
return input_arg;
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef* AddNode(StringPiece name, StringPiece op,
|
NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||||
|
@ -61,7 +61,11 @@ void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
|
|||||||
// is unique, and maps to output_tensor_name in the ret dict.
|
// is unique, and maps to output_tensor_name in the ret dict.
|
||||||
void AddFunctionOutputWithUniqueName(StringPiece prefix,
|
void AddFunctionOutputWithUniqueName(StringPiece prefix,
|
||||||
StringPiece output_tensor_name,
|
StringPiece output_tensor_name,
|
||||||
FunctionDef* function, DataType dt);
|
FunctionDef* fdef, DataType dtype);
|
||||||
|
|
||||||
|
// Adds an input to a FunctionDef.
|
||||||
|
OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef,
|
||||||
|
DataType dtype);
|
||||||
|
|
||||||
// Adds a node to a FunctionDef.
|
// Adds a node to a FunctionDef.
|
||||||
NodeDef* AddNode(StringPiece name, StringPiece op,
|
NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||||
|
@ -60,6 +60,18 @@ TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
|
|||||||
EXPECT_EQ(function.ret().at("y/_1"), "two");
|
EXPECT_EQ(function.ret().at("y/_1"), "two");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FunctionUtilsTest, AddFunctionInput) {
|
||||||
|
FunctionDef fdef;
|
||||||
|
auto arg0 = AddFunctionInput("arg0", &fdef, DT_INT32);
|
||||||
|
auto arg1 = AddFunctionInput("arg1", &fdef, DT_BOOL);
|
||||||
|
EXPECT_EQ(fdef.signature().input_arg().data()[0], arg0);
|
||||||
|
EXPECT_EQ(arg0->name(), "arg0");
|
||||||
|
EXPECT_EQ(arg0->type(), DT_INT32);
|
||||||
|
EXPECT_EQ(fdef.signature().input_arg().data()[1], arg1);
|
||||||
|
EXPECT_EQ(arg1->name(), "arg1");
|
||||||
|
EXPECT_EQ(arg1->type(), DT_BOOL);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
|
TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
|
||||||
FunctionDef function = test::function::XTimesTwo();
|
FunctionDef function = test::function::XTimesTwo();
|
||||||
EXPECT_FALSE(ContainsFunctionNodeWithName(
|
EXPECT_FALSE(ContainsFunctionNodeWithName(
|
||||||
|
@ -158,6 +158,46 @@ NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) {
|
|||||||
graph);
|
graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GetScalarConstNodeValueHelper(
|
||||||
|
const NodeDef& node, DataType dtype,
|
||||||
|
const std::function<void(const Tensor&)>& get_value) {
|
||||||
|
if (node.op() != kConstOpName)
|
||||||
|
return errors::InvalidArgument("Node ", node.name(),
|
||||||
|
" is not a Const node. Op: ", node.op());
|
||||||
|
|
||||||
|
Tensor tensor;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor));
|
||||||
|
if (!TensorShapeUtils::IsScalar(tensor.shape())) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Node ", node.name(),
|
||||||
|
" should be a scalar but has shape: ", tensor.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tensor.dtype() != dtype) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Node ", node.name(), " should have type ", DataTypeString(dtype),
|
||||||
|
" but has type: ", DataTypeString(tensor.dtype()));
|
||||||
|
}
|
||||||
|
|
||||||
|
get_value(tensor);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status GetScalarConstNodeValue(const NodeDef& node, int64* value) {
|
||||||
|
return GetScalarConstNodeValueHelper(
|
||||||
|
node, DT_INT64,
|
||||||
|
[value](const Tensor& tensor) { *value = tensor.scalar<int64>()(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status GetScalarConstNodeValue(const NodeDef& node, bool* value) {
|
||||||
|
return GetScalarConstNodeValueHelper(
|
||||||
|
node, DT_BOOL,
|
||||||
|
[value](const Tensor& tensor) { *value = tensor.scalar<bool>()(); });
|
||||||
|
}
|
||||||
|
|
||||||
bool Compare(const GraphDef& g1, const GraphDef& g2) {
|
bool Compare(const GraphDef& g1, const GraphDef& g2) {
|
||||||
if (g1.node_size() != g2.node_size()) {
|
if (g1.node_size() != g2.node_size()) {
|
||||||
return false;
|
return false;
|
||||||
@ -240,12 +280,12 @@ NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
|
|||||||
return graph.GetRegularFanin(input_port).node;
|
return graph.GetRegularFanin(input_port).node;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetDatasetOutputTypesAttr(const NodeDef& node, AttrValue* output_types) {
|
Status GetDatasetOutputTypesAttr(const NodeDef& node,
|
||||||
|
DataTypeVector* output_types) {
|
||||||
// We don't name the output_types attr consistently, so should check for both.
|
// We don't name the output_types attr consistently, so should check for both.
|
||||||
for (const string& attr_name : {"output_types", "Toutput_types"}) {
|
for (const string& attr_name : {"output_types", "Toutput_types"}) {
|
||||||
if (node.attr().contains(attr_name)) {
|
if (node.attr().contains(attr_name)) {
|
||||||
*output_types = node.attr().at(attr_name);
|
return GetNodeAttr(node, attr_name, output_types);
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return errors::InvalidArgument("Could not find output_types attr for node: ",
|
return errors::InvalidArgument("Could not find output_types attr for node: ",
|
||||||
|
@ -80,6 +80,21 @@ NodeDef* AddScalarConstNode(int64 v, MutableGraphView* graph);
|
|||||||
template <>
|
template <>
|
||||||
NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph);
|
NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph);
|
||||||
|
|
||||||
|
// Retrieves the value of a const node. Returns an error
|
||||||
|
// if the node is not const, or its value is of a different type.
|
||||||
|
template <typename T>
|
||||||
|
Status GetScalarConstNodeValue(const NodeDef& node, T* value) {
|
||||||
|
// is_same is an idiomatic hack for making it compile if not instantiated.
|
||||||
|
// Replacing with false will result in a compile-time error.
|
||||||
|
static_assert(!std::is_same<T, T>::value,
|
||||||
|
"Invalid specialization of this method fo rtype T.");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status GetScalarConstNodeValue(const NodeDef& node, int64* value);
|
||||||
|
template <>
|
||||||
|
Status GetScalarConstNodeValue(const NodeDef& node, bool* value);
|
||||||
|
|
||||||
// Checks whether the two graphs are the same.
|
// Checks whether the two graphs are the same.
|
||||||
bool Compare(const GraphDef& g1, const GraphDef& g2);
|
bool Compare(const GraphDef& g1, const GraphDef& g2);
|
||||||
|
|
||||||
@ -114,7 +129,8 @@ NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
|
|||||||
int64 i);
|
int64 i);
|
||||||
|
|
||||||
// Gets the attr corresponding to a dataset node's output types, if it exists.
|
// Gets the attr corresponding to a dataset node's output types, if it exists.
|
||||||
Status GetDatasetOutputTypesAttr(const NodeDef& node, AttrValue* output_types);
|
Status GetDatasetOutputTypesAttr(const NodeDef& node,
|
||||||
|
DataTypeVector* output_types);
|
||||||
|
|
||||||
// Returns the list of indices of all nodes with the given op or empty list if
|
// Returns the list of indices of all nodes with the given op or empty list if
|
||||||
// no such node exists.
|
// no such node exists.
|
||||||
|
@ -85,6 +85,64 @@ TEST(GraphUtilsTest, AddScalarConstNodeString) {
|
|||||||
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
|
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GraphUtilsTest, GetScalarConstNodeInt64) {
|
||||||
|
GraphDef graph_def;
|
||||||
|
MutableGraphView graph(&graph_def);
|
||||||
|
NodeDef* int64_node = AddScalarConstNode<int64>(128, &graph);
|
||||||
|
int64 result;
|
||||||
|
EXPECT_TRUE(GetScalarConstNodeValue<int64>(*int64_node, &result).ok());
|
||||||
|
EXPECT_EQ(result, 128);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GraphUtilsTest, GetScalarConstNodeBool) {
|
||||||
|
GraphDef graph_def;
|
||||||
|
MutableGraphView graph(&graph_def);
|
||||||
|
NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
|
||||||
|
bool result;
|
||||||
|
EXPECT_TRUE(GetScalarConstNodeValue<bool>(*bool_node, &result).ok());
|
||||||
|
EXPECT_EQ(result, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GraphUtilsTest, GetScalarConstNodeErrorWithNonConst) {
|
||||||
|
GraphDef graph_def;
|
||||||
|
MutableGraphView graph(&graph_def);
|
||||||
|
NodeDef* non_const = AddScalarPlaceholder(DT_INT64, &graph);
|
||||||
|
int64 result;
|
||||||
|
Status s = GetScalarConstNodeValue<int64>(*non_const, &result);
|
||||||
|
EXPECT_FALSE(s.ok());
|
||||||
|
EXPECT_EQ(s.error_message(),
|
||||||
|
"Node Placeholder is not a Const node. Op: Placeholder");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GraphUtilsTest, GetScalarConstNodeErrorWithType) {
|
||||||
|
GraphDef graph_def;
|
||||||
|
MutableGraphView graph(&graph_def);
|
||||||
|
NodeDef* int64_node = AddScalarConstNode<int64>(128, &graph);
|
||||||
|
bool result;
|
||||||
|
Status s = GetScalarConstNodeValue<bool>(*int64_node, &result);
|
||||||
|
EXPECT_FALSE(s.ok());
|
||||||
|
EXPECT_EQ(s.error_message(),
|
||||||
|
"Node Const should have type bool but has type: int64");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GraphUtilsTest, GetScalarConstNodeErrorWithVector) {
|
||||||
|
NodeDef node;
|
||||||
|
node.set_name("Const");
|
||||||
|
node.set_op("Const");
|
||||||
|
|
||||||
|
(*node.mutable_attr())["dtype"].set_type(DT_INT64);
|
||||||
|
auto tensor = (*node.mutable_attr())["value"].mutable_tensor();
|
||||||
|
tensor->set_dtype(DT_INT64);
|
||||||
|
tensor->mutable_tensor_shape()->mutable_dim()->Add()->set_size(1);
|
||||||
|
tensor->add_int64_val(128);
|
||||||
|
|
||||||
|
int64 result;
|
||||||
|
Status s = GetScalarConstNodeValue<int64>(node, &result);
|
||||||
|
EXPECT_FALSE(s.ok());
|
||||||
|
EXPECT_EQ(s.error_message(),
|
||||||
|
"Node Const should be a scalar but has shape: [1]");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(GraphUtilsTest, Compare) {
|
TEST(GraphUtilsTest, Compare) {
|
||||||
GraphDef graph_def_a;
|
GraphDef graph_def_a;
|
||||||
MutableGraphView graph_a(&graph_def_a);
|
MutableGraphView graph_a(&graph_def_a);
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||||
#include "tensorflow/core/grappler/utils/functions.h"
|
#include "tensorflow/core/grappler/utils/functions.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/padding.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
@ -50,14 +51,19 @@ constexpr char kConstOp[] = "Const";
|
|||||||
constexpr char kIdentityOp[] = "Identity";
|
constexpr char kIdentityOp[] = "Identity";
|
||||||
constexpr char kSubOp[] = "Sub";
|
constexpr char kSubOp[] = "Sub";
|
||||||
constexpr char kTruncateDivOp[] = "TruncateDiv";
|
constexpr char kTruncateDivOp[] = "TruncateDiv";
|
||||||
|
constexpr char kOutputShapesAttr[] = "output_shapes";
|
||||||
|
constexpr char kOutputTypesAttr[] = "output_types";
|
||||||
|
constexpr char kTOutputTypesAttr[] = "Toutput_types";
|
||||||
|
constexpr char kBatchOp[] = "BatchDataset";
|
||||||
|
constexpr char kBatchV2Op[] = "BatchDatasetV2";
|
||||||
|
constexpr char kPaddedBatchOp[] = "PaddedBatchDataset";
|
||||||
|
constexpr char kPaddedBatchV2Op[] = "PaddedBatchDatasetV2";
|
||||||
|
constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
|
||||||
|
constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset";
|
||||||
|
|
||||||
constexpr std::array<const char*, 6> kBatchDatasetOps = {
|
constexpr std::array<const char*, 6> kBatchDatasetOps = {
|
||||||
"BatchDataset",
|
kBatchOp, kBatchV2Op, kMapAndBatchOp, kExperimentalMapAndBatchOp,
|
||||||
"BatchDatasetV2",
|
kPaddedBatchOp, kPaddedBatchV2Op};
|
||||||
"ExperimentalMapAndBatchDataset",
|
|
||||||
"MapAndBatchDataset",
|
|
||||||
"PaddedBatchDataset",
|
|
||||||
"PaddedBatchDatasetV2"};
|
|
||||||
|
|
||||||
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||||
"ConcatenateDataset",
|
"ConcatenateDataset",
|
||||||
@ -117,17 +123,24 @@ constexpr std::array<const char*, 9> kSourceDatasetOps = {
|
|||||||
"TFRecordDataset",
|
"TFRecordDataset",
|
||||||
};
|
};
|
||||||
|
|
||||||
NodeDef* AddBinaryNode(const string& input_x, const string& input_y,
|
NodeDef MakeBinaryNode(const string& input_x, const string& input_y,
|
||||||
const string& op, DataType type,
|
const string& op, DataType dtype) {
|
||||||
MutableGraphView* graph) {
|
|
||||||
NodeDef node;
|
NodeDef node;
|
||||||
node.set_op(op);
|
node.set_op(op);
|
||||||
node.add_input(input_x);
|
node.add_input(input_x);
|
||||||
node.add_input(input_y);
|
node.add_input(input_y);
|
||||||
graph_utils::SetUniqueGraphNodeName(op, graph->graph(), &node);
|
AddNodeAttr("T", dtype, &node);
|
||||||
AddNodeAttr("T", type, &node);
|
|
||||||
|
|
||||||
return graph->AddNode(std::move(node));
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef* AddBinaryNode(const string& input_x, const string& input_y,
|
||||||
|
const string& op, DataType type, FunctionDef* fdef) {
|
||||||
|
NodeDef* node = fdef->add_node_def();
|
||||||
|
*node = MakeBinaryNode(input_x, input_y, op, type);
|
||||||
|
function_utils::SetUniqueFunctionNodeName(op, fdef, node);
|
||||||
|
|
||||||
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds a Const node to the FunctionDef.
|
// Adds a Const node to the FunctionDef.
|
||||||
@ -161,6 +174,30 @@ Status AddConstIntNode(gtl::ArraySlice<int32> values, const TensorShape& shape,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status AddConstInt64Node(int64 value, FunctionDef* fdef, NodeDef** result) {
|
||||||
|
*result = fdef->add_node_def();
|
||||||
|
Tensor t(value);
|
||||||
|
TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const")
|
||||||
|
.Attr("dtype", DT_INT64)
|
||||||
|
.Attr("value", t)
|
||||||
|
.Finalize(*result));
|
||||||
|
function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AddConstBoolNode(bool value, FunctionDef* fdef, NodeDef** result) {
|
||||||
|
*result = fdef->add_node_def();
|
||||||
|
Tensor t(value);
|
||||||
|
TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const")
|
||||||
|
.Attr("dtype", DT_BOOL)
|
||||||
|
.Attr("value", t)
|
||||||
|
.Finalize(*result));
|
||||||
|
function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef,
|
Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef,
|
||||||
NodeDef** result) {
|
NodeDef** result) {
|
||||||
*result = fdef->add_node_def();
|
*result = fdef->add_node_def();
|
||||||
@ -271,58 +308,69 @@ Status GetBatchDim(AttrValue output_shapes, int* batch_dim) {
|
|||||||
Status UpdateOutputShapes(const string& node_name, int64 num_workers,
|
Status UpdateOutputShapes(const string& node_name, int64 num_workers,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef* node = graph->GetNode(node_name);
|
NodeDef* node = graph->GetNode(node_name);
|
||||||
if (node->attr().contains("output_shapes")) {
|
if (node->attr().contains(kOutputShapesAttr)) {
|
||||||
AttrValue output_shapes = node->attr().at("output_shapes");
|
AttrValue output_shapes = node->attr().at(kOutputShapesAttr);
|
||||||
for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) {
|
for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) {
|
||||||
if (!shape.unknown_rank() && shape.dim(0).size() != -1) {
|
if (!shape.unknown_rank() && shape.dim(0).size() != -1) {
|
||||||
shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_workers);
|
shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_workers);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(*node->mutable_attr())["output_shapes"] = output_shapes;
|
(*node->mutable_attr())[kOutputShapesAttr] = output_shapes;
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to get the batch_size input node for a give batch node.
|
||||||
|
int64 GetBatchSizeArgIndex(const NodeDef& batch_node) {
|
||||||
|
if (batch_node.op() == kExperimentalMapAndBatchOp ||
|
||||||
|
batch_node.op() == kMapAndBatchOp) {
|
||||||
|
// For MapAndBatch we take the 3rd last input.
|
||||||
|
return batch_node.input_size() - 3;
|
||||||
|
}
|
||||||
|
// For all the batching datasets the batch_size is input number 1 except for
|
||||||
|
// MapAndBatchDataset.
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MakeNewBatchSizeNode(const string& global_batch_size_name,
|
||||||
|
int64 num_workers, FunctionDef* fdef,
|
||||||
|
NodeDef** result) {
|
||||||
|
NodeDef* one_node;
|
||||||
|
TF_RETURN_IF_ERROR(AddConstInt64Node(1, fdef, &one_node));
|
||||||
|
NodeDef* num_workers_node;
|
||||||
|
TF_RETURN_IF_ERROR(AddConstInt64Node(num_workers, fdef, &num_workers_node));
|
||||||
|
|
||||||
|
NodeDef* numerator_node =
|
||||||
|
AddBinaryNode(global_batch_size_name,
|
||||||
|
strings::StrCat(num_workers_node->name(), ":output:0"),
|
||||||
|
kAddOp, DT_INT64, fdef);
|
||||||
|
numerator_node = AddBinaryNode(
|
||||||
|
strings::StrCat(numerator_node->name(), ":z:0"),
|
||||||
|
strings::StrCat(one_node->name(), ":output:0"), kSubOp, DT_INT64, fdef);
|
||||||
|
|
||||||
|
*result =
|
||||||
|
AddBinaryNode(strings::StrCat(numerator_node->name(), ":z:0"),
|
||||||
|
strings::StrCat(num_workers_node->name(), ":output:0"),
|
||||||
|
kTruncateDivOp, DT_INT64, fdef);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Given a "batch" dataset node, we replace the `batch_size` input with a new
|
// Given a "batch" dataset node, we replace the `batch_size` input with a new
|
||||||
// input that corresponds to the original input divided by `num_workers`. If
|
// input that corresponds to the original input divided by `num_workers`.
|
||||||
// `num_workers` does not divide `batch_size` evenly, the value is rounded up.
|
|
||||||
Status MutateBatchSize(const NodeDef& node, int64 num_workers,
|
Status MutateBatchSize(const NodeDef& node, int64 num_workers,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
// For all the batching datasets the batch_size is input number 1 except for
|
// For all the batching datasets the batch_size is input number 1 except for
|
||||||
// MapAndBatchDataset.
|
// MapAndBatchDataset.
|
||||||
int64 batch_size_arg_index = 1;
|
int64 batch_size_arg_index = GetBatchSizeArgIndex(node);
|
||||||
if (node.op() == "ExperimentalMapAndBatchDataset" ||
|
|
||||||
node.op() == "MapAndBatchDataset") {
|
|
||||||
// For MapAndBatch we take the 3rd last input.
|
|
||||||
batch_size_arg_index = node.input_size() - 3;
|
|
||||||
}
|
|
||||||
NodeDef* batch_size_node =
|
NodeDef* batch_size_node =
|
||||||
graph_utils::GetInputNode(node, *graph, batch_size_arg_index);
|
graph_utils::GetInputNode(node, *graph, batch_size_arg_index);
|
||||||
NodeDef* new_batch_size_node;
|
int64 batch_size;
|
||||||
if (batch_size_node->op() == kConstOp) {
|
TF_RETURN_IF_ERROR(
|
||||||
Tensor batch_size_tensor;
|
graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size));
|
||||||
TF_RETURN_IF_ERROR(
|
DCHECK_EQ(batch_size % num_workers, 0);
|
||||||
GetNodeAttr(*batch_size_node, "value", &batch_size_tensor));
|
batch_size = batch_size / num_workers;
|
||||||
if (!TensorShapeUtils::IsScalar(batch_size_tensor.shape())) {
|
NodeDef* new_batch_size_node =
|
||||||
return errors::Internal("Batch size node shape should be scalar");
|
graph_utils::AddScalarConstNode<int64>(batch_size, graph);
|
||||||
}
|
|
||||||
int64 batch_size = batch_size_tensor.scalar<int64>()();
|
|
||||||
batch_size = (batch_size + num_workers - 1) / num_workers;
|
|
||||||
new_batch_size_node =
|
|
||||||
graph_utils::AddScalarConstNode<int64>(batch_size, graph);
|
|
||||||
} else {
|
|
||||||
NodeDef* one_node = graph_utils::AddScalarConstNode<int64>(1, graph);
|
|
||||||
NodeDef* num_workers_node =
|
|
||||||
graph_utils::AddScalarConstNode<int64>(num_workers, graph);
|
|
||||||
NodeDef* numerator_node =
|
|
||||||
AddBinaryNode(batch_size_node->name(), num_workers_node->name(), kAddOp,
|
|
||||||
DT_INT64, graph);
|
|
||||||
numerator_node = AddBinaryNode(numerator_node->name(), one_node->name(),
|
|
||||||
kSubOp, DT_INT64, graph);
|
|
||||||
new_batch_size_node =
|
|
||||||
AddBinaryNode(numerator_node->name(), num_workers_node->name(),
|
|
||||||
kTruncateDivOp, DT_INT64, graph);
|
|
||||||
}
|
|
||||||
// We don't call UpdateFanouts here because CSE elimination might lead to
|
// We don't call UpdateFanouts here because CSE elimination might lead to
|
||||||
// multiple nodes sharing the same batch size constant node. This is also
|
// multiple nodes sharing the same batch size constant node. This is also
|
||||||
// why we don't delete batch_size_node as well.
|
// why we don't delete batch_size_node as well.
|
||||||
@ -331,6 +379,181 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status AddFlatMapNode(const string& input_dataset,
|
||||||
|
gtl::ArraySlice<string> other_arguments,
|
||||||
|
gtl::ArraySlice<DataType> t_arguments,
|
||||||
|
const FunctionDef& flat_map_fn,
|
||||||
|
const AttrValue& output_shapes,
|
||||||
|
const DataTypeVector& output_types,
|
||||||
|
FunctionLibraryDefinition* flib, MutableGraphView* graph,
|
||||||
|
NodeDef** result) {
|
||||||
|
TF_RETURN_IF_ERROR(flib->AddFunctionDef(flat_map_fn));
|
||||||
|
AttrValue f;
|
||||||
|
f.mutable_func()->set_name(flat_map_fn.signature().name());
|
||||||
|
|
||||||
|
NodeDef flat_map_node;
|
||||||
|
flat_map_node.set_op("FlatMapDataset");
|
||||||
|
flat_map_node.add_input(input_dataset);
|
||||||
|
for (const auto& arg : other_arguments) {
|
||||||
|
flat_map_node.add_input(arg);
|
||||||
|
}
|
||||||
|
AddNodeAttr("f", f, &flat_map_node);
|
||||||
|
AddNodeAttr("Targuments", t_arguments, &flat_map_node);
|
||||||
|
AddNodeAttr(kOutputShapesAttr, output_shapes, &flat_map_node);
|
||||||
|
AddNodeAttr(kOutputTypesAttr, output_types, &flat_map_node);
|
||||||
|
|
||||||
|
graph_utils::SetUniqueGraphNodeName("rebatch/flat_map", graph->graph(),
|
||||||
|
&flat_map_node);
|
||||||
|
*result = graph->AddNode(std::move(flat_map_node));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// def flat_map_fn(*batched_components):
|
||||||
|
// ds = tf.data.Dataset.from_tensor_slices(batched_components)
|
||||||
|
// return ds.batch(minibatch_size, drop_remainder=False)
|
||||||
|
Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, int64 num_workers,
|
||||||
|
FunctionDef* result) {
|
||||||
|
NodeDef* tensor_slice_node = result->add_node_def();
|
||||||
|
tensor_slice_node->set_op("TensorSliceDataset");
|
||||||
|
for (int i = 0; i < dtypes.size(); ++i) {
|
||||||
|
auto* input_arg = function_utils::AddFunctionInput(
|
||||||
|
strings::StrCat("args_", i), result, dtypes.at(i));
|
||||||
|
tensor_slice_node->add_input(input_arg->name());
|
||||||
|
}
|
||||||
|
AddNodeAttr(kTOutputTypesAttr, dtypes, tensor_slice_node);
|
||||||
|
|
||||||
|
// The output_shapes attr here doesn't make a difference, since we
|
||||||
|
// set the output_shapes of the external FlatMap node.
|
||||||
|
AttrValue shapes;
|
||||||
|
SetUnknownShapes(dtypes.size(), &shapes);
|
||||||
|
AddNodeAttr(kOutputShapesAttr, shapes, tensor_slice_node);
|
||||||
|
function_utils::SetUniqueFunctionNodeName("rebatch/from_tensor_slices",
|
||||||
|
result, tensor_slice_node);
|
||||||
|
|
||||||
|
NodeDef* false_node;
|
||||||
|
TF_RETURN_IF_ERROR(AddConstBoolNode(false, result, &false_node));
|
||||||
|
NodeDef* batch_node = result->add_node_def();
|
||||||
|
batch_node->set_op(kBatchV2Op);
|
||||||
|
batch_node->add_input(
|
||||||
|
strings::StrCat(tensor_slice_node->name(), ":handle:0"));
|
||||||
|
|
||||||
|
// `batch_size` input
|
||||||
|
// Here, we capture the original batch size from outside the flat map fn.
|
||||||
|
auto* original_batch_size =
|
||||||
|
function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64);
|
||||||
|
NodeDef* new_batch_size;
|
||||||
|
TF_RETURN_IF_ERROR(MakeNewBatchSizeNode(
|
||||||
|
original_batch_size->name(), num_workers, result, &new_batch_size));
|
||||||
|
batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0"));
|
||||||
|
|
||||||
|
// `drop_remainder` input
|
||||||
|
batch_node->add_input(strings::StrCat(false_node->name(), ":output:0"));
|
||||||
|
AddNodeAttr(kOutputTypesAttr, dtypes, batch_node);
|
||||||
|
AddNodeAttr(kOutputShapesAttr, shapes, batch_node);
|
||||||
|
function_utils::SetUniqueFunctionNodeName("rebatch/batch", result,
|
||||||
|
batch_node);
|
||||||
|
function_utils::AddFunctionOutputWithUniqueName(
|
||||||
|
"output", strings::StrCat(batch_node->name(), ":handle:0"), result,
|
||||||
|
DT_VARIANT);
|
||||||
|
// Because TensorSliceDataset is stateful, we set the function to stateful.
|
||||||
|
result->mutable_signature()->set_is_stateful(true);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrite graph to add
|
||||||
|
// `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x).
|
||||||
|
// batch(minibatch_size, drop_remainder=False))`
|
||||||
|
// after the batch node. This ensures that the sum of the minibatch sizes
|
||||||
|
// in a step adds up to the global batch size. However, since this adds
|
||||||
|
// additional data copies (both from_tensor_slices and batch), we only use
|
||||||
|
// this approach when necessary, i.e. when we need to drop remainder on the
|
||||||
|
// global batch, or when the global batch size does not divide num_workers
|
||||||
|
// evenly.
|
||||||
|
Status AppendFlatMap(const NodeDef& batch_node, int64 num_workers,
|
||||||
|
FunctionLibraryDefinition* flib, MutableGraphView* graph) {
|
||||||
|
// `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x).
|
||||||
|
// batch(minibatch_size, drop_remainder=False))`
|
||||||
|
FunctionDef flat_map_fn;
|
||||||
|
FunctionDefLibrary lib = flib->ToProto();
|
||||||
|
graph_utils::SetUniqueGraphFunctionName("rebatch/flat_map_fn", &lib,
|
||||||
|
&flat_map_fn);
|
||||||
|
DataTypeVector dtypes;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
graph_utils::GetDatasetOutputTypesAttr(batch_node, &dtypes));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CreateFlatMapFnWithBatch(dtypes, num_workers, &flat_map_fn));
|
||||||
|
|
||||||
|
int64 batch_size_index = GetBatchSizeArgIndex(batch_node);
|
||||||
|
|
||||||
|
NodeDef* flat_map_node;
|
||||||
|
|
||||||
|
AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr);
|
||||||
|
for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) {
|
||||||
|
if (!shape.unknown_rank() && shape.dim(0).size() != -1) {
|
||||||
|
// Because the flat map function uses drop_remainder = False,
|
||||||
|
// the shape might be unknown
|
||||||
|
auto old_dim = shape.dim(0).size();
|
||||||
|
auto new_dim = old_dim % num_workers == 0 ? old_dim / num_workers : -1;
|
||||||
|
shape.mutable_dim(0)->set_size(new_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"),
|
||||||
|
{batch_node.input(batch_size_index)},
|
||||||
|
{DT_INT64}, flat_map_fn, output_shapes,
|
||||||
|
dtypes, flib, graph, &flat_map_node));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
graph->UpdateFanouts(batch_node.name(), flat_map_node->name()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// There are several things we do here, depending on the values of
|
||||||
|
// batch_size and drop_remainder.
|
||||||
|
// (1) If batch size is known and divisible by num_workers, and drop_remainder
|
||||||
|
// is known to be False, we mutate the batch size directly.
|
||||||
|
// .batch(global_batch_size) -> .batch(global_batch_size // num_workers)
|
||||||
|
// (2) Otherwise, we add a flat_map transformation to preserve the global batch
|
||||||
|
// size across the workers and to preserve the drop remainder behavior.
|
||||||
|
bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, int64 num_workers,
|
||||||
|
MutableGraphView* graph) {
|
||||||
|
int64 batch_size_arg_index = GetBatchSizeArgIndex(batch_node);
|
||||||
|
NodeDef* batch_size_node =
|
||||||
|
graph_utils::GetInputNode(batch_node, *graph, batch_size_arg_index);
|
||||||
|
|
||||||
|
int64 batch_size;
|
||||||
|
Status s =
|
||||||
|
graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size);
|
||||||
|
// If batch size is unknown or indivisible by num workers, we don't
|
||||||
|
// mutate it directly
|
||||||
|
if (!s.ok() || batch_size % num_workers != 0) return false;
|
||||||
|
|
||||||
|
if (batch_node.op() == kBatchOp || batch_node.op() == kPaddedBatchOp) {
|
||||||
|
// These ops don't have a `drop_remainder` input, and behave like
|
||||||
|
// drop_remainder is False.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// drop_remainder is the final input on the other batch nodes.
|
||||||
|
NodeDef* drop_remainder_node = graph_utils::GetInputNode(
|
||||||
|
batch_node, *graph, batch_node.input_size() - 1);
|
||||||
|
bool drop_remainder;
|
||||||
|
s = graph_utils::GetScalarConstNodeValue(*drop_remainder_node,
|
||||||
|
&drop_remainder);
|
||||||
|
return s.ok() && !drop_remainder;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RewriteBatchNode(const NodeDef& batch_node, int64 num_workers,
|
||||||
|
FunctionLibraryDefinition* flib,
|
||||||
|
MutableGraphView* graph) {
|
||||||
|
if (ShouldMutateBatchSizeDirectly(batch_node, num_workers, graph)) {
|
||||||
|
return MutateBatchSize(batch_node, num_workers, graph);
|
||||||
|
}
|
||||||
|
return AppendFlatMap(batch_node, num_workers, flib, graph);
|
||||||
|
}
|
||||||
|
|
||||||
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
||||||
bool use_fallback, GraphDef* output);
|
bool use_fallback, GraphDef* output);
|
||||||
|
|
||||||
@ -346,7 +569,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
|
|||||||
bool use_fallback, FunctionLibraryDefinition* flib,
|
bool use_fallback, FunctionLibraryDefinition* flib,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
|
if (IsDatasetNodeOfType(node, kBatchDatasetOps)) {
|
||||||
TF_RETURN_IF_ERROR(MutateBatchSize(node, num_workers, graph));
|
TF_RETURN_IF_ERROR(RewriteBatchNode(node, num_workers, flib, graph));
|
||||||
} else if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
|
} else if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
|
||||||
// For all multiple input datasets, all inputs are datasets themselves.
|
// For all multiple input datasets, all inputs are datasets themselves.
|
||||||
for (int i = 0; i < node.input_size(); ++i) {
|
for (int i = 0; i < node.input_size(); ++i) {
|
||||||
@ -403,7 +626,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add nodes to the function to reshape arg to shape (-1, new_batch_dim, ...)
|
// Add nodes to the function to reshape arg to shape (-1, new_batch_dim, ...)
|
||||||
Status ReshapeComponent(int new_batch_dim, StringPiece arg, DataType dtype,
|
Status ReshapeComponent(int new_batch_dim, const string& arg, DataType dtype,
|
||||||
FunctionDef* fdef, string* result) {
|
FunctionDef* fdef, string* result) {
|
||||||
// Const with value [0]
|
// Const with value [0]
|
||||||
NodeDef* const_vec_0;
|
NodeDef* const_vec_0;
|
||||||
@ -453,47 +676,50 @@ Status ReshapeComponent(int new_batch_dim, StringPiece arg, DataType dtype,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateFlatMapFn(int new_batch_dim, const AttrValue& types,
|
// def flat_map_fn(*batched_components):
|
||||||
FunctionDef* result) {
|
// return tf.data.Dataset.from_tensor_slices(
|
||||||
|
// [tf.reshape(c, (-1, new_batch_size, ...))
|
||||||
|
// for c in batched_components])
|
||||||
|
Status CreateFlatMapFnWithReshape(int new_batch_dim,
|
||||||
|
const DataTypeVector& types,
|
||||||
|
FunctionDef* result) {
|
||||||
std::vector<NodeDefBuilder::NodeOut> tensor_slice_dataset_inputs;
|
std::vector<NodeDefBuilder::NodeOut> tensor_slice_dataset_inputs;
|
||||||
|
|
||||||
// For each component of the dataset, we reshape it from shape
|
// For each component of the dataset, we reshape it from shape
|
||||||
// (old_batch_size, ...) to (-1, new_batch_size, ...)
|
// (old_batch_size, ...) to (-1, new_batch_size, ...)
|
||||||
// where new_batch_size = (old_batch_size + num_workers - 1) // num_workers
|
// where new_batch_size = (old_batch_size + num_workers - 1) // num_workers
|
||||||
for (int i = 0; i < types.list().type_size(); ++i) {
|
for (int i = 0; i < types.size(); ++i) {
|
||||||
string arg = strings::StrCat("args_", i);
|
auto* input_arg = function_utils::AddFunctionInput(
|
||||||
auto* input_arg = result->mutable_signature()->mutable_input_arg()->Add();
|
strings::StrCat("args_", i), result, types.at(i));
|
||||||
input_arg->set_type(types.list().type(i));
|
|
||||||
input_arg->set_name(arg);
|
|
||||||
|
|
||||||
string reshape_node_name;
|
string reshape_node_name;
|
||||||
TF_RETURN_IF_ERROR(ReshapeComponent(
|
TF_RETURN_IF_ERROR(ReshapeComponent(new_batch_dim, input_arg->name(),
|
||||||
new_batch_dim, arg, types.list().type(i), result, &reshape_node_name));
|
types.at(i), result,
|
||||||
|
&reshape_node_name));
|
||||||
|
|
||||||
tensor_slice_dataset_inputs.emplace_back(
|
tensor_slice_dataset_inputs.emplace_back(
|
||||||
strings::StrCat(reshape_node_name, ":output"), 0, types.list().type(i));
|
strings::StrCat(reshape_node_name, ":output"), 0, types.at(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// The output_shapes attr here doesn't make a difference, since we
|
// The output_shapes attr here doesn't make a difference, since we
|
||||||
// set the output_shapes of the external FlatMap node.
|
// set the output_shapes of the external FlatMap node.
|
||||||
AttrValue shapes;
|
AttrValue shapes;
|
||||||
SetUnknownShapes(types.list().type_size(), &shapes);
|
SetUnknownShapes(types.size(), &shapes);
|
||||||
|
|
||||||
NodeDef* tensor_slice_dataset = result->add_node_def();
|
NodeDef* tensor_slice_dataset = result->add_node_def();
|
||||||
TF_RETURN_IF_ERROR(NodeDefBuilder("", "TensorSliceDataset")
|
TF_RETURN_IF_ERROR(NodeDefBuilder("", "TensorSliceDataset")
|
||||||
.Input(tensor_slice_dataset_inputs)
|
.Input(tensor_slice_dataset_inputs)
|
||||||
.Attr("Toutput_types", types)
|
.Attr("Toutput_types", types)
|
||||||
.Attr("output_shapes", shapes)
|
.Attr(kOutputShapesAttr, shapes)
|
||||||
.Finalize(tensor_slice_dataset));
|
.Finalize(tensor_slice_dataset));
|
||||||
function_utils::SetUniqueFunctionNodeName("rebatch/tensor_slice_dataset",
|
function_utils::SetUniqueFunctionNodeName("rebatch/tensor_slice_dataset",
|
||||||
result, tensor_slice_dataset);
|
result, tensor_slice_dataset);
|
||||||
|
|
||||||
auto* output_arg = result->mutable_signature()->mutable_output_arg()->Add();
|
function_utils::AddFunctionOutputWithUniqueName(
|
||||||
output_arg->set_name("output");
|
"output", strings::StrCat(tensor_slice_dataset->name(), ":handle:0"),
|
||||||
output_arg->set_type(DT_VARIANT);
|
result, DT_VARIANT);
|
||||||
|
// Because TensorSliceDataset is stateful, we set the function to stateful.
|
||||||
result->mutable_signature()->set_is_stateful(true);
|
result->mutable_signature()->set_is_stateful(true);
|
||||||
(*result->mutable_ret())["output"] =
|
|
||||||
strings::StrCat(tensor_slice_dataset->name(), ":handle:0");
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -525,12 +751,12 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers,
|
|||||||
// because of the use of the "Reshape" op. This ensures that the error is
|
// because of the use of the "Reshape" op. This ensures that the error is
|
||||||
// surfaced correctly.
|
// surfaced correctly.
|
||||||
AttrValue output_shapes;
|
AttrValue output_shapes;
|
||||||
if (!fetch_node->attr().contains("output_shapes")) {
|
if (!fetch_node->attr().contains(kOutputShapesAttr)) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Cannot use rebatching fallback without output_shapes attr. Node: ",
|
"Cannot use rebatching fallback without output_shapes attr. Node: ",
|
||||||
fetch_node->name(), " Op: ", fetch_node->op());
|
fetch_node->name(), " Op: ", fetch_node->op());
|
||||||
} else {
|
} else {
|
||||||
output_shapes = fetch_node->attr().at("output_shapes");
|
output_shapes = fetch_node->attr().at(kOutputShapesAttr);
|
||||||
}
|
}
|
||||||
int batch_dim;
|
int batch_dim;
|
||||||
TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim));
|
TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim));
|
||||||
@ -543,35 +769,25 @@ Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_workers,
|
|||||||
// Create the flat map fn
|
// Create the flat map fn
|
||||||
FunctionDef flat_map_fn;
|
FunctionDef flat_map_fn;
|
||||||
FunctionDefLibrary lib = flib->ToProto();
|
FunctionDefLibrary lib = flib->ToProto();
|
||||||
graph_utils::SetUniqueGraphFunctionName("flat_map_fn", &lib, &flat_map_fn);
|
graph_utils::SetUniqueGraphFunctionName("rebatch/flat_map_fn", &lib,
|
||||||
|
&flat_map_fn);
|
||||||
|
|
||||||
// Get types of input arguments from the output types of the final dataset.
|
// Get types of input arguments from the output types of the final dataset.
|
||||||
AttrValue output_types;
|
DataTypeVector output_types;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types));
|
graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types));
|
||||||
|
TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_workers,
|
||||||
|
output_types, &flat_map_fn));
|
||||||
|
|
||||||
|
NodeDef* flat_map_node;
|
||||||
|
TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"),
|
||||||
|
{}, {}, flat_map_fn, output_shapes,
|
||||||
|
output_types, flib, graph, &flat_map_node));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CreateFlatMapFn(batch_dim / num_workers, output_types, &flat_map_fn));
|
UpdateOutputShapes(flat_map_node->name(), num_workers, graph));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(flib->AddFunctionDef(flat_map_fn));
|
|
||||||
AttrValue fn;
|
|
||||||
fn.mutable_func()->set_name(flat_map_fn.signature().name());
|
|
||||||
|
|
||||||
NodeDef flat_map_node;
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
NodeDefBuilder("", "FlatMapDataset")
|
graph->UpdateFanouts(fetch_node->name(), flat_map_node->name()));
|
||||||
.Input(fetch_node->name(), 0, DT_VARIANT)
|
|
||||||
.Input(std::vector<NodeDefBuilder::NodeOut>()) // other_arguments
|
|
||||||
.Attr("f", fn)
|
|
||||||
.Attr("Targuments", std::vector<DataType>())
|
|
||||||
.Attr("output_types", output_types)
|
|
||||||
.Attr("output_shapes", output_shapes)
|
|
||||||
.Finalize(&flat_map_node));
|
|
||||||
graph_utils::SetUniqueGraphNodeName("rebatch/flat_map", graph->graph(),
|
|
||||||
&flat_map_node);
|
|
||||||
NodeDef* added = graph->AddNode(std::move(flat_map_node));
|
|
||||||
TF_RETURN_IF_ERROR(UpdateOutputShapes(added->name(), num_workers, graph));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(graph->UpdateFanouts(fetch_node->name(), added->name()));
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -593,8 +809,8 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers,
|
|||||||
RecursivelyHandleOp(*sink_node, num_workers, use_fallback, &flib, &graph);
|
RecursivelyHandleOp(*sink_node, num_workers, use_fallback, &flib, &graph);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
if (use_fallback) {
|
if (use_fallback) {
|
||||||
VLOG(1) << "Couldn't find a batch transformation. Using a fallback method"
|
VLOG(1) << "Failed to rebatch by rewriting the batch transformation ("
|
||||||
" to rebatch dataset.";
|
<< s << "). Using a fallback method instead.";
|
||||||
// If RecursivelyHandleOp fails, we reset `graph` to use the original,
|
// If RecursivelyHandleOp fails, we reset `graph` to use the original,
|
||||||
// graph, since that function may have mutated `graph`.
|
// graph, since that function may have mutated `graph`.
|
||||||
*output = item.graph;
|
*output = item.graph;
|
||||||
|
@ -48,96 +48,98 @@ def _flat_shapes(dataset):
|
|||||||
return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
|
return nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
|
||||||
|
|
||||||
|
|
||||||
@parameterized.named_parameters(("WithDropRemainder", True),
|
|
||||||
("WithoutDropRemainder", False))
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class RebatchDatasetTest(test_base.DatasetTestBase):
|
class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
|
drop_remainder_cases = [("WithDropRemainder", True),
|
||||||
|
("WithoutDropRemainder", False)]
|
||||||
|
|
||||||
|
@parameterized.named_parameters(drop_remainder_cases)
|
||||||
def testBasic(self, drop_remainder):
|
def testBasic(self, drop_remainder):
|
||||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||||
32, drop_remainder=drop_remainder)
|
32, drop_remainder=drop_remainder)
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[8] if drop_remainder else [None]],
|
||||||
[[32 if drop_remainder else None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual(
|
|
||||||
[[8 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
|
|
||||||
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testScalarInputError(self, _):
|
def testScalarInputError(self):
|
||||||
dataset = dataset_ops.Dataset.range(1024)
|
dataset = dataset_ops.Dataset.range(1024)
|
||||||
|
distribute._RebatchDataset(dataset.batch(4), num_workers=4)
|
||||||
with self.assertRaisesRegexp(ValueError, "at least one dimension"):
|
with self.assertRaisesRegexp(ValueError, "at least one dimension"):
|
||||||
distribute._RebatchDataset(dataset, num_workers=4)
|
distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
|
|
||||||
def testNotDivisible(self, drop_remainder):
|
@parameterized.named_parameters(drop_remainder_cases)
|
||||||
|
def testBatchNotDivisibleByNumWorkers(self, drop_remainder):
|
||||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
dataset = dataset_ops.Dataset.range(1024).batch(
|
||||||
32, drop_remainder=drop_remainder)
|
32, drop_remainder=drop_remainder)
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
|
||||||
expected_output = [[k for k in range(i, i + 7)] for i in range(0, 1022, 7)] # pylint: disable=g-complex-comprehension
|
self.assertEqual([[None]],
|
||||||
if not drop_remainder:
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
expected_output.append([1022, 1023])
|
expected_output = []
|
||||||
|
i = 0
|
||||||
|
for _ in range(32): # number of steps
|
||||||
|
# first four minibatches have seven elements
|
||||||
|
for _ in range(4):
|
||||||
|
expected_output.append([k for k in range(i, i + 7)])
|
||||||
|
i += 7
|
||||||
|
# last minibatch has four elements
|
||||||
|
expected_output.append([k for k in range(i, i + 4)])
|
||||||
|
i += 4
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testTupleOutput(self, drop_remainder):
|
def testTupleOutput(self):
|
||||||
dataset = (
|
dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
|
||||||
dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(
|
|
||||||
32, drop_remainder=drop_remainder))
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension
|
expected_output = [([k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension
|
||||||
[k for k in range(i, i + 8)])
|
[k for k in range(i, i + 8)])
|
||||||
for i in range(0, 1024, 8)]
|
for i in range(0, 1024, 8)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testNestedDictionaryOutput(self, drop_remainder):
|
def testNestedDictionaryOutput(self):
|
||||||
dataset = dataset_ops.Dataset.range(1024).map(
|
dataset = dataset_ops.Dataset.range(1024).map(
|
||||||
lambda x: {"a": x, "b": {"c": x}}).batch(
|
lambda x: {"a": x, "b": {"c": x}}).batch(32)
|
||||||
32, drop_remainder=drop_remainder)
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension
|
expected_output = [{"a": [k for k in range(i, i + 8)], # pylint: disable=g-complex-comprehension
|
||||||
"b": {"c": [k for k in range(i, i + 8)]}}
|
"b": {"c": [k for k in range(i, i + 8)]}}
|
||||||
for i in range(0, 1024, 8)]
|
for i in range(0, 1024, 8)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testFinalPartialBatchOriginal(self, drop_remainder):
|
@parameterized.named_parameters(drop_remainder_cases)
|
||||||
|
def testFinalPartialBatch(self, drop_remainder):
|
||||||
dataset = dataset_ops.Dataset.range(1032).batch(
|
dataset = dataset_ops.Dataset.range(1032).batch(
|
||||||
32, drop_remainder=drop_remainder)
|
32, drop_remainder=drop_remainder)
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[8] if drop_remainder else [None]],
|
||||||
[[32 if drop_remainder else None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual(
|
|
||||||
[[8 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
|
|
||||||
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1032, 8)] # pylint: disable=g-complex-comprehension
|
# if drop_remainder, the final partial batch is dropped, even though it
|
||||||
|
# makes up a complete minibatch.
|
||||||
|
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
|
||||||
|
if not drop_remainder:
|
||||||
|
expected_output.append([k for k in range(1024, 1032)])
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(drop_remainder_cases)
|
||||||
def testFinalPartialBatchAfterRebatch(self, drop_remainder):
|
def testFinalPartialBatchAfterRebatch(self, drop_remainder):
|
||||||
dataset = dataset_ops.Dataset.range(34).batch(
|
dataset = dataset_ops.Dataset.range(34).batch(
|
||||||
32, drop_remainder=drop_remainder)
|
32, drop_remainder=drop_remainder)
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[8] if drop_remainder else [None]],
|
||||||
[[32 if drop_remainder else None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual(
|
|
||||||
[[8 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
|
|
||||||
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension
|
||||||
if not drop_remainder:
|
if not drop_remainder:
|
||||||
expected_output += [[32, 33]]
|
expected_output += [[32, 33]]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testMultipleBatches(self, drop_remainder):
|
def testMultipleBatches(self):
|
||||||
dataset = dataset_ops.Dataset.range(128).batch(
|
dataset = dataset_ops.Dataset.range(128).batch(4).batch(8)
|
||||||
4, drop_remainder=drop_remainder)
|
self.assertEqual([[None, None]],
|
||||||
dataset = dataset.batch(8, drop_remainder=drop_remainder)
|
[ts.as_list() for ts in _flat_shapes(dataset)])
|
||||||
self.assertEqual(
|
|
||||||
[[8, 4]] if drop_remainder else [[None, None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
# Each element is a list of 8 elements where each element is a list of 4.
|
# Each element is a list of 8 elements where each element is a list of 4.
|
||||||
expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension
|
expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension
|
||||||
for j in range(i, i + 32, 4)] # generates 8 elements
|
for j in range(i, i + 32, 4)] # generates 8 elements
|
||||||
@ -145,39 +147,30 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
self.assertDatasetProduces(dataset, expected_output)
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, 4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, 4)
|
||||||
self.assertEqual(
|
self.assertEqual([[None, None]],
|
||||||
[[2, 4]] if drop_remainder else [[None, None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
# Each element is a list of 2 elements where each element is a list of 4.
|
# Each element is a list of 2 elements where each element is a list of 4.
|
||||||
expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension
|
expected_output = [[[j, j + 1, j + 2, j + 3] # pylint: disable=g-complex-comprehension
|
||||||
for j in range(i, i + 8, 4)] # generates 2 elements
|
for j in range(i, i + 8, 4)] # generates 2 elements
|
||||||
for i in range(0, 128, 8)]
|
for i in range(0, 128, 8)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testMapAndBatch(self, drop_remainder):
|
def testMapAndBatch(self):
|
||||||
dataset = dataset_ops.Dataset.range(1024).apply(
|
dataset = dataset_ops.Dataset.range(1024).apply(
|
||||||
batching.map_and_batch(
|
batching.map_and_batch(math_ops.square, 32))
|
||||||
math_ops.square, 32, drop_remainder=drop_remainder))
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[None]],
|
||||||
[[32 if drop_remainder else None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual(
|
|
||||||
[[8 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k**2 for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||||
for i in range(0, 1024, 8)]
|
for i in range(0, 1024, 8)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testMapAndBatchWithCapturedInput(self, drop_remainder):
|
def testMapAndBatchWithCapturedInput(self):
|
||||||
captured_t = variables.Variable(42)
|
captured_t = variables.Variable(42)
|
||||||
dataset = dataset_ops.Dataset.range(1024).apply(
|
dataset = dataset_ops.Dataset.range(1024).apply(
|
||||||
batching.map_and_batch(
|
batching.map_and_batch(lambda x: captured_t, 32))
|
||||||
lambda x: captured_t, 32, drop_remainder=drop_remainder))
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual([[32 if drop_remainder else None]],
|
self.assertEqual([[None]],
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual([[8 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
expected_output = [[42 for _ in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||||
for i in range(0, 1024, 8)]
|
for i in range(0, 1024, 8)]
|
||||||
@ -185,22 +178,19 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
self.assertDatasetProduces(
|
self.assertDatasetProduces(
|
||||||
rebatched_dataset, expected_output, requires_initialization=True)
|
rebatched_dataset, expected_output, requires_initialization=True)
|
||||||
|
|
||||||
def testPaddedBatch(self, drop_remainder):
|
def testPaddedBatch(self):
|
||||||
dataset = dataset_ops.Dataset.range(128).batch(4).padded_batch(
|
dataset = dataset_ops.Dataset.range(128).batch(
|
||||||
8, padded_shapes=[5], drop_remainder=drop_remainder)
|
4, drop_remainder=True).padded_batch(
|
||||||
|
8, padded_shapes=[5])
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
|
||||||
[[8, 5]] if drop_remainder else [[None, 5]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
# Each element is a list of 8 elements in which each element is a list of 5
|
# Each element is a list of 8 elements in which each element is a list of 5
|
||||||
# elements, first four are numbers and the last one is a padded zero.
|
# elements, first four are numbers and the last one is a padded zero.
|
||||||
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
|
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
|
||||||
for j in range(i, i + 32, 4)] # generates 8 elements
|
for j in range(i, i + 32, 4)] # generates 8 elements
|
||||||
for i in range(0, 128, 32)]
|
for i in range(0, 128, 32)]
|
||||||
self.assertDatasetProduces(dataset, expected_output)
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
self.assertEqual(
|
self.assertEqual([[None, 5]],
|
||||||
[[2, 5]] if drop_remainder else [[None, 5]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
# Each element is a list of 2 elements in which each element is a list of 5
|
# Each element is a list of 2 elements in which each element is a list of 5
|
||||||
# elements, first four are numbers and the last one is a padded zero.
|
# elements, first four are numbers and the last one is a padded zero.
|
||||||
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
|
expected_output = [[[j, j + 1, j + 2, j + 3, 0] # pylint: disable=g-complex-comprehension
|
||||||
@ -208,32 +198,22 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
for i in range(0, 128, 8)]
|
for i in range(0, 128, 8)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testConcatenate(self, drop_remainder):
|
def testConcatenate(self):
|
||||||
dataset1 = dataset_ops.Dataset.range(64).batch(
|
dataset1 = dataset_ops.Dataset.range(64).batch(8)
|
||||||
8, drop_remainder=drop_remainder)
|
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||||
dataset2 = dataset_ops.Dataset.range(32).batch(
|
|
||||||
8, drop_remainder=drop_remainder)
|
|
||||||
dataset = dataset1.concatenate(dataset2)
|
dataset = dataset1.concatenate(dataset2)
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[None]],
|
||||||
[[8 if drop_remainder else None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual(
|
|
||||||
[[2 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
expected_output = ([[i, i + 1] for i in range(0, 64, 2)] +
|
expected_output = ([[i, i + 1] for i in range(0, 64, 2)] +
|
||||||
[[i, i + 1] for i in range(0, 32, 2)])
|
[[i, i + 1] for i in range(0, 32, 2)])
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testConcatenateDifferentShapes(self, drop_remainder):
|
def testConcatenateDifferentShapes(self):
|
||||||
dataset1 = dataset_ops.Dataset.range(64).batch(
|
dataset1 = dataset_ops.Dataset.range(64).batch(16)
|
||||||
16, drop_remainder=drop_remainder)
|
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||||
dataset2 = dataset_ops.Dataset.range(32).batch(
|
|
||||||
8, drop_remainder=drop_remainder)
|
|
||||||
dataset = dataset1.concatenate(dataset2)
|
dataset = dataset1.concatenate(dataset2)
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
|
||||||
[[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
[[None]],
|
[[None]],
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
@ -241,73 +221,56 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
[[i, i + 1] for i in range(0, 32, 2)])
|
[[i, i + 1] for i in range(0, 32, 2)])
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testZip(self, drop_remainder):
|
def testZip(self):
|
||||||
dataset1 = dataset_ops.Dataset.range(64).batch(
|
dataset1 = dataset_ops.Dataset.range(64).batch(8)
|
||||||
8, drop_remainder=drop_remainder)
|
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||||
dataset2 = dataset_ops.Dataset.range(32).batch(
|
|
||||||
8, drop_remainder=drop_remainder)
|
|
||||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[None], [None]],
|
||||||
[[8], [8]] if drop_remainder else [[None], [None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual(
|
|
||||||
[[2], [2]] if drop_remainder else [[None], [None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
|
expected_output = [([i, i + 1], [i, i + 1]) for i in range(0, 32, 2)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testZipDifferentShapes(self, drop_remainder):
|
def testZipDifferentShapes(self):
|
||||||
dataset1 = dataset_ops.Dataset.range(64).batch(
|
dataset1 = dataset_ops.Dataset.range(64).batch(16)
|
||||||
16, drop_remainder=drop_remainder)
|
dataset2 = dataset_ops.Dataset.range(32).batch(8)
|
||||||
dataset2 = dataset_ops.Dataset.range(32).batch(
|
|
||||||
8, drop_remainder=drop_remainder)
|
|
||||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[None], [None]],
|
||||||
[[16], [8]] if drop_remainder else [[None], [None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual(
|
|
||||||
[[4], [2]] if drop_remainder else [[None], [None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1])
|
expected_output = [([2 * i, 2 * i + 1, 2 * i + 2, 2 * i + 3], [i, i + 1])
|
||||||
for i in range(0, 32, 2)]
|
for i in range(0, 32, 2)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testUnsupportedTransformError(self, drop_remainder):
|
def testUnsupportedTransformError(self):
|
||||||
dataset = dataset_ops.Dataset.range(1024).batch(
|
dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10))
|
||||||
32, drop_remainder=drop_remainder).apply(sleep.sleep(10))
|
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
rebatched_dataset = distribute._RebatchDataset(
|
rebatched_dataset = distribute._RebatchDataset(
|
||||||
dataset, num_workers=4, use_fallback=False)
|
dataset, num_workers=4, use_fallback=False)
|
||||||
next_element = self.getNext(rebatched_dataset)
|
next_element = self.getNext(rebatched_dataset)
|
||||||
self.evaluate(next_element())
|
self.evaluate(next_element())
|
||||||
|
|
||||||
def testUnsupportedTransformInFlatMapError(self, drop_remainder):
|
def testUnsupportedTransformInFlatMapError(self):
|
||||||
dataset = dataset_ops.Dataset.range(2).flat_map(
|
dataset = dataset_ops.Dataset.range(2).flat_map(
|
||||||
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||||
32, drop_remainder=drop_remainder).apply(sleep.sleep(10)))
|
32).apply(sleep.sleep(10)))
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
rebatched_dataset = distribute._RebatchDataset(
|
rebatched_dataset = distribute._RebatchDataset(
|
||||||
dataset, num_workers=4, use_fallback=False)
|
dataset, num_workers=4, use_fallback=False)
|
||||||
next_element = self.getNext(rebatched_dataset)
|
next_element = self.getNext(rebatched_dataset)
|
||||||
self.evaluate(next_element())
|
self.evaluate(next_element())
|
||||||
|
|
||||||
def testFlatMapBatching(self, drop_remainder):
|
def testFlatMapBatching(self):
|
||||||
dataset = dataset_ops.Dataset.range(
|
dataset = dataset_ops.Dataset.range(2).flat_map(
|
||||||
2).flat_map(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||||
32, drop_remainder=drop_remainder))
|
32))
|
||||||
self.assertEqual(
|
|
||||||
[[32 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
# Two elements where each element is range(32)
|
# Two elements where each element is range(32)
|
||||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||||
self.assertDatasetProduces(dataset, expected_output)
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[None]],
|
||||||
[[8 if drop_remainder else None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
# Two elements where each element is a list of 4 elements where each element
|
# Two elements where each element is a list of 4 elements where each element
|
||||||
# is a list of 8.
|
# is a list of 8.
|
||||||
expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||||
@ -315,21 +278,18 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
for i in range(0, 32, 8)] # generates 4 elements
|
for i in range(0, 32, 8)] # generates 4 elements
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testInterleaveBatching(self, drop_remainder):
|
def testInterleaveBatching(self):
|
||||||
dataset = dataset_ops.Dataset.range(
|
dataset = dataset_ops.Dataset.range(2).interleave(
|
||||||
2).interleave(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||||
32, drop_remainder=drop_remainder), cycle_length=2)
|
32),
|
||||||
self.assertEqual(
|
cycle_length=2)
|
||||||
[[32 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
# Two elements where each element is range(32)
|
# Two elements where each element is range(32)
|
||||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||||
self.assertDatasetProduces(dataset, expected_output)
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[None]],
|
||||||
[[8 if drop_remainder else None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
# List of 4 elements where each element is a list of 8 numbering from 0 to
|
# List of 4 elements where each element is a list of 8 numbering from 0 to
|
||||||
# 31 repeated twice.
|
# 31 repeated twice.
|
||||||
expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||||
@ -337,22 +297,19 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
for _ in range(2)]
|
for _ in range(2)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testParallelInterleaveBatching(self, drop_remainder):
|
def testParallelInterleaveBatching(self):
|
||||||
dataset = dataset_ops.Dataset.range(
|
dataset = dataset_ops.Dataset.range(2).interleave(
|
||||||
2).interleave(lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
|
||||||
32, drop_remainder=drop_remainder), cycle_length=2,
|
32),
|
||||||
num_parallel_calls=2)
|
cycle_length=2,
|
||||||
self.assertEqual(
|
num_parallel_calls=2)
|
||||||
[[32 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
# Two elements where each element is range(32)
|
# Two elements where each element is range(32)
|
||||||
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k for k in range(32)] for _ in range(2)] # pylint: disable=g-complex-comprehension
|
||||||
self.assertDatasetProduces(dataset, expected_output)
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
self.assertEqual(
|
self.assertEqual([[None]],
|
||||||
[[8 if drop_remainder else None]],
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
|
||||||
# List of 4 elements where each element is a list of 8 numbering from 0 to
|
# List of 4 elements where each element is a list of 8 numbering from 0 to
|
||||||
# 31 repeated twice in collated fashion i.e [0...8], [0...8] etc.
|
# 31 repeated twice in collated fashion i.e [0...8], [0...8] etc.
|
||||||
expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
|
||||||
@ -360,17 +317,17 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
for _ in range(2)]
|
for _ in range(2)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testGroupByWindowStaticBatch(self, drop_remainder):
|
def testGroupByWindowStaticBatch(self):
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(
|
dataset = dataset_ops.Dataset.from_tensor_slices(
|
||||||
[[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
|
[[array_ops.constant(i, dtype=dtypes.int64)] * 3 for i in range(40)])
|
||||||
reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda
|
reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda
|
||||||
batch_size=10, drop_remainder=drop_remainder)
|
batch_size=10)
|
||||||
dataset = dataset.apply(
|
dataset = dataset.apply(
|
||||||
grouping.group_by_window(
|
grouping.group_by_window(
|
||||||
key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10))
|
key_func=lambda x: x[0] % 4, reduce_func=reduce_fn, window_size=10))
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||||
|
|
||||||
self.assertEqual([[5, 3] if drop_remainder else [None, 3]],
|
self.assertEqual([[None, 3]],
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
# pylint: disable=g-complex-comprehension
|
# pylint: disable=g-complex-comprehension
|
||||||
expected_output = [[[j + i * 4 + k * 20] * 3
|
expected_output = [[[j + i * 4 + k * 20] * 3
|
||||||
@ -379,10 +336,15 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
for k in range(2)]
|
for k in range(2)]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
def testGroupByWindowDynamicBatch(self, drop_remainder):
|
def testGroupByWindowDynamicBatch(self):
|
||||||
|
# {0, 1, 0, 1, ...}
|
||||||
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
||||||
reduce_fn = lambda bucket_id, ds: ds.batch( # pylint: disable=g-long-lambda
|
|
||||||
batch_size=(bucket_id + 1) * 5, drop_remainder=drop_remainder)
|
def reduce_fn(key, ds):
|
||||||
|
# key == 0 -> .batch(5)
|
||||||
|
# key == 1 -> .batch(10)
|
||||||
|
return ds.batch(batch_size=(key + 1) * 5)
|
||||||
|
|
||||||
dataset = dataset.apply(
|
dataset = dataset.apply(
|
||||||
grouping.group_by_window(
|
grouping.group_by_window(
|
||||||
key_func=lambda x: x, reduce_func=reduce_fn, window_size=10))
|
key_func=lambda x: x, reduce_func=reduce_fn, window_size=10))
|
||||||
@ -390,15 +352,64 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
|
|
||||||
self.assertEqual([[None]],
|
self.assertEqual([[None]],
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
[ts.as_list() for ts in _flat_shapes(dataset)])
|
||||||
pairs = [(3, 0), (3, 0), (3, 0)]
|
|
||||||
if not drop_remainder:
|
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
|
||||||
pairs.extend([(1, 0)])
|
# the batches of 10 (value == 1) split into minibatches of (5, 5)
|
||||||
pairs.extend([(5, 1), (5, 1)])
|
# [(batch_size, value), ...]
|
||||||
|
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1)]
|
||||||
pairs = pairs * 2
|
pairs = pairs * 2
|
||||||
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||||
self.assertDatasetProduces(dataset, expected_output)
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
def testScanAfterBatch(self, drop_remainder):
|
def testGroupByWindowDynamicBatchWithPartialBatch(self):
|
||||||
|
# {0, 1, 0, 1, ...}
|
||||||
|
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
||||||
|
|
||||||
|
def reduce_fn(key, ds):
|
||||||
|
# key == 0 -> .batch(5)
|
||||||
|
# key == 1 -> .batch(10)
|
||||||
|
return ds.batch(batch_size=(key + 1) * 5)
|
||||||
|
|
||||||
|
dataset = dataset.apply(
|
||||||
|
grouping.group_by_window(
|
||||||
|
key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
|
||||||
|
dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||||
|
|
||||||
|
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||||
|
|
||||||
|
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
|
||||||
|
# the batches of 10 (value == 1) split into minibatches of (5, 5)
|
||||||
|
# [(batch_size, value), ...]
|
||||||
|
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1),
|
||||||
|
(3, 0), (2, 0), (3, 0), (1, 0), (5, 1), (4, 1)]
|
||||||
|
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||||
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
|
def testGroupByWindowDynamicBatchWithPartialBatchWithDropRemainder(self):
|
||||||
|
# This test exercises nested batch functionality, dynamic batch size
|
||||||
|
# and drop_remainder=True together.
|
||||||
|
dataset = dataset_ops.Dataset.range(40).map(lambda x: x % 2)
|
||||||
|
|
||||||
|
def reduce_fn(key, ds):
|
||||||
|
# key == 0 -> .batch(5)
|
||||||
|
# key == 1 -> .batch(10)
|
||||||
|
return ds.batch(batch_size=(key + 1) * 5, drop_remainder=True)
|
||||||
|
|
||||||
|
dataset = dataset.apply(
|
||||||
|
grouping.group_by_window(
|
||||||
|
key_func=lambda x: x, reduce_func=reduce_fn, window_size=11))
|
||||||
|
dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||||
|
|
||||||
|
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
|
||||||
|
|
||||||
|
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
|
||||||
|
# the batches of 10 (value == 1) split into minibatches of (5, 5)
|
||||||
|
# [(batch_size, value), ...]
|
||||||
|
pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (5, 1), (5, 1), (3, 0), (2, 0)]
|
||||||
|
expected_output = [[value] * batch_size for batch_size, value in pairs]
|
||||||
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
|
def testScanAfterBatch(self):
|
||||||
dataset = dataset_ops.Dataset.range(40).batch(10).apply(
|
dataset = dataset_ops.Dataset.range(40).batch(10).apply(
|
||||||
scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
|
scan_ops.scan(np.int64(2), lambda state, value: (state, value * state)))
|
||||||
dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
dataset = distribute._RebatchDataset(dataset, num_workers=2)
|
||||||
@ -408,7 +419,7 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension
|
expected_output = [[i * 2 for i in range(j*5, (j+1)*5)] for j in range(8)] # pylint: disable=g-complex-comprehension
|
||||||
self.assertDatasetProduces(dataset, expected_output)
|
self.assertDatasetProduces(dataset, expected_output)
|
||||||
|
|
||||||
def testMakeBatchedFeaturesDataset(self, drop_remainder):
|
def testMakeBatchedFeaturesDataset(self):
|
||||||
# Set up
|
# Set up
|
||||||
fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
|
fn = os.path.join(self.get_temp_dir(), "tf_record.txt")
|
||||||
writer = python_io.TFRecordWriter(fn)
|
writer = python_io.TFRecordWriter(fn)
|
||||||
@ -429,13 +440,11 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
|
|||||||
features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)},
|
features={"value": parsing_ops.FixedLenFeature([], dtypes.int64)},
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
drop_final_batch=drop_remainder)
|
drop_final_batch=False)
|
||||||
|
|
||||||
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=4)
|
||||||
|
|
||||||
self.assertEqual([[32 if drop_remainder else None]],
|
self.assertEqual([[None]],
|
||||||
[ts.as_list() for ts in _flat_shapes(dataset)])
|
|
||||||
self.assertEqual([[8 if drop_remainder else None]],
|
|
||||||
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
|
||||||
|
|
||||||
expected_output = [{
|
expected_output = [{
|
||||||
|
@ -74,7 +74,11 @@ def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=i
|
|||||||
|
|
||||||
|
|
||||||
class _RebatchDataset(dataset_ops.UnaryDataset):
|
class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||||
"""A `Dataset` that divides the batch size by `num_workers`."""
|
"""A `Dataset` that divides the batch size by `num_workers`.
|
||||||
|
|
||||||
|
For each batch in the input dataset, the resulting dataset will produce
|
||||||
|
`num_replicas` minibatches whose sizes add up to the original batch size.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dataset, num_workers, use_fallback=True):
|
def __init__(self, input_dataset, num_workers, use_fallback=True):
|
||||||
self._input_dataset = input_dataset
|
self._input_dataset = input_dataset
|
||||||
@ -85,8 +89,14 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Input shape should have at least one dimension. "
|
"Input shape should have at least one dimension. "
|
||||||
"Perhaps your input dataset is not batched?")
|
"Perhaps your input dataset is not batched?")
|
||||||
output_dims = [d for d in output_shapes.dims]
|
output_dims = [d.value for d in output_shapes.dims]
|
||||||
output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers
|
|
||||||
|
if output_dims[0] is not None and output_dims[0] % num_workers == 0:
|
||||||
|
output_dims[0] = output_dims[0] // num_workers
|
||||||
|
else:
|
||||||
|
# Set the batch dimension to unknown. If the global batch size does not
|
||||||
|
# divide num_workers evenly, the minibatches may have different sizes.
|
||||||
|
output_dims[0] = None
|
||||||
return tensor_shape.TensorShape(output_dims)
|
return tensor_shape.TensorShape(output_dims)
|
||||||
|
|
||||||
input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
|
input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
|
||||||
|
Loading…
Reference in New Issue
Block a user