Merge branch 'master' into upstream-staging-terminology-3
This commit is contained in:
commit
19d836b4fc
@ -894,6 +894,22 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "tensor_list_ops_test",
|
||||
size = "small",
|
||||
srcs = ["tensor_list_ops_test.py"],
|
||||
# TensorList ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = "cpu_ondemand",
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:list_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python/eager:function",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "ternary_ops_test",
|
||||
size = "small",
|
||||
|
@ -1820,7 +1820,7 @@ TEST_F(OpTest, Diag) {
|
||||
do {
|
||||
dims = RandomDims(1);
|
||||
size = TensorShape(dims).num_elements();
|
||||
} while (size * size < tf_xla_max_tensor_size);
|
||||
} while (size * size > tf_xla_max_tensor_size);
|
||||
return ExpectTfAndXlaOutputsAreClose(
|
||||
OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type));
|
||||
});
|
||||
|
105
tensorflow/compiler/tests/tensor_list_ops_test.py
Normal file
105
tensorflow/compiler/tests/tensor_list_ops_test.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for ops which manipulate lists of tensors via bridge."""
|
||||
|
||||
# pylint: disable=g-bad-name
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
def scalar_shape():
|
||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||
|
||||
|
||||
class ListOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def testElementShape(self):
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
dim = array_ops.placeholder(dtypes.int32)
|
||||
l = list_ops.tensor_list_reserve(
|
||||
element_shape=(dim, 15), num_elements=20,
|
||||
element_dtype=dtypes.float32)
|
||||
e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
|
||||
e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64)
|
||||
self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15))
|
||||
self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
|
||||
|
||||
def testPushPop(self):
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
num = array_ops.placeholder(dtypes.int32)
|
||||
l = list_ops.tensor_list_reserve(
|
||||
element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32)
|
||||
l = list_ops.tensor_list_push_back(
|
||||
l, constant_op.constant(1.0, shape=(7, 15)))
|
||||
l = list_ops.tensor_list_push_back(
|
||||
l, constant_op.constant(2.0, shape=(7, 15)))
|
||||
l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||
_, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||
self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15)))
|
||||
self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15)))
|
||||
|
||||
def testPushPopSeparateLists(self):
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
num = array_ops.placeholder(dtypes.int32)
|
||||
l = list_ops.tensor_list_reserve(
|
||||
element_shape=scalar_shape(),
|
||||
num_elements=num,
|
||||
element_dtype=dtypes.float32)
|
||||
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
|
||||
l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
|
||||
l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0))
|
||||
_, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||
l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
|
||||
l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
|
||||
l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
|
||||
l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
|
||||
result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20})
|
||||
self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
|
||||
|
||||
def testEmptyTensorList(self):
|
||||
dim = 7
|
||||
with self.cached_session() as sess, self.test_scope():
|
||||
p = array_ops.placeholder(dtypes.int32)
|
||||
l = list_ops.empty_tensor_list(
|
||||
element_shape=(p, 15), element_dtype=dtypes.float32)
|
||||
l = list_ops.tensor_list_push_back(
|
||||
l, constant_op.constant(1.0, shape=(dim, 15)))
|
||||
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Use TensorListReserve instead"):
|
||||
self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -95,6 +95,7 @@ tf_kernel_library(
|
||||
"stateless_random_ops.cc",
|
||||
"strided_slice_op.cc",
|
||||
"tensor_array_ops.cc",
|
||||
"tensor_list_ops.cc",
|
||||
"tile_ops.cc",
|
||||
"topk_op.cc",
|
||||
"training_ops.cc",
|
||||
@ -158,6 +159,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:conv_ops",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:list_kernels",
|
||||
"//tensorflow/core/kernels:no_op",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//tensorflow/core/kernels:pooling_ops",
|
||||
|
226
tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
Normal file
226
tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
Normal file
@ -0,0 +1,226 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// XLA TensorList operators.
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/concat_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op,
|
||||
TensorShape* tensor_list_shape) {
|
||||
auto shape_or_status = builder->GetShape(op);
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
}
|
||||
xla::Shape shape = shape_or_status.ValueOrDie();
|
||||
TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
|
||||
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
|
||||
tensor_list_shape);
|
||||
}
|
||||
|
||||
class TensorListReserveOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape element_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
|
||||
int64 num_elements;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
|
||||
|
||||
TensorShape tensor_shape;
|
||||
tensor_shape.AddDim(num_elements);
|
||||
tensor_shape.AppendShape(element_shape);
|
||||
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_),
|
||||
tensor_shape.dim_sizes()),
|
||||
xla::ConstantR0<int32>(b, 0)}));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorListReserve")
|
||||
.CompileTimeConstInput("element_shape")
|
||||
.CompileTimeConstInput("num_elements"),
|
||||
TensorListReserveOp);
|
||||
|
||||
class EmptyTensorListOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
ctx->CtxFailure(
|
||||
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
|
||||
"size. Use TensorListReserve instead."));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp);
|
||||
|
||||
class TensorListElementShapeOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape));
|
||||
shape.RemoveDim(0);
|
||||
|
||||
switch (shape_type_) {
|
||||
case DT_INT64:
|
||||
ctx->SetOutput(0, xla::ConstantR1<int64>(b, shape.dim_sizes()));
|
||||
break;
|
||||
case DT_INT32: {
|
||||
std::vector<int32> size;
|
||||
for (int64 s : shape.dim_sizes()) {
|
||||
size.push_back(s);
|
||||
}
|
||||
ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ctx->CtxFailure(
|
||||
errors::InvalidArgument("Unsupported shape type requested"));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataType shape_type_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp);
|
||||
|
||||
class TensorListPushBackOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
xla::XlaOp list = ctx->Input(0);
|
||||
TensorShape elem_shape = ctx->InputShape(1);
|
||||
|
||||
xla::XlaOp ta = xla::GetTupleElement(list, 0);
|
||||
xla::XlaOp index = xla::GetTupleElement(list, 1);
|
||||
xla::XlaOp value = ctx->Input(1);
|
||||
|
||||
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
|
||||
auto start_indices =
|
||||
xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
|
||||
xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
|
||||
|
||||
TensorShape slice_shape = elem_shape;
|
||||
slice_shape.InsertDim(0, 1LL);
|
||||
auto update = xla::Reshape(value, slice_shape.dim_sizes());
|
||||
|
||||
// TODO(phawkins): We don't check the index is in bounds --- there is no
|
||||
// error mechanism in XLA.
|
||||
ctx->SetOutput(
|
||||
0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices),
|
||||
index + xla::ConstantR0<int32>(b, 1)}));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp);
|
||||
|
||||
class TensorListPopBackOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
xla::XlaOp state = ctx->Input(0);
|
||||
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
|
||||
|
||||
xla::XlaOp ta = xla::GetTupleElement(state, 0);
|
||||
xla::XlaOp index = xla::GetTupleElement(state, 1);
|
||||
|
||||
index = index - xla::ConstantR0<int32>(b, 1);
|
||||
|
||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||
auto start_indices =
|
||||
xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
|
||||
xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}}));
|
||||
|
||||
auto slice_shape = shape.dim_sizes();
|
||||
slice_shape[0] = 1LL;
|
||||
|
||||
// TODO(phawkins): We don't check the index is in bounds --- there is no
|
||||
// error mechanism in XLA.
|
||||
xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
|
||||
// Remove the leading '1' dimension.
|
||||
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
|
||||
|
||||
ctx->SetOutput(0, xla::Tuple(b, {ta, index}));
|
||||
ctx->SetOutput(1, xla::Reshape(read, value_shape));
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
|
||||
// Makes the host Tensor that will refer to the expression.
|
||||
Tensor* output = nullptr;
|
||||
auto shape = builder()->GetShape(handle);
|
||||
if (!shape.ok()) {
|
||||
SetStatus(shape.status());
|
||||
return;
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
|
||||
Tensor** output) {
|
||||
// The step's default allocator is the dummy XlaCompilationAllocator which
|
||||
// simply allocates a metadata buffer to hold the expression to which it
|
||||
// corresponds.
|
||||
TensorShape tensor_shape;
|
||||
if (expected_output_dtype(index) == DT_VARIANT) {
|
||||
// tensor_data() is not supported for variant Tensor (i.e.,
|
||||
// DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
|
||||
// XlaExpression inside the Tensor's tensor_data() does not work for
|
||||
// variant. Instead construct a uint8 tensor and store the expression in its
|
||||
// value.
|
||||
// TODO(jpienaar): This should be refactored to stop masquerading
|
||||
// XlaExpressions as Tensors.
|
||||
*output = new Tensor();
|
||||
TensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context_->allocate_temp(DT_UINT8, tensor_shape, *output));
|
||||
context_->set_output(index, **output);
|
||||
} else {
|
||||
TensorShape tensor_shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
|
||||
TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
|
||||
// Makes the host Tensor that will refer to the expression.
|
||||
Tensor* output = nullptr;
|
||||
auto shape_or = builder()->GetShape(handle);
|
||||
if (!shape_or.ok()) {
|
||||
SetStatus(shape_or.status());
|
||||
return;
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context_,
|
||||
XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape));
|
||||
OP_REQUIRES_OK(context_,
|
||||
context_->allocate_output(index, tensor_shape, &output));
|
||||
allocate_output(index, shape_or.ValueOrDie(), &output));
|
||||
|
||||
// The expression is stored in the tensor's data buffer. Fill in the
|
||||
// fields now.
|
||||
|
@ -255,6 +255,11 @@ class XlaOpKernelContext {
|
||||
// Returns the tensor of input `name`.
|
||||
const Tensor& GetInputTensorByName(absl::string_view name);
|
||||
|
||||
// Wraps OpKernelContext's allocate_output method while providing special
|
||||
// behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
|
||||
// type to allow mapping for variant to more generic types.
|
||||
Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
|
||||
|
||||
OpKernelContext* const context_;
|
||||
};
|
||||
|
||||
|
@ -1945,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
|
||||
}
|
||||
} break;
|
||||
case TUPLE:
|
||||
LOG(FATAL) << "Should not be called on tuple shapes: "
|
||||
<< ShapeUtil::HumanString(subshape());
|
||||
break;
|
||||
return InvalidArgument("Should not be called on tuple shapes: %s",
|
||||
ShapeUtil::HumanString(subshape()));
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
|
||||
return InvalidArgument("Is called on unsupported shape: %s",
|
||||
ShapeUtil::HumanString(subshape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -1841,42 +1841,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inliner",
|
||||
srcs = ["inliner.cc"],
|
||||
hdrs = ["inliner.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "inliner_test",
|
||||
srcs = ["inliner_test.cc"],
|
||||
deps = [
|
||||
":cpu_plugin",
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":inliner",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "computation_placer",
|
||||
srcs = ["computation_placer.cc"],
|
||||
@ -3492,6 +3456,39 @@ cc_library(
|
||||
deps = ["//tensorflow/core:lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_inliner",
|
||||
srcs = ["map_inliner.cc"],
|
||||
hdrs = ["map_inliner.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "map_inliner_test",
|
||||
srcs = ["map_inliner_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":map_inliner",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_casting_utils_test",
|
||||
srcs = ["hlo_casting_utils_test.cc"],
|
||||
|
@ -94,6 +94,7 @@ cc_library(
|
||||
":target_machine_features",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:scatter_expander",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
@ -127,7 +128,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:indexed_array_analysis",
|
||||
"//tensorflow/compiler/xla/service:inliner",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
|
@ -86,8 +86,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/scatter_expander.h"
|
||||
@ -249,7 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
&pipeline, module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
|
||||
pipeline.AddPass<Inliner>();
|
||||
pipeline.AddPass<MapInliner>();
|
||||
|
||||
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
|
||||
// pass.
|
||||
|
@ -465,8 +465,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kIota:
|
||||
TF_RET_CHECK(proto.dimensions_size() <= 1)
|
||||
<< "Iota instruction should have at most 1 dimension but sees "
|
||||
TF_RET_CHECK(proto.dimensions_size() == 1)
|
||||
<< "Iota instruction should have 1 dimension but sees "
|
||||
<< proto.dimensions_size();
|
||||
instruction = CreateIota(proto.shape(), proto.dimensions(0));
|
||||
break;
|
||||
|
@ -195,13 +195,15 @@ class ListScheduler {
|
||||
return entry;
|
||||
}
|
||||
|
||||
// Returns the number of bytes freed if the HLO instruction is scheduled.
|
||||
// If the instruction calls subcomputations, we count the memory used by the
|
||||
// subcomputations as memory "defined" by the instruction. This is not
|
||||
// entirely accurate, because subcomputation memory will be freed after the
|
||||
// instruction finishes. But it is more accurate than not taking
|
||||
// subcomputations into account at all. In the future, we may improve
|
||||
// accounting for subcomputation memory (b/65409243).
|
||||
// Returns the number of bytes freed *after* the HLO instruction finishes.
|
||||
// The current List algorithm only considers two states for an instruction:
|
||||
// right before it runs, and after it finishes. We don't represent memory
|
||||
// usage during the execution of an instruction. But if the instruction calls
|
||||
// subcomputations, they are only live during the instruction's execution.
|
||||
// We end up counting the memory used by subcomputations as memory "defined"
|
||||
// by the instruction. This is not entirely accurate, but it is more accurate
|
||||
// than not taking subcomputations into account at all. In the future, we may
|
||||
// improve accounting for subcomputation memory (b/65409243).
|
||||
int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
|
||||
int64 freed_bytes = 0;
|
||||
for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
|
||||
@ -223,7 +225,18 @@ class ListScheduler {
|
||||
}
|
||||
}
|
||||
}
|
||||
return freed_bytes - entry.bytes_defined - max_subcomputation_bytes;
|
||||
int64 bytes_defined;
|
||||
if (max_subcomputation_bytes > 0 &&
|
||||
(entry.instruction->opcode() == HloOpcode::kWhile ||
|
||||
entry.instruction->opcode() == HloOpcode::kCall ||
|
||||
entry.instruction->opcode() == HloOpcode::kConditional)) {
|
||||
// The output buffer of while/call/conditional is always aliased with the
|
||||
// output buffer of the root instruction in the body. Don't double count.
|
||||
bytes_defined = max_subcomputation_bytes;
|
||||
} else {
|
||||
bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
|
||||
}
|
||||
return freed_bytes - bytes_defined;
|
||||
}
|
||||
|
||||
// Constructs the scheduling priority of the given instruction.
|
||||
|
@ -1304,7 +1304,7 @@ TEST_F(HloParserTest, MoreConstants) {
|
||||
|
||||
ENTRY %SelectScalarS32True.v4 () -> s32[] {
|
||||
%constant.2 = pred[] constant(true)
|
||||
%constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
|
||||
%constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4}
|
||||
%constant = s32[] constant(42)
|
||||
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/overflow_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
@ -377,6 +378,20 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
|
||||
<< "Maximal sharding is expected to have single device assignment, but "
|
||||
<< proto.tile_assignment_devices().size() << " has provided.";
|
||||
|
||||
TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
|
||||
TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
|
||||
|
||||
// RE: the product of tile assignment tensor dimensions must be
|
||||
// equal to tile_assignment_devices.size().
|
||||
int64 product_of_dimensions = 1;
|
||||
for (auto dimension : proto.tile_assignment_dimensions()) {
|
||||
TF_RET_CHECK(dimension > 0);
|
||||
product_of_dimensions =
|
||||
MultiplyWithoutOverflow(product_of_dimensions, dimension);
|
||||
TF_RET_CHECK(product_of_dimensions > 0);
|
||||
}
|
||||
TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
|
||||
|
||||
// Some versions of gcc cannot infer the TileAssignment constructor from a
|
||||
// braced initializer-list, so create one manually.
|
||||
std::vector<int64> devices(proto.tile_assignment_devices().begin(),
|
||||
|
@ -45,8 +45,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:inliner",
|
||||
"//tensorflow/compiler/xla/service:layout_assignment",
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:while_loop_simplifier",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -28,9 +28,9 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -32,10 +32,10 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// InlinerVisitor traverses the HLO computation and inlines maps.
|
||||
class InlinerVisitor : public DfsHloVisitorWithDefault {
|
||||
// MapInlinerVisitor traverses the HLO computation and inlines maps.
|
||||
class MapInlinerVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit InlinerVisitor(HloComputation* computation)
|
||||
explicit MapInlinerVisitor(HloComputation* computation)
|
||||
: computation_(computation) {}
|
||||
|
||||
// Default visitor action is to do nothing and return OK.
|
||||
@ -49,24 +49,23 @@ class InlinerVisitor : public DfsHloVisitorWithDefault {
|
||||
StatusOr<bool> Run(HloComputation* computation);
|
||||
|
||||
private:
|
||||
// Current HloComputation instance the InlinerVisitor is traversing.
|
||||
// Current HloComputation instance the MapInlinerVisitor is traversing.
|
||||
HloComputation* computation_;
|
||||
|
||||
// Whether algebraic simplification has occurred.
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) {
|
||||
StatusOr<bool> MapInlinerVisitor::Run(HloComputation* computation) {
|
||||
changed_ = false;
|
||||
computation_ = computation;
|
||||
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
|
||||
return changed_;
|
||||
}
|
||||
|
||||
Status InlinerVisitor::HandleMap(HloInstruction* map) {
|
||||
Status MapInlinerVisitor::HandleMap(HloInstruction* map) {
|
||||
HloComputation* function = map->to_apply();
|
||||
HloInstruction& root = *function->root_instruction();
|
||||
// TODO(b/29249531): Add DCE pass to remove unused HloComputations.
|
||||
// Only inlining functions that are simply a single operation until a better
|
||||
// profitability model for inlining is defined.
|
||||
if (hlo_query::AllOperandsAreParameters(root)) {
|
||||
@ -112,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> Inliner::Run(HloModule* module) {
|
||||
InlinerVisitor visitor(/*computation=*/nullptr);
|
||||
StatusOr<bool> MapInliner::Run(HloModule* module) {
|
||||
MapInlinerVisitor visitor(/*computation=*/nullptr);
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation));
|
@ -13,27 +13,27 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A pass which performs inlining. Which can result, for example, in functions
|
||||
// that were previously being mapped by Map instead directly applied to the
|
||||
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
|
||||
class Inliner : public HloModulePass {
|
||||
// A pass which performs map inlining. This replaces kMap instructions with
|
||||
// their equivalent sequence of array operations. For example:
|
||||
// map({X, Y}, add) -> add(X, Y)).
|
||||
class MapInliner : public HloModulePass {
|
||||
public:
|
||||
~Inliner() override = default;
|
||||
absl::string_view name() const override { return "inline"; }
|
||||
~MapInliner() override = default;
|
||||
absl::string_view name() const override { return "map-inline"; }
|
||||
|
||||
// Run inlining on the given computation. Returns whether the computation was
|
||||
// changed.
|
||||
// Run map inlining on the given computation. Returns whether the computation
|
||||
// was changed.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using InlinerTest = HloVerifiedTestBase;
|
||||
using MapInlinerTest = HloVerifiedTestBase;
|
||||
|
||||
// Test that `map` with `max` is transformed to `max`
|
||||
TEST_F(InlinerTest, MapMax) {
|
||||
TEST_F(MapInlinerTest, MapMax) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
auto max_builder = HloComputation::Builder(TestName());
|
||||
@ -63,7 +63,7 @@ TEST_F(InlinerTest, MapMax) {
|
||||
hlo_module->AddEmbeddedComputation(std::move(max_f32));
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
Inliner inliner;
|
||||
MapInliner inliner;
|
||||
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
|
||||
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
|
||||
op::Maximum(lhs, rhs));
|
||||
@ -75,7 +75,7 @@ TEST_F(InlinerTest, MapMax) {
|
||||
}
|
||||
|
||||
// Test that `constant` function is changed to `broadcast`.
|
||||
TEST_F(InlinerTest, MapConstant) {
|
||||
TEST_F(MapInlinerTest, MapConstant) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
auto const2_builder = HloComputation::Builder(TestName());
|
||||
@ -97,7 +97,7 @@ TEST_F(InlinerTest, MapConstant) {
|
||||
hlo_module->AddEmbeddedComputation(std::move(const2_f32));
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
|
||||
Inliner inliner;
|
||||
MapInliner inliner;
|
||||
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
|
||||
root = hlo_module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Broadcast(op::Constant()));
|
||||
@ -108,7 +108,7 @@ TEST_F(InlinerTest, MapConstant) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
|
||||
}
|
||||
|
||||
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||
TEST_F(MapInlinerTest, MapSubtractOppositeOrder) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
// Note that the parameter ordinals are in the opposite order to their
|
||||
@ -135,7 +135,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||
hlo_module->AddEmbeddedComputation(std::move(max_f32));
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
Inliner inliner;
|
||||
MapInliner inliner;
|
||||
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
|
||||
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
|
||||
op::Subtract(rhs, lhs));
|
||||
@ -146,7 +146,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
|
||||
}
|
||||
|
||||
TEST_F(InlinerTest, MapParameter) {
|
||||
TEST_F(MapInlinerTest, MapParameter) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
auto param_builder = HloComputation::Builder(TestName());
|
||||
@ -167,7 +167,7 @@ TEST_F(InlinerTest, MapParameter) {
|
||||
hlo_module->AddEmbeddedComputation(std::move(param_f32));
|
||||
hlo_module->AddEntryComputation(std::move(computation));
|
||||
|
||||
Inliner inliner;
|
||||
MapInliner inliner;
|
||||
EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
|
||||
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
|
||||
|
@ -869,11 +869,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (Rank(shape) != shape.dimensions_size()) {
|
||||
return InvalidArgument(
|
||||
"shape's rank is mismatched with dimension count; rank=%d "
|
||||
"dimensions_size=%d",
|
||||
Rank(shape), shape.dimensions_size());
|
||||
if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) {
|
||||
return InvalidArgument("sparse arrays must have rank > 0");
|
||||
}
|
||||
for (int64 i = 0; i < Rank(shape); ++i) {
|
||||
int64 dimension = shape.dimensions(i);
|
||||
|
@ -4,6 +4,7 @@ package(default_visibility = [
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
|
||||
|
||||
exports_files(glob([
|
||||
@ -165,10 +166,6 @@ cc_library(
|
||||
"stderr_reporter.h",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
defines = select({
|
||||
":with_tflite_flex": ["TFLITE_FLEX"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = [
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
@ -276,6 +273,7 @@ cc_test(
|
||||
"testdata/0_subgraphs.bin",
|
||||
"testdata/2_subgraphs.bin",
|
||||
"testdata/empty_model.bin",
|
||||
"testdata/multi_add_flex.bin",
|
||||
"testdata/test_model.bin",
|
||||
"testdata/test_model_broken.bin",
|
||||
],
|
||||
@ -283,6 +281,26 @@ cc_test(
|
||||
":framework",
|
||||
"//tensorflow/contrib/lite/c:c_api_internal",
|
||||
"//tensorflow/contrib/lite/core/api",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
# Test model framework with the flex library linked into the target.
|
||||
tf_cc_test(
|
||||
name = "model_flex_test",
|
||||
size = "small",
|
||||
srcs = ["model_flex_test.cc"],
|
||||
data = [
|
||||
"testdata/multi_add_flex.bin",
|
||||
],
|
||||
tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC.
|
||||
deps = [
|
||||
":framework",
|
||||
"//tensorflow/contrib/lite/core/api",
|
||||
"//tensorflow/contrib/lite/delegates/flex:delegate",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
|
@ -2,7 +2,7 @@
|
||||
# This is a TF Lite delegate that is powered by TensorFlow's Eager.
|
||||
#
|
||||
package(default_visibility = [
|
||||
"//visibility:public",
|
||||
"//visibility:private",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
@ -50,6 +50,7 @@ cc_library(
|
||||
hdrs = [
|
||||
"delegate.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":buffer_map",
|
||||
":delegate_data",
|
||||
@ -66,6 +67,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -83,6 +83,15 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
|
||||
} // namespace delegate
|
||||
} // namespace flex
|
||||
|
||||
// Corresponding weak declaration found in lite/model.cc.
|
||||
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
|
||||
AcquireFlexDelegate() {
|
||||
return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
|
||||
tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
|
||||
delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
|
||||
std::unique_ptr<flex::DelegateData> delegate_data;
|
||||
if (!flex::DelegateData::Create(&delegate_data).ok()) {
|
||||
|
@ -349,6 +349,10 @@ class Interpreter {
|
||||
return context_.allow_fp32_relax_to_fp16;
|
||||
}
|
||||
|
||||
// Owning handle to a TfLiteDelegate instance.
|
||||
using TfLiteDelegatePtr =
|
||||
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
|
||||
|
||||
// Allow a delegate to look at the graph and modify the graph to handle
|
||||
// parts of the graph themselves. After this is called, the graph may
|
||||
// contain new nodes that replace 1 more nodes.
|
||||
@ -574,19 +578,11 @@ class Interpreter {
|
||||
TfLiteExternalContextType type,
|
||||
TfLiteExternalContext* ctx);
|
||||
|
||||
using TfLiteDelegatePtr =
|
||||
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
|
||||
|
||||
// Variant of the public ModifyGraphWithDelegate method that additionally
|
||||
// Assumes ownership of the provided delegate.
|
||||
// WARNING: This is an experimental API and subject to change.
|
||||
template <typename Delegate>
|
||||
TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr<Delegate> typed_delegate,
|
||||
TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate,
|
||||
bool allow_dynamic_tensors = false) {
|
||||
TfLiteDelegatePtr delegate(typed_delegate.release(),
|
||||
[](TfLiteDelegate* delegate) {
|
||||
delete static_cast<Delegate*>(delegate);
|
||||
});
|
||||
// Note that we retain ownership of the delegate even if graph modification
|
||||
// fails, as delegate use will be in an indeterminate state at that point.
|
||||
owned_delegates_.push_back(std::move(delegate));
|
||||
@ -676,6 +672,7 @@ class Interpreter {
|
||||
// List of delegates that have been installed and are owned by this
|
||||
// interpreter instance. Useful if client delegate ownership is burdensome.
|
||||
// WARNING: This is an experimental API and subject to change.
|
||||
// TODO(b/116667551): Use TfLiteExternalContext for storing state.
|
||||
std::vector<TfLiteDelegatePtr> owned_delegates_;
|
||||
|
||||
std::unique_ptr<MemoryPlanner> memory_planner_;
|
||||
|
@ -30,7 +30,11 @@ class InterpreterTest : public ::testing::Test {
|
||||
template <typename Delegate>
|
||||
static TfLiteStatus ModifyGraphWithDelegate(
|
||||
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
|
||||
return interpreter->ModifyGraphWithDelegate(std::move(delegate));
|
||||
Interpreter::TfLiteDelegatePtr tflite_delegate(
|
||||
delegate.release(), [](TfLiteDelegate* delegate) {
|
||||
delete reinterpret_cast<Delegate*>(delegate);
|
||||
});
|
||||
return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate));
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -113,6 +113,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// input configuration.
|
||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
|
||||
const int batch_size = input->dims->data[0];
|
||||
const int max_time = input->dims->data[1];
|
||||
const int fw_num_units = fw_input_weights->dims->data[0];
|
||||
|
@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (input1->type == kTfLiteUInt8) { \
|
||||
auto input1_offset = -input1->params.zero_point; \
|
||||
auto input2_offset = -input2->params.zero_point; \
|
||||
const int left_shift = 20; \
|
||||
const double twice_max_input_scale = \
|
||||
2 * std::max(input1->params.scale, input2->params.scale); \
|
||||
const double real_input1_multiplier = \
|
||||
input1->params.scale / twice_max_input_scale; \
|
||||
const double real_input2_multiplier = \
|
||||
input2->params.scale / twice_max_input_scale; \
|
||||
const int left_shift = 8; \
|
||||
\
|
||||
int32 input1_multiplier; \
|
||||
int input1_shift; \
|
||||
QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \
|
||||
QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
|
||||
&input1_multiplier, &input1_shift); \
|
||||
int32 input2_multiplier; \
|
||||
int input2_shift; \
|
||||
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
|
||||
QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
|
||||
&input2_multiplier, &input2_shift); \
|
||||
\
|
||||
ComparisonParams op_params; \
|
||||
op_params.left_shift = left_shift; \
|
||||
op_params.input1_offset = input1_offset; \
|
||||
op_params.input1_multiplier = input1_multiplier; \
|
||||
op_params.input1_shift = -input1_shift; \
|
||||
op_params.input1_shift = input1_shift; \
|
||||
op_params.input2_offset = input2_offset; \
|
||||
op_params.input2_multiplier = input2_multiplier; \
|
||||
op_params.input2_shift = -input2_shift; \
|
||||
op_params.input2_shift = input2_shift; \
|
||||
if (requires_broadcast) { \
|
||||
reference_ops::Broadcast4DSlow##opname##WithScaling( \
|
||||
op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
|
||||
|
@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) {
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
|
||||
ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
|
||||
{TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
|
||||
TensorType_UINT8, BuiltinOperator_GREATER);
|
||||
model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
|
||||
model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
|
||||
}
|
||||
|
||||
TEST(ComparisonsTest, GreaterEqualQuantized) {
|
||||
const float kMin = -1.f;
|
||||
const float kMax = 128.f;
|
||||
|
@ -27,9 +27,6 @@ limitations under the License.
|
||||
#ifndef TFLITE_MCU
|
||||
#include "tensorflow/contrib/lite/nnapi_delegate.h"
|
||||
#endif
|
||||
#if defined(TFLITE_FLEX)
|
||||
#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
|
||||
#endif
|
||||
#include "tensorflow/contrib/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -43,6 +40,25 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
|
||||
|
||||
const char* kEmptyTensorName = "";
|
||||
|
||||
// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but
|
||||
// we avoid the absl dependency for binary size reasons.
|
||||
#ifdef __has_attribute
|
||||
#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x)
|
||||
#else
|
||||
#define TFLITE_HAS_ATTRIBUTE(x) 0
|
||||
#endif
|
||||
|
||||
#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__))
|
||||
// Using weak symbols for the flex delegate allows automatic injection of the
|
||||
// delegate simply by adding it as a dependency. See also the strong override in
|
||||
// lite/delegates/flex/delegate.cc.
|
||||
__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
|
||||
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
|
||||
}
|
||||
#else
|
||||
Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
|
||||
#endif
|
||||
|
||||
#ifndef TFLITE_MCU
|
||||
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
|
||||
// otherwise make a copy of the model in a buffer.
|
||||
@ -450,13 +466,14 @@ TfLiteStatus InterpreterBuilder::operator()(
|
||||
}
|
||||
(**interpreter).SetVariables(std::move(variables));
|
||||
|
||||
#if defined(TFLITE_FLEX)
|
||||
if (auto delegate = FlexDelegate::Create()) {
|
||||
(**interpreter)
|
||||
.ModifyGraphWithDelegate(std::move(delegate),
|
||||
/*allow_dynamic_tensors=*/true);
|
||||
// TODO(b/116667551): Only create the flex delegate if the model has flex ops.
|
||||
if (AcquireFlexDelegate != nullptr) {
|
||||
if (auto flex_delegate = AcquireFlexDelegate()) {
|
||||
(**interpreter)
|
||||
.ModifyGraphWithDelegate(std::move(flex_delegate),
|
||||
/*allow_dynamic_tensors=*/true);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
45
tensorflow/contrib/lite/model_flex_test.cc
Normal file
45
tensorflow/contrib/lite/model_flex_test.cc
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// Ensures that a model with TensorFlow ops can be imported as long as the
|
||||
// appropriate delegate is linked into the client.
|
||||
TEST(FlexModel, WithFlexDelegate) {
|
||||
auto model = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/contrib/lite/testdata/multi_add_flex.bin");
|
||||
ASSERT_TRUE(model);
|
||||
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
ASSERT_EQ(InterpreterBuilder(*model,
|
||||
ops::builtin::BuiltinOpResolver{})(&interpreter),
|
||||
kTfLiteOk);
|
||||
ASSERT_TRUE(interpreter);
|
||||
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/contrib/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/testing/util.h"
|
||||
|
||||
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
|
||||
@ -193,6 +194,27 @@ TEST(BasicFlatBufferModel, TestModelInInterpreter) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test that loading a model with TensorFlow ops fails when the flex delegate is
|
||||
// not linked into the target.
|
||||
TEST(FlexModel, FailureWithoutFlexDelegate) {
|
||||
auto model = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/contrib/lite/testdata/multi_add_flex.bin");
|
||||
ASSERT_TRUE(model);
|
||||
|
||||
// Note that creation will succeed when using the BuiltinOpResolver, but
|
||||
// unless the appropriate delegate is linked into the target or the client
|
||||
// explicitly installs the delegate, execution will fail.
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
ASSERT_EQ(InterpreterBuilder(*model,
|
||||
ops::builtin::BuiltinOpResolver{})(&interpreter),
|
||||
kTfLiteOk);
|
||||
ASSERT_TRUE(interpreter);
|
||||
|
||||
// As the flex ops weren't resolved implicitly by the flex delegate, runtime
|
||||
// allocation and execution will fail.
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError);
|
||||
}
|
||||
|
||||
// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
|
||||
// buffer. But the buffer is provided to be only 1 element.
|
||||
TEST(BasicFlatBufferModel, TestBrokenMmap) {
|
||||
|
BIN
tensorflow/contrib/lite/testdata/multi_add_flex.bin
vendored
Normal file
BIN
tensorflow/contrib/lite/testdata/multi_add_flex.bin
vendored
Normal file
Binary file not shown.
@ -40,7 +40,7 @@ cc_binary(
|
||||
srcs = [
|
||||
"benchmark_main.cc",
|
||||
],
|
||||
copts = common_copts + ["-DTFLITE_FLEX"],
|
||||
copts = common_copts,
|
||||
linkopts = tflite_linkopts() + select({
|
||||
"//tensorflow:android": [
|
||||
"-pie", # Android 5.0 and later supports only PIE
|
||||
@ -49,8 +49,9 @@ cc_binary(
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":benchmark_tflite_model_plus_flex_lib",
|
||||
":benchmark_tflite_model_lib",
|
||||
":logging",
|
||||
"//tensorflow/contrib/lite/delegates/flex:delegate",
|
||||
],
|
||||
)
|
||||
|
||||
@ -110,25 +111,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_tflite_model_plus_flex_lib",
|
||||
srcs = [
|
||||
"benchmark_tflite_model.cc",
|
||||
"logging.h",
|
||||
],
|
||||
hdrs = ["benchmark_tflite_model.h"],
|
||||
copts = common_copts + ["-DTFLITE_FLEX"],
|
||||
deps = [
|
||||
":benchmark_model_lib",
|
||||
":logging",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/delegates/flex:delegate",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/profiling:profile_summarizer",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_params",
|
||||
srcs = [
|
||||
|
@ -23,9 +23,6 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#ifdef TFLITE_FLEX
|
||||
#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
|
||||
#endif // TFLITE_FLEX
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/op_resolver.h"
|
||||
@ -305,15 +302,6 @@ void BenchmarkTfLiteModel::Init() {
|
||||
|
||||
interpreter->UseNNAPI(use_nnapi);
|
||||
|
||||
#ifdef TFLITE_FLEX
|
||||
TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
|
||||
delegate_ = FlexDelegate::Create();
|
||||
if (delegate_) {
|
||||
interpreter->ModifyGraphWithDelegate(delegate_.get(),
|
||||
/*allow_dynamic_tensors=*/true);
|
||||
}
|
||||
#endif // TFLITE_FLEX
|
||||
|
||||
auto interpreter_inputs = interpreter->inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
|
@ -20,9 +20,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifdef TFLITE_FLEX
|
||||
#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
|
||||
#endif // TFLITE_FLEX
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
|
||||
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
|
||||
@ -73,9 +70,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
|
||||
void PrepareInputsAndOutputs() override;
|
||||
|
||||
private:
|
||||
#ifdef TFLITE_FLEX
|
||||
std::unique_ptr<FlexDelegate> delegate_;
|
||||
#endif // TFLITE_FLEX
|
||||
std::unique_ptr<tflite::FlatBufferModel> model;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
std::vector<InputLayerInfo> inputs;
|
||||
|
@ -4,22 +4,33 @@ op {
|
||||
in_arg {
|
||||
name: "arguments"
|
||||
description: <<END
|
||||
A list of tensors whose types are Targuments, corresponding to the inputs the
|
||||
function should be mapped over.
|
||||
A list of tensors whose types are `Targuments`, corresponding to the inputs
|
||||
the function should be mapped over.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "captured_inputs"
|
||||
description: <<END
|
||||
A list of tensors whose types are `Tcaptured`, corresponding to the captured
|
||||
inputs of the defun.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A list of output tensors whose types are output_types and whose dimensions 0
|
||||
are the same as the dimensions 0 of the tensors in arguments, and whose
|
||||
remaining dimensions correspond to those in output_shapes.
|
||||
A list of output tensors whose types are `output_types` and whose dimensions
|
||||
0 are the same as the dimensions 0 of the tensors in `arguments`, and whose
|
||||
remaining dimensions correspond to those in `output_shapes`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
description: "A list of types."
|
||||
}
|
||||
attr {
|
||||
name: "Tcaptured"
|
||||
description: "A list of types."
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
description: "A list of types."
|
||||
@ -29,6 +40,6 @@ END
|
||||
description: "A list of shapes."
|
||||
}
|
||||
summary: <<END
|
||||
Maps a function on the list of tensors unpacked from inputs on dimension 0.
|
||||
Maps a function on the list of tensors unpacked from arguments on dimension 0.
|
||||
END
|
||||
}
|
||||
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListPushBackBatch"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "EmptyTensorList"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListConcatLists"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListElementShape"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListFromTensor"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListGather"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListGetItem"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListLength"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListPopBack"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListPushBack"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListReserve"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListScatter"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListSetItem"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "TensorListStack"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
|
||||
map_defun_node->add_input(input.name());
|
||||
}
|
||||
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
|
||||
AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
|
||||
|
||||
// Set return values to match output names
|
||||
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
|
||||
|
@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
|
||||
func.set_name(function_name);
|
||||
NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
|
||||
graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
|
||||
graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
|
||||
graph_transforms::SetNodeAttr("output_types", output_types, node);
|
||||
graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
|
||||
graph_transforms::SetNodeAttr("f", func, node);
|
||||
@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
|
||||
*lib.add_function() = outer;
|
||||
*lib.add_function() = inner;
|
||||
FunctionDef* vectorized;
|
||||
Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
|
||||
LOG(ERROR) << s;
|
||||
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
|
||||
EXPECT_TRUE(
|
||||
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
|
||||
|
@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel {
|
||||
|
||||
~MapDefunOp() override {}
|
||||
|
||||
Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
|
||||
// Validates inputs and gets the size of their leading dimension.
|
||||
*batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
|
||||
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
|
||||
if (ctx->input(i).dims() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have rank at least 1. Input ", i,
|
||||
" has a rank of 0.");
|
||||
} else if (ctx->input(i).dim_size(0) != *batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have the same dimension 0. Input ", i,
|
||||
" has leading dimension ", ctx->input(i).dim_size(0),
|
||||
", while all previous inputs have leading dimension ", batch_size);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
ComputeOptions* compute_opts = nullptr;
|
||||
|
||||
@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel {
|
||||
// all calls to the function are complete. This struct also encapsulates
|
||||
// all the components that need to be passed to each MapFunctionCallFrame.
|
||||
|
||||
const std::vector<Tensor> args;
|
||||
OpInputList args;
|
||||
const std::vector<TensorShape> arg_shapes;
|
||||
OpInputList captured_inputs;
|
||||
const int64 batch_size;
|
||||
|
||||
// Output of a compute call
|
||||
@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel {
|
||||
|
||||
// Create a copy of output_shapes because every `Compute` may expect a
|
||||
// different output shape.
|
||||
ComputeOptions(std::vector<Tensor> args,
|
||||
ComputeOptions(OpInputList args, OpInputList captured_inputs,
|
||||
std::vector<TensorShape> arg_shapes, int64 batch_size,
|
||||
const std::vector<PartialTensorShape>& output_shapes_attr)
|
||||
: args(std::move(args)),
|
||||
: args(args),
|
||||
arg_shapes(std::move(arg_shapes)),
|
||||
captured_inputs(captured_inputs),
|
||||
batch_size(batch_size),
|
||||
output_shapes(output_shapes_attr) {}
|
||||
};
|
||||
|
||||
// Get inputs to Compute and check that they are valid.
|
||||
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
|
||||
int64 batch_size =
|
||||
ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
|
||||
OpInputList arguments;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
|
||||
OpInputList captured_inputs;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
|
||||
|
||||
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
|
||||
if (ctx->input(i).dims() == 0) {
|
||||
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
if (arguments[i].dims() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have rank at least 1. Input ", i,
|
||||
" has a rank of 0.");
|
||||
} else if (ctx->input(i).dim_size(0) != batch_size) {
|
||||
} else if (arguments[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have the same dimension 0. Input ", i,
|
||||
" has leading dimension ", ctx->input(i).dim_size(0),
|
||||
@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Tensor> args;
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
args.reserve(ctx->num_inputs());
|
||||
arg_shapes.reserve(ctx->num_inputs());
|
||||
arg_shapes.reserve(arguments.size());
|
||||
|
||||
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
|
||||
args.push_back(ctx->input(i));
|
||||
arg_shapes.push_back(ctx->input(i).shape());
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
arg_shapes.push_back(arguments[i].shape());
|
||||
arg_shapes.at(i).RemoveDim(0);
|
||||
}
|
||||
|
||||
*compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
|
||||
batch_size, output_shapes_);
|
||||
*compute_opts =
|
||||
new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes),
|
||||
batch_size, output_shapes_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel {
|
||||
}
|
||||
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
if (index < 0 || index >= compute_opts_->args.size()) {
|
||||
if (index < 0 || index >= compute_opts_->args.size() +
|
||||
compute_opts_->captured_inputs.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in number of function inputs.");
|
||||
}
|
||||
|
||||
if (index >= compute_opts_->args.size()) {
|
||||
// The function is calling for a captured input
|
||||
*val =
|
||||
compute_opts_->captured_inputs[index - compute_opts_->args.size()];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool result =
|
||||
val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
|
||||
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
|
||||
compute_opts_->arg_shapes.at(index));
|
||||
if (!result) {
|
||||
return errors::Internal("GetArg failed.");
|
||||
@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel {
|
||||
// Ensure alignment
|
||||
*val = tensor::DeepCopy(*val);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -97,6 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel {
|
||||
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
|
||||
errors::Internal("Could not find handle ", handle),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, args.size() == fbody->arg_nodes.size(),
|
||||
errors::InvalidArgument(
|
||||
"Wrong number of arguments to the op; function expects ",
|
||||
fbody->arg_nodes.size(), " but PartitionedCall received ",
|
||||
args.size()),
|
||||
done);
|
||||
// We need to pass global op_registry as default_registry when creating
|
||||
// graph. So that graph optimization passes can lookup all possible ops
|
||||
// by name.
|
||||
|
@ -30566,6 +30566,52 @@ op {
|
||||
type: "func"
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MapDefun"
|
||||
input_arg {
|
||||
name: "arguments"
|
||||
type_list_attr: "Targuments"
|
||||
}
|
||||
input_arg {
|
||||
name: "captured_inputs"
|
||||
type_list_attr: "Tcaptured"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_list_attr: "output_types"
|
||||
}
|
||||
attr {
|
||||
name: "Targuments"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "Tcaptured"
|
||||
type: "list(type)"
|
||||
default_value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
type: "func"
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MapIncompleteSize"
|
||||
output_arg {
|
||||
@ -71843,6 +71889,48 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Substr"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type: DT_STRING
|
||||
}
|
||||
input_arg {
|
||||
name: "pos"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "len"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type: DT_STRING
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "unit"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "BYTE"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "BYTE"
|
||||
s: "UTF8_CHAR"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Sum"
|
||||
input_arg {
|
||||
|
@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset")
|
||||
|
||||
REGISTER_OP("MapDefun")
|
||||
.Input("arguments: Targuments")
|
||||
.Input("captured_inputs: Tcaptured")
|
||||
.Output("output: output_types")
|
||||
.Attr("Targuments: list(type) >= 1")
|
||||
.Attr("Tcaptured: list(type) >= 0 = []")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("f: func")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
|
||||
DataTypeVector t_args;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
|
||||
if (output_shapes.size() != c->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"`output_shapes` must be the same length as `output_types` (",
|
||||
@ -918,10 +922,11 @@ REGISTER_OP("MapDefun")
|
||||
}
|
||||
|
||||
int64 dim_zero = -1;
|
||||
for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
|
||||
for (size_t i = 0; i < t_args.size(); ++i) {
|
||||
if (c->Rank(c->input(i)) == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Inputs must have rank at least 1. Input ", i, " has rank of 0");
|
||||
"Arguments must have rank at least 1. Input ", i,
|
||||
" has rank of 0.");
|
||||
}
|
||||
auto dim_handle = c->Dim(c->input(i), 0);
|
||||
if (c->ValueKnown(dim_handle)) {
|
||||
@ -929,7 +934,7 @@ REGISTER_OP("MapDefun")
|
||||
dim_zero = c->Value(dim_handle);
|
||||
} else if (c->Value(dim_handle) != dim_zero) {
|
||||
return errors::InvalidArgument(
|
||||
"Inputs must have the same dimension 0.");
|
||||
"Arguments must have the same dimension 0.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15262,6 +15262,10 @@ op {
|
||||
name: "arguments"
|
||||
type_list_attr: "Targuments"
|
||||
}
|
||||
input_arg {
|
||||
name: "captured_inputs"
|
||||
type_list_attr: "Tcaptured"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_list_attr: "output_types"
|
||||
@ -15272,6 +15276,15 @@ op {
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "Tcaptured"
|
||||
type: "list(type)"
|
||||
default_value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
@ -33748,6 +33761,19 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "unit"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "BYTE"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "BYTE"
|
||||
s: "UTF8_CHAR"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Sum"
|
||||
|
25
tensorflow/python/data/experimental/benchmarks/BUILD
Normal file
25
tensorflow/python/data/experimental/benchmarks/BUILD
Normal file
@ -0,0 +1,25 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_test(
|
||||
name = "map_benchmark",
|
||||
size = "medium",
|
||||
srcs = ["map_benchmark.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/experimental/ops:optimization",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
|
||||
import hashlib
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
@ -27,128 +26,15 @@ import numpy as np
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import error_ops
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_NUMPY_RANDOM_SEED = 42
|
||||
|
||||
|
||||
class MapDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
def testMapIgnoreError(self):
|
||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.check_numerics(x, "message")).apply(
|
||||
error_ops.ignore_errors()))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testParallelMapIgnoreError(self):
|
||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).map(
|
||||
lambda x: array_ops.check_numerics(x, "message"),
|
||||
num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testReadFileIgnoreError(self):
|
||||
|
||||
def write_string_to_file(value, filename):
|
||||
with open(filename, "w") as f:
|
||||
f.write(value)
|
||||
|
||||
filenames = [
|
||||
os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
|
||||
]
|
||||
for filename in filenames:
|
||||
write_string_to_file(filename, filename)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(filenames).map(
|
||||
io_ops.read_file,
|
||||
num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# All of the files are present.
|
||||
sess.run(init_op)
|
||||
for filename in filenames:
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Delete one of the files.
|
||||
os.remove(filenames[0])
|
||||
|
||||
# Attempting to read filenames[0] will fail, but ignore_errors()
|
||||
# will catch the error.
|
||||
sess.run(init_op)
|
||||
for filename in filenames[1:]:
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testCaptureResourceInMapFn(self):
|
||||
|
||||
def _build_ds(iterator):
|
||||
|
||||
def _map_fn(x):
|
||||
get_next = iterator.get_next()
|
||||
return x * get_next
|
||||
|
||||
return dataset_ops.Dataset.range(10).map(_map_fn)
|
||||
|
||||
def _build_graph():
|
||||
captured_iterator = dataset_ops.Dataset.range(
|
||||
10).make_initializable_iterator()
|
||||
ds = _build_ds(captured_iterator)
|
||||
iterator = ds.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
return captured_iterator.initializer, init_op, get_next
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
captured_init_op, init_op, get_next = _build_graph()
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(captured_init_op)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEquals(i * i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
class MapDatasetBenchmark(test.Benchmark):
|
||||
|
||||
# The purpose of this benchmark is to compare the performance of chaining vs
|
@ -8,57 +8,16 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_test(
|
||||
name = "batch_dataset_op_test",
|
||||
name = "bucket_by_sequence_length_test",
|
||||
size = "medium",
|
||||
srcs = ["batch_dataset_op_test.py"],
|
||||
srcs = ["bucket_by_sequence_length_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss", # (b/79552534)
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "bucketing_test",
|
||||
size = "medium",
|
||||
srcs = ["bucketing_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/experimental/ops:grouping",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
@ -67,16 +26,44 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "csv_dataset_op_test",
|
||||
size = "medium",
|
||||
srcs = ["csv_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
cuda_py_test(
|
||||
name = "copy_to_device_test",
|
||||
size = "small",
|
||||
srcs = ["copy_to_device_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python/compat:compat",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
tags = ["no_windows_gpu"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "counter_test",
|
||||
size = "small",
|
||||
srcs = ["counter_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python/data/experimental/ops:counter",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "csv_dataset_test",
|
||||
size = "medium",
|
||||
srcs = ["csv_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -97,25 +84,18 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_constructor_op_test",
|
||||
size = "medium",
|
||||
srcs = ["dataset_constructor_op_test.py"],
|
||||
name = "dense_to_sparse_batch_test",
|
||||
srcs = ["dense_to_sparse_batch_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"manual",
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
"nomac", # b/62040583
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -124,11 +104,6 @@ py_test(
|
||||
size = "medium",
|
||||
srcs = ["directed_interleave_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
@ -140,15 +115,68 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "enumerate_dataset_test",
|
||||
size = "small",
|
||||
srcs = ["enumerate_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/experimental/ops:enumerate_ops",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "filter_dataset_op_test",
|
||||
size = "medium",
|
||||
srcs = ["filter_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:optimization",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "function_buffering_resource_test",
|
||||
size = "small",
|
||||
srcs = ["function_buffering_resource_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
tags = ["no_windows_gpu"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "get_single_element_test",
|
||||
size = "small",
|
||||
srcs = ["get_single_element_test.py"],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -164,14 +192,69 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "group_by_reducer_test",
|
||||
size = "medium",
|
||||
srcs = ["group_by_reducer_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/experimental/ops:grouping",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "group_by_window_test",
|
||||
size = "medium",
|
||||
srcs = ["group_by_window_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/experimental/ops:grouping",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ignore_errors_test",
|
||||
srcs = ["ignore_errors_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:error_ops",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "indexed_dataset_ops_test",
|
||||
srcs = ["indexed_dataset_ops_test.py"],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -185,14 +268,134 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "interleave_dataset_op_test",
|
||||
name = "make_batched_features_dataset_test",
|
||||
size = "medium",
|
||||
srcs = ["interleave_dataset_op_test.py"],
|
||||
srcs = ["make_batched_features_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":reader_dataset_ops_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python/data/experimental/ops:readers",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "make_csv_dataset_test",
|
||||
size = "medium",
|
||||
srcs = ["make_csv_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/data/experimental/ops:readers",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "make_tf_record_dataset_test",
|
||||
size = "medium",
|
||||
srcs = ["make_tf_record_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":reader_dataset_ops_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/data/experimental/ops:readers",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "map_and_batch_test",
|
||||
size = "medium",
|
||||
srcs = ["map_and_batch_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "map_defun_op_test",
|
||||
size = "small",
|
||||
srcs = ["map_defun_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/data/experimental/ops:map_defun",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "override_threadpool_test",
|
||||
size = "small",
|
||||
srcs = ["override_threadpool_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python/data/experimental/ops:threadpool",
|
||||
"//tensorflow/python/data/experimental/ops:unique",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "parallel_interleave_test",
|
||||
size = "medium",
|
||||
srcs = ["parallel_interleave_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
@ -212,120 +415,10 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "iterator_ops_test",
|
||||
name = "parse_example_dataset_test",
|
||||
size = "small",
|
||||
srcs = ["iterator_ops_test.py"],
|
||||
srcs = ["parse_example_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/experimental/ops:iterator_ops",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "map_dataset_op_test",
|
||||
size = "medium",
|
||||
srcs = ["map_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
"noasan", # times out
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/experimental/ops:error_ops",
|
||||
"//tensorflow/python/data/experimental/ops:optimization",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "filter_dataset_op_test",
|
||||
size = "medium",
|
||||
srcs = ["filter_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:io_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:optimization",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "map_defun_op_test",
|
||||
size = "small",
|
||||
srcs = ["map_defun_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/data/experimental/ops:map_defun",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "parsing_ops_test",
|
||||
size = "small",
|
||||
srcs = ["parsing_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -344,53 +437,20 @@ py_test(
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "prefetching_ops_test",
|
||||
name = "prefetch_to_device_test",
|
||||
size = "small",
|
||||
srcs = ["prefetching_ops_test.py"],
|
||||
srcs = ["prefetch_to_device_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python/compat:compat",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
"no_windows_gpu",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "range_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["range_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/experimental/ops:counter",
|
||||
"//tensorflow/python/data/experimental/ops:enumerate_ops",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
tags = ["no_windows_gpu"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -421,41 +481,12 @@ py_library(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "reader_dataset_ops_test",
|
||||
name = "rejection_resample_test",
|
||||
size = "medium",
|
||||
srcs = ["reader_dataset_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":reader_dataset_ops_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/data/experimental/ops:readers",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "resample_test",
|
||||
size = "medium",
|
||||
srcs = ["resample_test.py"],
|
||||
srcs = ["rejection_resample_test.py"],
|
||||
shard_count = 2,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
"noasan",
|
||||
"optonly",
|
||||
],
|
||||
@ -477,15 +508,27 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "scan_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["scan_dataset_op_test.py"],
|
||||
name = "restructured_dataset_test",
|
||||
size = "medium",
|
||||
srcs = ["restructured_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "scan_test",
|
||||
size = "small",
|
||||
srcs = ["scan_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -503,14 +546,12 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "shuffle_dataset_op_test",
|
||||
name = "shuffle_and_repeat_test",
|
||||
size = "medium",
|
||||
srcs = ["shuffle_dataset_op_test.py"],
|
||||
srcs = ["shuffle_and_repeat_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
"optonly",
|
||||
],
|
||||
deps = [
|
||||
@ -525,8 +566,8 @@ py_test(
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sql_dataset_op_test_base",
|
||||
srcs = ["sql_dataset_op_test_base.py"],
|
||||
name = "sql_dataset_test_base",
|
||||
srcs = ["sql_dataset_test_base.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = [
|
||||
"//tensorflow/python/data/experimental/kernel_tests:__pkg__",
|
||||
@ -543,17 +584,13 @@ py_library(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sql_dataset_op_test",
|
||||
name = "sql_dataset_test",
|
||||
size = "small",
|
||||
srcs = ["sql_dataset_op_test.py"],
|
||||
srcs = ["sql_dataset_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":sql_dataset_op_test_base",
|
||||
":sql_dataset_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
@ -565,11 +602,7 @@ py_test(
|
||||
size = "medium",
|
||||
srcs = ["stats_dataset_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":reader_dataset_ops_test_base",
|
||||
":stats_dataset_test_base",
|
||||
@ -595,59 +628,9 @@ py_library(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "threadpool_dataset_ops_test",
|
||||
name = "tf_record_writer_test",
|
||||
size = "small",
|
||||
srcs = ["threadpool_dataset_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python/data/experimental/ops:threadpool",
|
||||
"//tensorflow/python/data/experimental/ops:unique",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "unique_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["unique_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:unique",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "writer_ops_test",
|
||||
size = "small",
|
||||
srcs = ["writer_ops_test.py"],
|
||||
tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
"no_windows",
|
||||
],
|
||||
srcs = ["tf_record_writer_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -660,3 +643,45 @@ py_test(
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "unbatch_test",
|
||||
size = "medium",
|
||||
srcs = ["unbatch_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "unique_test",
|
||||
size = "small",
|
||||
srcs = ["unique_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:unique",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
@ -1,686 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testDenseToSparseBatchDataset(self):
|
||||
components = np.random.randint(12, size=(100,)).astype(np.int32)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([x], x)).apply(
|
||||
batching.dense_to_sparse_batch(4, [12]))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual([[i, j]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)], results.indices)
|
||||
self.assertAllEqual(
|
||||
[c for c in components[start:start + 4] for _ in range(c)],
|
||||
results.values)
|
||||
self.assertAllEqual([min(4,
|
||||
len(components) - start), 12],
|
||||
results.dense_shape)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testDenseToSparseBatchDatasetWithUnknownShape(self):
|
||||
components = np.random.randint(5, size=(40,)).astype(np.int32)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([x, x], x)).apply(
|
||||
batching.dense_to_sparse_batch(
|
||||
4, [5, None])).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual([[i, j, z]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)
|
||||
for z in range(c)], results.indices)
|
||||
self.assertAllEqual([
|
||||
c
|
||||
for c in components[start:start + 4] for _ in range(c)
|
||||
for _ in range(c)
|
||||
], results.values)
|
||||
self.assertAllEqual([
|
||||
min(4,
|
||||
len(components) - start), 5,
|
||||
np.max(components[start:start + 4])
|
||||
], results.dense_shape)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testDenseToSparseBatchDatasetWithInvalidShape(self):
|
||||
input_tensor = array_ops.constant([[1]])
|
||||
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
|
||||
dataset_ops.Dataset.from_tensors(input_tensor).apply(
|
||||
batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
|
||||
|
||||
def testDenseToSparseBatchDatasetShapeErrors(self):
|
||||
input_tensor = array_ops.placeholder(dtypes.int32)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensors(input_tensor).apply(
|
||||
batching.dense_to_sparse_batch(4, [12]))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Initialize with an input tensor of incompatible rank.
|
||||
sess.run(init_op, feed_dict={input_tensor: [[1]]})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"incompatible with the row shape"):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with an input tensor that is larger than `row_shape`.
|
||||
sess.run(init_op, feed_dict={input_tensor: range(13)})
|
||||
with self.assertRaisesRegexp(errors.DataLossError,
|
||||
"larger than the row shape"):
|
||||
sess.run(get_next)
|
||||
|
||||
def testUnbatchWithUnknownRankInput(self):
|
||||
placeholder = array_ops.placeholder(dtypes.int32)
|
||||
dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
|
||||
batching.unbatch())
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_elem = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||
for i in range(4):
|
||||
self.assertEqual(i, sess.run(next_elem))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_elem)
|
||||
|
||||
def testUnbatchScalarDataset(self):
|
||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
expected_types = (dtypes.int32,) * 3
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
data = data.apply(batching.unbatch())
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i,) * 3, sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testUnbatchDatasetWithStrings(self):
|
||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
|
||||
expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
data = data.apply(batching.unbatch())
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testUnbatchDatasetWithSparseTensor(self):
|
||||
st = sparse_tensor.SparseTensorValue(
|
||||
indices=[[i, i] for i in range(10)],
|
||||
values=list(range(10)),
|
||||
dense_shape=[10, 10])
|
||||
data = dataset_ops.Dataset.from_tensors(st)
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.batch(5)
|
||||
data = data.apply(batching.unbatch())
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
st_row = sess.run(next_element)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testUnbatchDatasetWithDenseAndSparseTensor(self):
|
||||
st = sparse_tensor.SparseTensorValue(
|
||||
indices=[[i, i] for i in range(10)],
|
||||
values=list(range(10)),
|
||||
dense_shape=[10, 10])
|
||||
data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.batch(5)
|
||||
data = data.apply(batching.unbatch())
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
dense_elem, st_row = sess.run(next_element)
|
||||
self.assertEqual(i, dense_elem)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testUnbatchSingleElementTupleDataset(self):
|
||||
data = tuple([(math_ops.range(10),) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
expected_types = ((dtypes.int32,),) * 3
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
data = data.apply(batching.unbatch())
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testUnbatchMultiElementTupleDataset(self):
|
||||
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
|
||||
array_ops.fill([10], "hi")) for i in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
expected_types = ((dtypes.int32, dtypes.string),) * 3
|
||||
data = data.batch(2)
|
||||
self.assertAllEqual(expected_types, data.output_types)
|
||||
data = data.apply(batching.unbatch())
|
||||
self.assertAllEqual(expected_types, data.output_types)
|
||||
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
|
||||
sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testUnbatchEmpty(self):
|
||||
data = dataset_ops.Dataset.from_tensors(
|
||||
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
|
||||
constant_op.constant([], shape=[0, 4, 0])))
|
||||
data = data.apply(batching.unbatch())
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testUnbatchStaticShapeMismatch(self):
|
||||
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
|
||||
np.arange(9)))
|
||||
with self.assertRaises(ValueError):
|
||||
data.apply(batching.unbatch())
|
||||
|
||||
def testUnbatchDynamicShapeMismatch(self):
|
||||
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
|
||||
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
|
||||
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
|
||||
data = data.apply(batching.unbatch())
|
||||
iterator = data.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Mismatch in the 0th dimension.
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
feed_dict={
|
||||
ph1: np.arange(7).astype(np.int32),
|
||||
ph2: np.arange(8).astype(np.int32)
|
||||
})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(next_element)
|
||||
|
||||
# No 0th dimension (i.e. scalar value) for one component.
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
feed_dict={
|
||||
ph1: np.arange(7).astype(np.int32),
|
||||
ph2: 7
|
||||
})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(next_element)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Default", None, None),
|
||||
("SequentialCalls", 1, None),
|
||||
("ParallelCalls", 2, None),
|
||||
("ParallelBatches", None, 10),
|
||||
)
|
||||
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
# The pipeline is TensorSliceDataset ->
|
||||
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
|
||||
components = (np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7))
|
||||
|
||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
|
||||
batching.map_and_batch(
|
||||
map_func=_map_fn,
|
||||
batch_size=batch_size,
|
||||
num_parallel_calls=num_parallel_calls,
|
||||
num_parallel_batches=num_parallel_batches))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Batch of a finite input, where the batch_size divides the
|
||||
# total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
num_batches = (28 * 7) // 14
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Batch of a finite input, where the batch_size does not
|
||||
# divide the total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 14, batch_size: 8})
|
||||
|
||||
# We expect (num_batches - 1) full-sized batches.
|
||||
num_batches = int(math.ceil((14 * 7) / 8))
|
||||
for i in range(num_batches - 1):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(8):
|
||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((14 * 7) % 8):
|
||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Batch of an empty input should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Empty batch should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Even", False),
|
||||
("Uneven", True),
|
||||
)
|
||||
def testMapAndBatchPartialBatch(self, drop_remainder):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(10).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: array_ops.reshape(x * x, [1]),
|
||||
batch_size=4,
|
||||
drop_remainder=drop_remainder)).make_one_shot_iterator())
|
||||
if drop_remainder:
|
||||
self.assertEqual([4, 1], iterator.output_shapes.as_list())
|
||||
else:
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
if not drop_remainder:
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testMapAndBatchYieldsPartialBatch(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.apply(batching.map_and_batch(
|
||||
lambda x: array_ops.reshape(x * x, [1]), 4))
|
||||
.make_one_shot_iterator())
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testMapAndBatchParallelGetNext(self):
|
||||
iterator = (dataset_ops.Dataset.range(50000)
|
||||
.apply(batching.map_and_batch(lambda x: x, batch_size=100))
|
||||
.make_one_shot_iterator())
|
||||
elements = []
|
||||
for _ in range(100):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
|
||||
self.assertAllEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elements)
|
||||
|
||||
def testMapAndBatchParallelGetNextDropRemainder(self):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(49999).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: x, batch_size=100, drop_remainder=True))
|
||||
.make_one_shot_iterator())
|
||||
elements = []
|
||||
for _ in range(100):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(4):
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
|
||||
self.assertAllEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elements)
|
||||
|
||||
def testMapAndBatchSparse(self):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0]], values=(i * [1]), dense_shape=[1])
|
||||
|
||||
iterator = dataset_ops.Dataset.range(10).apply(
|
||||
batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
dense_shape=[5, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMapAndBatchFails(self):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
array_ops.check_numerics(
|
||||
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = (
|
||||
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
|
||||
sess.run(init_op, feed_dict={batch_size: 14})
|
||||
|
||||
def testMapAndBatchShapeMismatch(self):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
|
||||
def generator():
|
||||
yield [1]
|
||||
yield [2]
|
||||
yield [3]
|
||||
yield [[4, 5, 6]]
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int32)
|
||||
batch_size = 4
|
||||
iterator = (
|
||||
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"number of elements does not match"):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMapAndBatchImplicitDispose(self):
|
||||
# Tests whether a map and batch dataset will be cleaned up correctly when
|
||||
# the pipeline does not run it until exhaustion.
|
||||
# The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
|
||||
# MapAndBatchDataset(f=square_3, batch_size=100).
|
||||
components = (np.arange(1000),
|
||||
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(1000))
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
|
||||
1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
|
||||
dataset = dataset.prefetch(5)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 0),
|
||||
("2", 5),
|
||||
("3", 10),
|
||||
("4", 90),
|
||||
("5", 95),
|
||||
("6", 99),
|
||||
)
|
||||
def testMapAndBatchOutOfRangeError(self, threshold):
|
||||
|
||||
def raising_py_fn(i):
|
||||
if i >= threshold:
|
||||
raise StopIteration()
|
||||
else:
|
||||
return i
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(100).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
|
||||
batch_size=10)).make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(threshold // 10):
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
||||
if threshold % 10 != 0:
|
||||
self.assertAllEqual(
|
||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||
sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", False, dtypes.bool),
|
||||
("2", -42, dtypes.int8),
|
||||
("3", -42, dtypes.int16),
|
||||
("4", -42, dtypes.int32),
|
||||
("5", -42, dtypes.int64),
|
||||
("6", 42, dtypes.uint8),
|
||||
("7", 42, dtypes.uint16),
|
||||
("8", 42.0, dtypes.float16),
|
||||
("9", 42.0, dtypes.float32),
|
||||
("10", 42.0, dtypes.float64),
|
||||
("11", b"hello", dtypes.string),
|
||||
)
|
||||
def testMapAndBatchTypes(self, element, dtype):
|
||||
def gen():
|
||||
yield element
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
|
||||
batching.map_and_batch(lambda x: x, batch_size=10))
|
||||
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
||||
|
||||
|
||||
class UnbatchDatasetBenchmark(test.Benchmark):
|
||||
|
||||
def benchmarkNativeUnbatch(self):
|
||||
batch_sizes = [1, 2, 5, 10, 20, 50]
|
||||
elems_per_trial = 10000
|
||||
with ops.Graph().as_default():
|
||||
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
|
||||
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
dataset = dataset.batch(batch_size_placeholder)
|
||||
dataset = dataset.apply(batching.unbatch())
|
||||
dataset = dataset.skip(elems_per_trial)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
for batch_size in batch_sizes:
|
||||
deltas = []
|
||||
for _ in range(5):
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
feed_dict={batch_size_placeholder: batch_size})
|
||||
start = time.time()
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append((end - start) / elems_per_trial)
|
||||
|
||||
median_wall_time = np.median(deltas)
|
||||
print("Unbatch (native) batch size: %d Median wall time per element:"
|
||||
" %f microseconds" % (batch_size, median_wall_time * 1e6))
|
||||
self.report_benchmark(
|
||||
iters=10000,
|
||||
wall_time=median_wall_time,
|
||||
name="benchmark_unbatch_dataset_native_batch_size_%d" %
|
||||
batch_size)
|
||||
|
||||
# Include a benchmark of the previous `unbatch()` implementation that uses
|
||||
# a composition of more primitive ops. Eventually we'd hope to generate code
|
||||
# that is as good in both cases.
|
||||
def benchmarkOldUnbatchImplementation(self):
|
||||
batch_sizes = [1, 2, 5, 10, 20, 50]
|
||||
elems_per_trial = 10000
|
||||
with ops.Graph().as_default():
|
||||
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
|
||||
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
dataset = dataset.batch(batch_size_placeholder)
|
||||
dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
|
||||
dataset = dataset.skip(elems_per_trial)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
for batch_size in batch_sizes:
|
||||
deltas = []
|
||||
for _ in range(5):
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
feed_dict={batch_size_placeholder: batch_size})
|
||||
start = time.time()
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append((end - start) / elems_per_trial)
|
||||
|
||||
median_wall_time = np.median(deltas)
|
||||
print("Unbatch (unfused) batch size: %d Median wall time per element:"
|
||||
" %f microseconds" % (batch_size, median_wall_time * 1e6))
|
||||
self.report_benchmark(
|
||||
iters=10000,
|
||||
wall_time=median_wall_time,
|
||||
name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
|
||||
batch_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,322 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.bucket_by_sequence_length()."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _element_length_fn(x, y=None):
|
||||
del y
|
||||
return array_ops.shape(x)[0]
|
||||
|
||||
|
||||
def _to_sparse_tensor(record):
|
||||
return sparse_tensor.SparseTensor(**record)
|
||||
|
||||
|
||||
def _format_record(array, sparse):
|
||||
if sparse:
|
||||
return {
|
||||
"values": array,
|
||||
"indices": [[i] for i in range(len(array))],
|
||||
"dense_shape": (len(array),)
|
||||
}
|
||||
return array
|
||||
|
||||
|
||||
def _get_record_type(sparse):
|
||||
if sparse:
|
||||
return {
|
||||
"values": dtypes.int64,
|
||||
"indices": dtypes.int64,
|
||||
"dense_shape": dtypes.int64
|
||||
}
|
||||
return dtypes.int32
|
||||
|
||||
|
||||
def _get_record_shape(sparse):
|
||||
if sparse:
|
||||
return {
|
||||
"values": tensor_shape.TensorShape([None,]),
|
||||
"indices": tensor_shape.TensorShape([None, 1]),
|
||||
"dense_shape": tensor_shape.TensorShape([1,])
|
||||
}
|
||||
return tensor_shape.TensorShape([None])
|
||||
|
||||
|
||||
class BucketBySequenceLengthTest(test_base.DatasetTestBase):
|
||||
|
||||
def testBucket(self):
|
||||
|
||||
boundaries = [10, 20, 30]
|
||||
batch_sizes = [10, 8, 4, 2]
|
||||
lengths = [8, 13, 25, 35]
|
||||
|
||||
def build_dataset(sparse):
|
||||
def _generator():
|
||||
# Produce 1 batch for each bucket
|
||||
elements = []
|
||||
for batch_size, length in zip(batch_sizes, lengths):
|
||||
record_len = length - 1
|
||||
for _ in range(batch_size):
|
||||
elements.append([1] * record_len)
|
||||
record_len = length
|
||||
random.shuffle(elements)
|
||||
for el in elements:
|
||||
yield (_format_record(el, sparse),)
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
_generator,
|
||||
(_get_record_type(sparse),),
|
||||
(_get_record_shape(sparse),))
|
||||
if sparse:
|
||||
dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
|
||||
return dataset
|
||||
|
||||
def _test_bucket_by_padding(no_padding):
|
||||
dataset = build_dataset(sparse=no_padding)
|
||||
dataset = dataset.apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
_element_length_fn,
|
||||
boundaries,
|
||||
batch_sizes,
|
||||
no_padding=no_padding))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(4):
|
||||
batches.append(sess.run(batch))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(batch)
|
||||
batch_sizes_val = []
|
||||
lengths_val = []
|
||||
for batch in batches:
|
||||
shape = batch.dense_shape if no_padding else batch.shape
|
||||
batch_size = shape[0]
|
||||
length = shape[1]
|
||||
batch_sizes_val.append(batch_size)
|
||||
lengths_val.append(length)
|
||||
sum_check = batch.values.sum() if no_padding else batch.sum()
|
||||
self.assertEqual(sum_check, batch_size * length - 1)
|
||||
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
|
||||
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
|
||||
self.assertEqual(sorted(lengths), sorted(lengths_val))
|
||||
|
||||
for no_padding in (True, False):
|
||||
_test_bucket_by_padding(no_padding)
|
||||
|
||||
def testPadToBoundary(self):
|
||||
|
||||
boundaries = [10, 20, 30]
|
||||
batch_sizes = [10, 8, 4, 2]
|
||||
lengths = [8, 13, 25]
|
||||
|
||||
def element_gen():
|
||||
# Produce 1 batch for each bucket
|
||||
elements = []
|
||||
for batch_size, length in zip(batch_sizes[:-1], lengths):
|
||||
for _ in range(batch_size):
|
||||
elements.append([1] * length)
|
||||
random.shuffle(elements)
|
||||
for el in elements:
|
||||
yield (el,)
|
||||
for _ in range(batch_sizes[-1]):
|
||||
el = [1] * (boundaries[-1] + 5)
|
||||
yield (el,)
|
||||
|
||||
element_len = lambda el: array_ops.shape(el)[0]
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
element_gen, (dtypes.int64,), ([None],)).apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
element_len, boundaries, batch_sizes,
|
||||
pad_to_bucket_boundary=True))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(3):
|
||||
batches.append(sess.run(batch))
|
||||
with self.assertRaisesOpError("bucket_boundaries"):
|
||||
sess.run(batch)
|
||||
batch_sizes_val = []
|
||||
lengths_val = []
|
||||
for batch in batches:
|
||||
batch_size = batch.shape[0]
|
||||
length = batch.shape[1]
|
||||
batch_sizes_val.append(batch_size)
|
||||
lengths_val.append(length)
|
||||
batch_sizes = batch_sizes[:-1]
|
||||
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
|
||||
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
|
||||
self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
|
||||
sorted(lengths_val))
|
||||
|
||||
def testPadToBoundaryNoExtraneousPadding(self):
|
||||
|
||||
boundaries = [3, 7, 11]
|
||||
batch_sizes = [2, 2, 2, 2]
|
||||
lengths = range(1, 11)
|
||||
|
||||
def element_gen():
|
||||
for length in lengths:
|
||||
yield ([1] * length,)
|
||||
|
||||
element_len = lambda element: array_ops.shape(element)[0]
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
element_gen, (dtypes.int64,), ([None],)).apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
element_len, boundaries, batch_sizes,
|
||||
pad_to_bucket_boundary=True))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(5):
|
||||
batches.append(sess.run(batch))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(batch)
|
||||
|
||||
self.assertAllEqual(batches[0], [[1, 0],
|
||||
[1, 1]])
|
||||
self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0]])
|
||||
self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 1]])
|
||||
self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
|
||||
self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
|
||||
def testTupleElements(self):
|
||||
|
||||
def build_dataset(sparse):
|
||||
def _generator():
|
||||
text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
|
||||
label = [1, 2, 1, 2]
|
||||
for x, y in zip(text, label):
|
||||
yield (_format_record(x, sparse), y)
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator=_generator,
|
||||
output_types=(_get_record_type(sparse), dtypes.int32),
|
||||
output_shapes=(_get_record_shape(sparse),
|
||||
tensor_shape.TensorShape([])))
|
||||
if sparse:
|
||||
dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
|
||||
return dataset
|
||||
|
||||
def _test_tuple_elements_by_padding(no_padding):
|
||||
dataset = build_dataset(sparse=no_padding)
|
||||
dataset = dataset.apply(grouping.bucket_by_sequence_length(
|
||||
element_length_func=_element_length_fn,
|
||||
bucket_batch_sizes=[2, 2, 2],
|
||||
bucket_boundaries=[0, 8],
|
||||
no_padding=no_padding))
|
||||
shapes = dataset.output_shapes
|
||||
self.assertEqual([None, None], shapes[0].as_list())
|
||||
self.assertEqual([None], shapes[1].as_list())
|
||||
|
||||
for no_padding in (True, False):
|
||||
_test_tuple_elements_by_padding(no_padding)
|
||||
|
||||
def testBucketSparse(self):
|
||||
"""Tests bucketing of sparse tensors (case where `no_padding` == True).
|
||||
|
||||
Test runs on following dataset:
|
||||
[
|
||||
[0],
|
||||
[0, 1],
|
||||
[0, 1, 2]
|
||||
...
|
||||
[0, ..., max_len - 1]
|
||||
]
|
||||
Sequences are bucketed by length and batched with
|
||||
`batch_size` < `bucket_size`.
|
||||
"""
|
||||
|
||||
min_len = 0
|
||||
max_len = 100
|
||||
batch_size = 7
|
||||
bucket_size = 10
|
||||
|
||||
def _build_dataset():
|
||||
input_data = [range(i+1) for i in range(min_len, max_len)]
|
||||
def generator_fn():
|
||||
for record in input_data:
|
||||
yield _format_record(record, sparse=True)
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator=generator_fn,
|
||||
output_types=_get_record_type(sparse=True))
|
||||
dataset = dataset.map(_to_sparse_tensor)
|
||||
return dataset
|
||||
|
||||
def _compute_expected_batches():
|
||||
"""Computes expected batch outputs and stores in a set."""
|
||||
all_expected_sparse_tensors = set()
|
||||
for bucket_start_len in range(min_len, max_len, bucket_size):
|
||||
for batch_offset in range(0, bucket_size, batch_size):
|
||||
batch_start_len = bucket_start_len + batch_offset
|
||||
batch_end_len = min(batch_start_len + batch_size,
|
||||
bucket_start_len + bucket_size)
|
||||
expected_indices = []
|
||||
expected_values = []
|
||||
for length in range(batch_start_len, batch_end_len):
|
||||
for val in range(length + 1):
|
||||
expected_indices.append((length - batch_start_len, val))
|
||||
expected_values.append(val)
|
||||
expected_sprs_tensor = (tuple(expected_indices),
|
||||
tuple(expected_values))
|
||||
all_expected_sparse_tensors.add(expected_sprs_tensor)
|
||||
return all_expected_sparse_tensors
|
||||
|
||||
def _compute_batches(dataset):
|
||||
"""Computes actual batch outputs of dataset and stores in a set."""
|
||||
batch = dataset.make_one_shot_iterator().get_next()
|
||||
all_sparse_tensors = set()
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
output = sess.run(batch)
|
||||
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
|
||||
tuple(output.values))
|
||||
all_sparse_tensors.add(sprs_tensor)
|
||||
return all_sparse_tensors
|
||||
|
||||
dataset = _build_dataset()
|
||||
boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
|
||||
dataset = dataset.apply(grouping.bucket_by_sequence_length(
|
||||
_element_length_fn,
|
||||
boundaries,
|
||||
[batch_size] * (len(boundaries) + 1),
|
||||
no_padding=True))
|
||||
batches = _compute_batches(dataset)
|
||||
expected_batches = _compute_expected_batches()
|
||||
self.assertEqual(batches, expected_batches)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,824 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
|
||||
def checkResults(self, dataset, shapes, values):
|
||||
self.assertEqual(shapes, dataset.output_shapes)
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
for expected in values:
|
||||
got = sess.run(get_next)
|
||||
self.assertEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSum(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda _: np.int64(0),
|
||||
reduce_func=lambda x, y: x + y,
|
||||
finalize_func=lambda x: x)
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.range(2 * i).apply(
|
||||
grouping.group_by_reducer(lambda x: x % 2, reducer))
|
||||
self.checkResults(
|
||||
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
|
||||
|
||||
def testAverage(self):
|
||||
|
||||
def reduce_fn(x, y):
|
||||
return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
|
||||
x[1] + 1), x[1] + 1
|
||||
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda _: (0.0, 0.0),
|
||||
reduce_func=reduce_fn,
|
||||
finalize_func=lambda x, _: x)
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.range(2 * i).apply(
|
||||
grouping.group_by_reducer(
|
||||
lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
|
||||
self.checkResults(
|
||||
dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
|
||||
|
||||
def testConcat(self):
|
||||
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: "",
|
||||
reduce_func=lambda x, y: x + y[0],
|
||||
finalize_func=lambda x: x)
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensor_slices(components),
|
||||
dataset_ops.Dataset.range(2 * i))).apply(
|
||||
grouping.group_by_reducer(lambda x, y: y % 2, reducer))
|
||||
self.checkResults(
|
||||
dataset,
|
||||
shapes=tensor_shape.scalar(),
|
||||
values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
|
||||
|
||||
def testSparseSum(self):
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=np.array([[0, 0]]),
|
||||
values=(i * np.array([1], dtype=np.int64)),
|
||||
dense_shape=np.array([1, 1]))
|
||||
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda _: _sparse(np.int64(0)),
|
||||
reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
|
||||
finalize_func=lambda x: x.values[0])
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
|
||||
grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
|
||||
self.checkResults(
|
||||
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
|
||||
|
||||
def testChangingStateShape(self):
|
||||
|
||||
def reduce_fn(x, _):
|
||||
# Statically known rank, but dynamic length.
|
||||
larger_dim = array_ops.concat([x[0], x[0]], 0)
|
||||
# Statically unknown rank.
|
||||
larger_rank = array_ops.expand_dims(x[1], 0)
|
||||
return larger_dim, larger_rank
|
||||
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: ([0], 1),
|
||||
reduce_func=reduce_fn,
|
||||
finalize_func=lambda x, y: (x, y))
|
||||
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
|
||||
grouping.group_by_reducer(lambda x: x, reducer))
|
||||
self.assertEqual([None], dataset.output_shapes[0].as_list())
|
||||
self.assertIs(None, dataset.output_shapes[1].ndims)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
self.assertAllEqual([0] * (2**i), x)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), y)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testTypeMismatch(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
|
||||
reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"The element types for the new state must match the initial state."):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: np.int64(0), reducer))
|
||||
|
||||
# TODO(b/78665031): Remove once non-scalar keys are supported.
|
||||
def testInvalidKeyShape(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: np.int64(0),
|
||||
reduce_func=lambda x, y: x + y,
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`key_func` must return a single tf.int64 tensor."):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
|
||||
|
||||
# TODO(b/78665031): Remove once non-int64 keys are supported.
|
||||
def testInvalidKeyType(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: np.int64(0),
|
||||
reduce_func=lambda x, y: x + y,
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`key_func` must return a single tf.int64 tensor."):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: "wrong", reducer))
|
||||
|
||||
def testTuple(self):
|
||||
def init_fn(_):
|
||||
return np.array([], dtype=np.int64), np.int64(0)
|
||||
|
||||
def reduce_fn(state, value):
|
||||
s1, s2 = state
|
||||
v1, v2 = value
|
||||
return array_ops.concat([s1, [v1]], 0), s2 + v2
|
||||
|
||||
def finalize_fn(s1, s2):
|
||||
return s1, s2
|
||||
|
||||
reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
|
||||
dataset = dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
|
||||
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
|
||||
self.assertEqual(y, 45)
|
||||
|
||||
|
||||
class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
|
||||
def testSimple(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
|
||||
.apply(
|
||||
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
|
||||
4)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
self.assertTrue(
|
||||
all(x % 2 == 0
|
||||
for x in result) or all(x % 2 == 1)
|
||||
for x in result)
|
||||
counts.append(result.shape[0])
|
||||
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
num_full_batches = len([c for c in counts if c == 4])
|
||||
self.assertGreaterEqual(num_full_batches, 24)
|
||||
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
|
||||
|
||||
def testImmediateOutput(self):
|
||||
components = np.array(
|
||||
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
|
||||
grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
|
||||
4)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
# The input is infinite, so this test demonstrates that:
|
||||
# 1. We produce output without having to consume the entire input,
|
||||
# 2. Different buckets can produce output at different rates, and
|
||||
# 3. For deterministic input, the output is deterministic.
|
||||
for _ in range(3):
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
|
||||
4)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
# The small outputs at the end are deterministically produced in key
|
||||
# order.
|
||||
self.assertAllEqual([0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1], sess.run(get_next))
|
||||
|
||||
def testEmpty(self):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(4).apply(
|
||||
grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Window size must be greater than zero, but got 0."):
|
||||
print(sess.run(get_next))
|
||||
|
||||
def testReduceFuncError(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
|
||||
def reduce_func(_, xs):
|
||||
# Introduce an incorrect padded shape that cannot (currently) be
|
||||
# detected at graph construction time.
|
||||
return xs.padded_batch(
|
||||
4,
|
||||
padded_shapes=(tensor_shape.TensorShape([]),
|
||||
constant_op.constant([5], dtype=dtypes.int64) * -1))
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
|
||||
grouping.group_by_window(lambda x, _: x % 2, reduce_func,
|
||||
32)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testConsumeWindowDatasetMoreThanOnce(self):
|
||||
components = np.random.randint(50, size=(200,)).astype(np.int64)
|
||||
|
||||
def reduce_func(key, window):
|
||||
# Apply two different kinds of padding to the input: tight
|
||||
# padding, and quantized (to a multiple of 10) padding.
|
||||
return dataset_ops.Dataset.zip((
|
||||
window.padded_batch(
|
||||
4, padded_shapes=tensor_shape.TensorShape([None])),
|
||||
window.padded_batch(
|
||||
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
|
||||
))
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
|
||||
.apply(grouping.group_by_window(
|
||||
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
|
||||
reduce_func, 4))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
tight_result, multiple_of_10_result = sess.run(get_next)
|
||||
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
|
||||
self.assertAllEqual(tight_result,
|
||||
multiple_of_10_result[:, :tight_result.shape[1]])
|
||||
counts.append(tight_result.shape[0])
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
|
||||
|
||||
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
|
||||
# Currently, they use a constant batch size, though should be made to use a
|
||||
# different batch size per key.
|
||||
class BucketTest(test_base.DatasetTestBase):
|
||||
|
||||
def _dynamicPad(self, bucket, window, window_size):
|
||||
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
|
||||
# generic form of padded_batch that pads every component
|
||||
# dynamically and does not rely on static shape information about
|
||||
# the arguments.
|
||||
return dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensors(bucket),
|
||||
window.padded_batch(
|
||||
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
|
||||
[None]), tensor_shape.TensorShape([3])))))
|
||||
|
||||
def testSingleBucket(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
grouping.group_by_window(
|
||||
lambda x, y, z: 0,
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
|
||||
|
||||
iterator = bucketed_dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
which_bucket, bucketed_values = sess.run(get_next)
|
||||
|
||||
self.assertEqual(0, which_bucket)
|
||||
|
||||
expected_scalar_int = np.arange(32, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
|
||||
for i in range(32):
|
||||
expected_unk_int64[i, :i] = i
|
||||
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
|
||||
|
||||
def testEvenOddBuckets(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
grouping.group_by_window(
|
||||
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
|
||||
|
||||
iterator = bucketed_dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
which_bucket_even, bucketed_values_even = sess.run(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values_even))
|
||||
self.assertEqual(3, len(bucketed_values_odd))
|
||||
|
||||
# Ensure bucket 0 was used for all minibatch entries.
|
||||
self.assertAllEqual(0, which_bucket_even)
|
||||
self.assertAllEqual(1, which_bucket_odd)
|
||||
|
||||
# Test the first bucket outputted, the events starting at 0
|
||||
expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2 * i] = 2 * i
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
|
||||
|
||||
# Test the second bucket outputted, the odds starting at 1
|
||||
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
|
||||
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return {
|
||||
"x": v,
|
||||
"y": array_ops.fill([v], v),
|
||||
"z": array_ops.fill([3], string_ops.as_string(v))
|
||||
}
|
||||
|
||||
def _dynamic_pad_fn(bucket, window, _):
|
||||
return dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensors(bucket),
|
||||
window.padded_batch(
|
||||
32, {
|
||||
"x": tensor_shape.TensorShape([]),
|
||||
"y": tensor_shape.TensorShape([None]),
|
||||
"z": tensor_shape.TensorShape([3])
|
||||
})))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
|
||||
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
|
||||
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
grouping.group_by_window(
|
||||
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
|
||||
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
|
||||
|
||||
iterator = bucketed_dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
which_bucket0, bucketed_values_even0 = sess.run(get_next)
|
||||
which_bucket1, bucketed_values_even1 = sess.run(get_next)
|
||||
|
||||
# Ensure that bucket 1 was completely filtered out
|
||||
self.assertAllEqual(0, which_bucket0)
|
||||
self.assertAllEqual(0, which_bucket1)
|
||||
self.assertAllEqual(
|
||||
np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
|
||||
self.assertAllEqual(
|
||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
|
||||
|
||||
def testDynamicWindowSize(self):
|
||||
components = np.arange(100).astype(np.int64)
|
||||
|
||||
# Key fn: even/odd
|
||||
# Reduce fn: batches of 5
|
||||
# Window size fn: even=5, odd=10
|
||||
|
||||
def window_size_func(key):
|
||||
window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
|
||||
return window_sizes[key]
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
|
||||
None, window_size_func))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
batches = 0
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
is_even = all(x % 2 == 0 for x in result)
|
||||
is_odd = all(x % 2 == 1 for x in result)
|
||||
self.assertTrue(is_even or is_odd)
|
||||
expected_batch_size = 5 if is_even else 10
|
||||
self.assertEqual(expected_batch_size, result.shape[0])
|
||||
batches += 1
|
||||
|
||||
self.assertEqual(batches, 15)
|
||||
|
||||
|
||||
def _element_length_fn(x, y=None):
|
||||
del y
|
||||
return array_ops.shape(x)[0]
|
||||
|
||||
|
||||
def _to_sparse_tensor(record):
|
||||
return sparse_tensor.SparseTensor(**record)
|
||||
|
||||
|
||||
def _format_record(array, sparse):
|
||||
if sparse:
|
||||
return {
|
||||
"values": array,
|
||||
"indices": [[i] for i in range(len(array))],
|
||||
"dense_shape": (len(array),)
|
||||
}
|
||||
return array
|
||||
|
||||
|
||||
def _get_record_type(sparse):
|
||||
if sparse:
|
||||
return {
|
||||
"values": dtypes.int64,
|
||||
"indices": dtypes.int64,
|
||||
"dense_shape": dtypes.int64
|
||||
}
|
||||
return dtypes.int32
|
||||
|
||||
|
||||
def _get_record_shape(sparse):
|
||||
if sparse:
|
||||
return {
|
||||
"values": tensor_shape.TensorShape([None,]),
|
||||
"indices": tensor_shape.TensorShape([None, 1]),
|
||||
"dense_shape": tensor_shape.TensorShape([1,])
|
||||
}
|
||||
return tensor_shape.TensorShape([None])
|
||||
|
||||
|
||||
class BucketBySequenceLength(test_base.DatasetTestBase):
|
||||
|
||||
def testBucket(self):
|
||||
|
||||
boundaries = [10, 20, 30]
|
||||
batch_sizes = [10, 8, 4, 2]
|
||||
lengths = [8, 13, 25, 35]
|
||||
|
||||
def build_dataset(sparse):
|
||||
def _generator():
|
||||
# Produce 1 batch for each bucket
|
||||
elements = []
|
||||
for batch_size, length in zip(batch_sizes, lengths):
|
||||
record_len = length - 1
|
||||
for _ in range(batch_size):
|
||||
elements.append([1] * record_len)
|
||||
record_len = length
|
||||
random.shuffle(elements)
|
||||
for el in elements:
|
||||
yield (_format_record(el, sparse),)
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
_generator,
|
||||
(_get_record_type(sparse),),
|
||||
(_get_record_shape(sparse),))
|
||||
if sparse:
|
||||
dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
|
||||
return dataset
|
||||
|
||||
def _test_bucket_by_padding(no_padding):
|
||||
dataset = build_dataset(sparse=no_padding)
|
||||
dataset = dataset.apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
_element_length_fn,
|
||||
boundaries,
|
||||
batch_sizes,
|
||||
no_padding=no_padding))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(4):
|
||||
batches.append(sess.run(batch))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(batch)
|
||||
batch_sizes_val = []
|
||||
lengths_val = []
|
||||
for batch in batches:
|
||||
shape = batch.dense_shape if no_padding else batch.shape
|
||||
batch_size = shape[0]
|
||||
length = shape[1]
|
||||
batch_sizes_val.append(batch_size)
|
||||
lengths_val.append(length)
|
||||
sum_check = batch.values.sum() if no_padding else batch.sum()
|
||||
self.assertEqual(sum_check, batch_size * length - 1)
|
||||
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
|
||||
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
|
||||
self.assertEqual(sorted(lengths), sorted(lengths_val))
|
||||
|
||||
for no_padding in (True, False):
|
||||
_test_bucket_by_padding(no_padding)
|
||||
|
||||
def testPadToBoundary(self):
|
||||
|
||||
boundaries = [10, 20, 30]
|
||||
batch_sizes = [10, 8, 4, 2]
|
||||
lengths = [8, 13, 25]
|
||||
|
||||
def element_gen():
|
||||
# Produce 1 batch for each bucket
|
||||
elements = []
|
||||
for batch_size, length in zip(batch_sizes[:-1], lengths):
|
||||
for _ in range(batch_size):
|
||||
elements.append([1] * length)
|
||||
random.shuffle(elements)
|
||||
for el in elements:
|
||||
yield (el,)
|
||||
for _ in range(batch_sizes[-1]):
|
||||
el = [1] * (boundaries[-1] + 5)
|
||||
yield (el,)
|
||||
|
||||
element_len = lambda el: array_ops.shape(el)[0]
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
element_gen, (dtypes.int64,), ([None],)).apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
element_len, boundaries, batch_sizes,
|
||||
pad_to_bucket_boundary=True))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(3):
|
||||
batches.append(sess.run(batch))
|
||||
with self.assertRaisesOpError("bucket_boundaries"):
|
||||
sess.run(batch)
|
||||
batch_sizes_val = []
|
||||
lengths_val = []
|
||||
for batch in batches:
|
||||
batch_size = batch.shape[0]
|
||||
length = batch.shape[1]
|
||||
batch_sizes_val.append(batch_size)
|
||||
lengths_val.append(length)
|
||||
batch_sizes = batch_sizes[:-1]
|
||||
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
|
||||
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
|
||||
self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
|
||||
sorted(lengths_val))
|
||||
|
||||
def testPadToBoundaryNoExtraneousPadding(self):
|
||||
|
||||
boundaries = [3, 7, 11]
|
||||
batch_sizes = [2, 2, 2, 2]
|
||||
lengths = range(1, 11)
|
||||
|
||||
def element_gen():
|
||||
for length in lengths:
|
||||
yield ([1] * length,)
|
||||
|
||||
element_len = lambda element: array_ops.shape(element)[0]
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
element_gen, (dtypes.int64,), ([None],)).apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
element_len, boundaries, batch_sizes,
|
||||
pad_to_bucket_boundary=True))
|
||||
batch, = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
batches = []
|
||||
for _ in range(5):
|
||||
batches.append(sess.run(batch))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(batch)
|
||||
|
||||
self.assertAllEqual(batches[0], [[1, 0],
|
||||
[1, 1]])
|
||||
self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0]])
|
||||
self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 1]])
|
||||
self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
|
||||
self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
|
||||
def testTupleElements(self):
|
||||
|
||||
def build_dataset(sparse):
|
||||
def _generator():
|
||||
text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
|
||||
label = [1, 2, 1, 2]
|
||||
for x, y in zip(text, label):
|
||||
yield (_format_record(x, sparse), y)
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator=_generator,
|
||||
output_types=(_get_record_type(sparse), dtypes.int32),
|
||||
output_shapes=(_get_record_shape(sparse),
|
||||
tensor_shape.TensorShape([])))
|
||||
if sparse:
|
||||
dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
|
||||
return dataset
|
||||
|
||||
def _test_tuple_elements_by_padding(no_padding):
|
||||
dataset = build_dataset(sparse=no_padding)
|
||||
dataset = dataset.apply(grouping.bucket_by_sequence_length(
|
||||
element_length_func=_element_length_fn,
|
||||
bucket_batch_sizes=[2, 2, 2],
|
||||
bucket_boundaries=[0, 8],
|
||||
no_padding=no_padding))
|
||||
shapes = dataset.output_shapes
|
||||
self.assertEqual([None, None], shapes[0].as_list())
|
||||
self.assertEqual([None], shapes[1].as_list())
|
||||
|
||||
for no_padding in (True, False):
|
||||
_test_tuple_elements_by_padding(no_padding)
|
||||
|
||||
def testBucketSparse(self):
|
||||
"""Tests bucketing of sparse tensors (case where `no_padding` == True).
|
||||
|
||||
Test runs on following dataset:
|
||||
[
|
||||
[0],
|
||||
[0, 1],
|
||||
[0, 1, 2]
|
||||
...
|
||||
[0, ..., max_len - 1]
|
||||
]
|
||||
Sequences are bucketed by length and batched with
|
||||
`batch_size` < `bucket_size`.
|
||||
"""
|
||||
|
||||
min_len = 0
|
||||
max_len = 100
|
||||
batch_size = 7
|
||||
bucket_size = 10
|
||||
|
||||
def _build_dataset():
|
||||
input_data = [range(i+1) for i in range(min_len, max_len)]
|
||||
def generator_fn():
|
||||
for record in input_data:
|
||||
yield _format_record(record, sparse=True)
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator=generator_fn,
|
||||
output_types=_get_record_type(sparse=True))
|
||||
dataset = dataset.map(_to_sparse_tensor)
|
||||
return dataset
|
||||
|
||||
def _compute_expected_batches():
|
||||
"""Computes expected batch outputs and stores in a set."""
|
||||
all_expected_sparse_tensors = set()
|
||||
for bucket_start_len in range(min_len, max_len, bucket_size):
|
||||
for batch_offset in range(0, bucket_size, batch_size):
|
||||
batch_start_len = bucket_start_len + batch_offset
|
||||
batch_end_len = min(batch_start_len + batch_size,
|
||||
bucket_start_len + bucket_size)
|
||||
expected_indices = []
|
||||
expected_values = []
|
||||
for length in range(batch_start_len, batch_end_len):
|
||||
for val in range(length + 1):
|
||||
expected_indices.append((length - batch_start_len, val))
|
||||
expected_values.append(val)
|
||||
expected_sprs_tensor = (tuple(expected_indices),
|
||||
tuple(expected_values))
|
||||
all_expected_sparse_tensors.add(expected_sprs_tensor)
|
||||
return all_expected_sparse_tensors
|
||||
|
||||
def _compute_batches(dataset):
|
||||
"""Computes actual batch outputs of dataset and stores in a set."""
|
||||
batch = dataset.make_one_shot_iterator().get_next()
|
||||
all_sparse_tensors = set()
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
output = sess.run(batch)
|
||||
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
|
||||
tuple(output.values))
|
||||
all_sparse_tensors.add(sprs_tensor)
|
||||
return all_sparse_tensors
|
||||
|
||||
dataset = _build_dataset()
|
||||
boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
|
||||
dataset = dataset.apply(grouping.bucket_by_sequence_length(
|
||||
_element_length_fn,
|
||||
boundaries,
|
||||
[batch_size] * (len(boundaries) + 1),
|
||||
no_padding=True))
|
||||
batches = _compute_batches(dataset)
|
||||
expected_batches = _compute_expected_batches()
|
||||
self.assertEqual(batches, expected_batches)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -12,440 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for prefetching_ops."""
|
||||
"""Tests for `tf.data.experimental.copy_to_device()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
|
||||
|
||||
def setUp(self):
|
||||
self._event = threading.Event()
|
||||
|
||||
def _create_ds_and_iterator(self, device0, initializable=False):
|
||||
|
||||
def gen():
|
||||
for i in range(1, 10):
|
||||
yield [float(i)]
|
||||
if i == 6:
|
||||
self._event.set()
|
||||
|
||||
with ops.device(device0):
|
||||
ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
|
||||
if initializable:
|
||||
ds_iterator = ds.make_initializable_iterator()
|
||||
else:
|
||||
ds_iterator = ds.make_one_shot_iterator()
|
||||
return (ds, ds_iterator)
|
||||
|
||||
def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
|
||||
ds_iterator_handle = ds_iterator.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, ds.output_types, ds.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
target = constant_op.constant(device0)
|
||||
with ops.device(device1):
|
||||
buffer_resource_handle = prefetching_ops.function_buffering_resource(
|
||||
f=_remote_fn,
|
||||
output_types=[dtypes.float32],
|
||||
target_device=target,
|
||||
string_arg=ds_iterator_handle,
|
||||
buffer_size=3,
|
||||
shared_name=buffer_name)
|
||||
|
||||
with ops.device(device1):
|
||||
prefetch_op = prefetching_ops.function_buffering_resource_get_next(
|
||||
function_buffer_resource=buffer_resource_handle,
|
||||
output_types=[dtypes.float32])
|
||||
reset_op = prefetching_ops.function_buffering_resource_reset(
|
||||
function_buffer_resource=buffer_resource_handle)
|
||||
destroy_op = resource_variable_ops.destroy_resource_op(
|
||||
buffer_resource_handle, ignore_lookup_error=True)
|
||||
|
||||
return (prefetch_op, reset_op, destroy_op)
|
||||
|
||||
def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
|
||||
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
|
||||
prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
|
||||
device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testSameDeviceCPU(self):
|
||||
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
||||
"/job:localhost/replica:0/task:0/cpu:0",
|
||||
"/job:localhost/replica:0/task:0/cpu:0")
|
||||
|
||||
def testDifferentDeviceCPU(self):
|
||||
self._prefetch_fn_helper_one_shot("diff_device_cpu",
|
||||
"/job:localhost/replica:0/task:0/cpu:0",
|
||||
"/job:localhost/replica:0/task:0/cpu:1")
|
||||
|
||||
def testDifferentDeviceCPUGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
self._prefetch_fn_helper_one_shot("cpu_gpu",
|
||||
"/job:localhost/replica:0/task:0/cpu:0",
|
||||
"/job:localhost/replica:0/task:0/gpu:0")
|
||||
|
||||
def testReinitialization(self):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
|
||||
device0 = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
device1 = "/job:localhost/replica:0/task:0/cpu:1"
|
||||
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
|
||||
prefetch_op, reset_op, destroy_op = self._create_ops(
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
# Lets reset the function buffering resource and reinitialize the
|
||||
# iterator. Should be able to go through this again.
|
||||
self._event.clear()
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testReinitializationOutOfRange(self):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
|
||||
device0 = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
device1 = "/job:localhost/replica:0/task:0/cpu:1"
|
||||
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
|
||||
prefetch_op, reset_op, destroy_op = self._create_ops(
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
|
||||
# Now reset everything and try it out again.
|
||||
self._event.clear()
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testStringsGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
device0 = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
device1 = "/job:localhost/replica:0/task:0/gpu:0"
|
||||
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
|
||||
ds_iterator = ds.make_one_shot_iterator()
|
||||
ds_iterator_handle = ds_iterator.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, ds.output_types, ds.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
target = constant_op.constant(device0)
|
||||
with ops.device(device1):
|
||||
buffer_resource_handle = prefetching_ops.function_buffering_resource(
|
||||
f=_remote_fn,
|
||||
output_types=[dtypes.string],
|
||||
target_device=target,
|
||||
string_arg=ds_iterator_handle,
|
||||
buffer_size=3,
|
||||
shared_name="strings")
|
||||
|
||||
with ops.device(device1):
|
||||
prefetch_op = prefetching_ops.function_buffering_resource_get_next(
|
||||
function_buffer_resource=buffer_resource_handle,
|
||||
output_types=[dtypes.string])
|
||||
destroy_op = resource_variable_ops.destroy_resource_op(
|
||||
buffer_resource_handle, ignore_lookup_error=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual([b"a"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"b"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"c"], sess.run(prefetch_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
|
||||
sess.run(destroy_op)
|
||||
|
||||
|
||||
class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
def testPrefetchToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
self.assertEqual([], next_element.shape)
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToSameDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
self.assertEqual([], next_element.shape)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchDictToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element["a"].dtype)
|
||||
self.assertEqual([], next_element["a"].shape)
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchSparseTensorsToDevice(self):
|
||||
def make_tensor(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
|
||||
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = sess.run(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToDeviceGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToDeviceWithReInit(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
self.assertEqual([], next_element.shape)
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToDeviceGpuWithReInit(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
||||
class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
def testCopyToDevice(self):
|
@ -0,0 +1,51 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.Counter`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import counter
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CounterTest(test_base.DatasetTestBase):
|
||||
|
||||
def testCounter(self):
|
||||
"""Test dataset construction using `count`."""
|
||||
iterator = (counter.Counter(start=3, step=4)
|
||||
.make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
self.assertEqual([], get_next.shape.as_list())
|
||||
self.assertEqual(dtypes.int64, get_next.dtype)
|
||||
|
||||
negative_iterator = (counter.Counter(start=0, step=-1)
|
||||
.make_one_shot_iterator())
|
||||
negative_get_next = negative_iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
self.assertEqual(3 + 4, sess.run(get_next))
|
||||
self.assertEqual(3 + 2 * 4, sess.run(get_next))
|
||||
|
||||
self.assertEqual(0, sess.run(negative_get_next))
|
||||
self.assertEqual(-1, sess.run(negative_get_next))
|
||||
self.assertEqual(-2, sess.run(negative_get_next))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for CsvDatasetOp."""
|
||||
"""Tests for `tf.data.experimental.CsvDataset`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -44,7 +44,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CsvDatasetOpTest(test_base.DatasetTestBase):
|
||||
class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
|
||||
filenames = []
|
@ -1,692 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Base class for testing serializable datasets."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def remove_variants(get_next_op):
|
||||
# TODO(b/72408568): Remove this once session.run can get
|
||||
# variant tensors.
|
||||
"""Remove variants from a nest structure, so sess.run will execute."""
|
||||
|
||||
def _remove_variant(x):
|
||||
if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
|
||||
return ()
|
||||
else:
|
||||
return x
|
||||
|
||||
return nest.map_structure(_remove_variant, get_next_op)
|
||||
|
||||
|
||||
class DatasetSerializationTestBase(test.TestCase):
|
||||
"""Base class for testing serializable datasets."""
|
||||
|
||||
def tearDown(self):
|
||||
self._delete_ckpt()
|
||||
|
||||
# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
|
||||
# (deprecated) saveable `SparseTensorSliceDataset`, once the API
|
||||
# `from_sparse_tensor_slices()`and related tests are deleted.
|
||||
def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
|
||||
"""Runs the core tests.
|
||||
|
||||
Args:
|
||||
ds_fn1: 0-argument function that returns a Dataset.
|
||||
ds_fn2: 0-argument function that returns a Dataset different from
|
||||
ds_fn1. If None, verify_restore_in_modified_graph test is not run.
|
||||
num_outputs: Total number of outputs expected from this Dataset.
|
||||
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
self.verify_unused_iterator(
|
||||
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
|
||||
self.verify_fully_used_iterator(
|
||||
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
|
||||
self.verify_exhausted_iterator(
|
||||
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
|
||||
self.verify_init_before_restore(
|
||||
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
|
||||
self.verify_multiple_breaks(
|
||||
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
|
||||
self.verify_reset_restored_iterator(
|
||||
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
|
||||
self.verify_restore_in_empty_graph(
|
||||
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
|
||||
if ds_fn2:
|
||||
self.verify_restore_in_modified_graph(
|
||||
ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
|
||||
|
||||
def verify_unused_iterator(self,
|
||||
ds_fn,
|
||||
num_outputs,
|
||||
sparse_tensors=False,
|
||||
verify_exhausted=True):
|
||||
"""Verifies that saving and restoring an unused iterator works.
|
||||
|
||||
Args:
|
||||
ds_fn: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
verify_exhausted: See `gen_outputs`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
self.verify_run_with_breaks(
|
||||
ds_fn, [0],
|
||||
num_outputs,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=verify_exhausted)
|
||||
|
||||
def verify_fully_used_iterator(self, ds_fn, num_outputs,
|
||||
sparse_tensors=False):
|
||||
"""Verifies that saving and restoring a fully used iterator works.
|
||||
|
||||
Note that this only checks saving and restoring an iterator from which
|
||||
`num_outputs` items have been produced but does not check for an
|
||||
exhausted iterator, i.e., one from which an OutOfRange error has been
|
||||
returned.
|
||||
|
||||
Args:
|
||||
ds_fn: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
|
||||
Raises:
|
||||
AssertionError if test fails.
|
||||
"""
|
||||
self.verify_run_with_breaks(
|
||||
ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
|
||||
|
||||
def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
|
||||
"""Verifies that saving and restoring an exhausted iterator works.
|
||||
|
||||
An exhausted iterator is one which has returned an OutOfRange error.
|
||||
|
||||
Args:
|
||||
ds_fn: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
self.gen_outputs(
|
||||
ds_fn, [],
|
||||
num_outputs,
|
||||
verify_exhausted=True,
|
||||
sparse_tensors=sparse_tensors)
|
||||
actual = self.gen_outputs(
|
||||
ds_fn, [],
|
||||
0,
|
||||
ckpt_saved=True,
|
||||
verify_exhausted=True,
|
||||
sparse_tensors=sparse_tensors)
|
||||
self.assertEqual(len(actual), 0)
|
||||
|
||||
def verify_init_before_restore(self,
|
||||
ds_fn,
|
||||
num_outputs,
|
||||
sparse_tensors=False,
|
||||
verify_exhausted=True):
|
||||
"""Verifies that restoring into an already initialized iterator works.
|
||||
|
||||
Args:
|
||||
ds_fn: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
verify_exhausted: See `gen_outputs`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
self.verify_run_with_breaks(
|
||||
ds_fn,
|
||||
self.gen_break_points(num_outputs),
|
||||
num_outputs,
|
||||
init_before_restore=True,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=verify_exhausted)
|
||||
|
||||
def verify_multiple_breaks(self,
|
||||
ds_fn,
|
||||
num_outputs,
|
||||
num_breaks=10,
|
||||
sparse_tensors=False,
|
||||
verify_exhausted=True):
|
||||
"""Attempts to save/restore at multiple break points.
|
||||
|
||||
Args:
|
||||
ds_fn: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
num_breaks: The number of break points. These are uniformly spread in
|
||||
[0, num_outputs] both inclusive.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
verify_exhausted: See `gen_outputs`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
self.verify_run_with_breaks(
|
||||
ds_fn,
|
||||
self.gen_break_points(num_outputs, num_breaks),
|
||||
num_outputs,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=verify_exhausted)
|
||||
|
||||
def verify_reset_restored_iterator(self,
|
||||
ds_fn,
|
||||
num_outputs,
|
||||
break_point=None,
|
||||
sparse_tensors=False,
|
||||
verify_exhausted=True):
|
||||
"""Attempts to re-initialize a restored iterator.
|
||||
|
||||
This is useful when restoring a training checkpoint during validation.
|
||||
|
||||
Args:
|
||||
ds_fn: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
break_point: Break point. Optional. Defaults to num_outputs/2.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
verify_exhausted: See `gen_outputs`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
break_point = num_outputs // 2 if not break_point else break_point
|
||||
|
||||
# Collect ground truth containing all outputs.
|
||||
expected = self.gen_outputs(
|
||||
ds_fn, [],
|
||||
num_outputs,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=verify_exhausted)
|
||||
|
||||
# Skip some items and save checkpoint.
|
||||
self.gen_outputs(
|
||||
ds_fn, [],
|
||||
break_point,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=False)
|
||||
|
||||
actual = []
|
||||
# Restore from checkpoint and then run init_op.
|
||||
with ops.Graph().as_default() as g:
|
||||
saver = self._import_meta_graph()
|
||||
init_op, get_next_op = self._get_iterator_ops_from_collection(
|
||||
ds_fn, sparse_tensors=sparse_tensors)
|
||||
get_next_op = remove_variants(get_next_op)
|
||||
with self.session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
self._initialize(init_op, sess)
|
||||
for _ in range(num_outputs):
|
||||
actual.append(sess.run(get_next_op))
|
||||
if verify_exhausted:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
self.match(expected, actual)
|
||||
|
||||
def verify_restore_in_modified_graph(self,
|
||||
ds_fn1,
|
||||
ds_fn2,
|
||||
num_outputs,
|
||||
break_point=None,
|
||||
sparse_tensors=False,
|
||||
verify_exhausted=True):
|
||||
"""Attempts to restore an iterator in a modified graph.
|
||||
|
||||
Builds an input pipeline using ds_fn1, runs it for `break_point` steps
|
||||
and saves a checkpoint. Then builds a new graph using ds_fn2, restores
|
||||
the checkpoint from ds_fn1 and verifies that the restore is successful.
|
||||
|
||||
Args:
|
||||
ds_fn1: See `run_core_tests`.
|
||||
ds_fn2: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
break_point: Break point. Optional. Defaults to num_outputs/2.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
verify_exhausted: See `gen_outputs`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
break_point = num_outputs // 2 if not break_point else break_point
|
||||
|
||||
# Skip `break_point` items and store the remaining produced from ds_fn1
|
||||
# in `expected`.
|
||||
self.gen_outputs(
|
||||
ds_fn1, [],
|
||||
break_point,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=False)
|
||||
expected = self.gen_outputs(
|
||||
ds_fn1, [],
|
||||
num_outputs - break_point,
|
||||
ckpt_saved=True,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=verify_exhausted)
|
||||
|
||||
# Generate `break_point` items from ds_fn1 and save checkpoint.
|
||||
self.gen_outputs(
|
||||
ds_fn1, [],
|
||||
break_point,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=False)
|
||||
|
||||
actual = []
|
||||
# Build graph for ds_fn2 but load checkpoint for ds_fn1.
|
||||
with ops.Graph().as_default() as g:
|
||||
_, get_next_op, saver = self._build_graph(
|
||||
ds_fn2, sparse_tensors=sparse_tensors)
|
||||
get_next_op = remove_variants(get_next_op)
|
||||
with self.session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
for _ in range(num_outputs - break_point):
|
||||
actual.append(sess.run(get_next_op))
|
||||
if verify_exhausted:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
self.match(expected, actual)
|
||||
|
||||
def verify_restore_in_empty_graph(self,
|
||||
ds_fn,
|
||||
num_outputs,
|
||||
break_point=None,
|
||||
sparse_tensors=False,
|
||||
verify_exhausted=True):
|
||||
"""Attempts to restore an iterator in an empty graph.
|
||||
|
||||
Builds an input pipeline using ds_fn, runs it for `break_point` steps
|
||||
and saves a checkpoint. Then builds a new empty graph, restores
|
||||
the checkpoint from ds_fn and verifies that the restore is successful.
|
||||
|
||||
Args:
|
||||
ds_fn: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
break_point: Break point. Optional. Defaults to num_outputs/2.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
verify_exhausted: See `gen_outputs`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
break_point = num_outputs // 2 if not break_point else break_point
|
||||
|
||||
# Skip `break_point` items and store the remaining produced from ds_fn
|
||||
# in `expected`.
|
||||
self.gen_outputs(
|
||||
ds_fn, [],
|
||||
break_point,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=False)
|
||||
expected = self.gen_outputs(
|
||||
ds_fn, [],
|
||||
num_outputs - break_point,
|
||||
ckpt_saved=True,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=verify_exhausted)
|
||||
|
||||
# Generate `break_point` items from ds_fn and save checkpoint.
|
||||
self.gen_outputs(
|
||||
ds_fn, [],
|
||||
break_point,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=False)
|
||||
|
||||
actual = []
|
||||
# Build an empty graph but load checkpoint for ds_fn.
|
||||
with ops.Graph().as_default() as g:
|
||||
get_next_op, saver = self._build_empty_graph(
|
||||
ds_fn, sparse_tensors=sparse_tensors)
|
||||
get_next_op = remove_variants(get_next_op)
|
||||
with self.session(graph=g) as sess:
|
||||
self._restore(saver, sess)
|
||||
for _ in range(num_outputs - break_point):
|
||||
actual.append(sess.run(get_next_op))
|
||||
if verify_exhausted:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
|
||||
self.match(expected, actual)
|
||||
|
||||
def verify_error_on_save(self,
|
||||
ds_fn,
|
||||
num_outputs,
|
||||
error,
|
||||
break_point=None,
|
||||
sparse_tensors=False):
|
||||
"""Attempts to save a non-saveable iterator.
|
||||
|
||||
Args:
|
||||
ds_fn: See `run_core_tests`.
|
||||
num_outputs: See `run_core_tests`.
|
||||
error: Declared error when trying to save iterator.
|
||||
break_point: Break point. Optional. Defaults to num_outputs/2.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
|
||||
break_point = num_outputs // 2 if not break_point else break_point
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
ds_fn, sparse_tensors=sparse_tensors)
|
||||
get_next_op = remove_variants(get_next_op)
|
||||
with self.session(graph=g) as sess:
|
||||
self._initialize(init_op, sess)
|
||||
for _ in range(break_point):
|
||||
sess.run(get_next_op)
|
||||
with self.assertRaises(error):
|
||||
self._save(sess, saver)
|
||||
|
||||
def verify_run_with_breaks(self,
|
||||
ds_fn,
|
||||
break_points,
|
||||
num_outputs,
|
||||
init_before_restore=False,
|
||||
sparse_tensors=False,
|
||||
verify_exhausted=True):
|
||||
"""Verifies that ds_fn() produces the same outputs with and without breaks.
|
||||
|
||||
1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
|
||||
*without* stopping at break points.
|
||||
2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
|
||||
with stopping at break points.
|
||||
|
||||
Deep matches outputs from 1 and 2.
|
||||
|
||||
Args:
|
||||
ds_fn: See `gen_outputs`.
|
||||
break_points: See `gen_outputs`.
|
||||
num_outputs: See `gen_outputs`.
|
||||
init_before_restore: See `gen_outputs`.
|
||||
sparse_tensors: See `run_core_tests`.
|
||||
verify_exhausted: See `gen_outputs`.
|
||||
|
||||
Raises:
|
||||
AssertionError if any test fails.
|
||||
"""
|
||||
expected = self.gen_outputs(
|
||||
ds_fn, [],
|
||||
num_outputs,
|
||||
init_before_restore=init_before_restore,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=verify_exhausted)
|
||||
|
||||
actual = self.gen_outputs(
|
||||
ds_fn,
|
||||
break_points,
|
||||
num_outputs,
|
||||
init_before_restore=init_before_restore,
|
||||
sparse_tensors=sparse_tensors,
|
||||
verify_exhausted=verify_exhausted)
|
||||
|
||||
self.match(expected, actual)
|
||||
|
||||
def gen_outputs(self,
|
||||
ds_fn,
|
||||
break_points,
|
||||
num_outputs,
|
||||
ckpt_saved=False,
|
||||
init_before_restore=False,
|
||||
sparse_tensors=False,
|
||||
verify_exhausted=True,
|
||||
save_checkpoint_at_end=True):
|
||||
"""Generates elements from input dataset while stopping at break points.
|
||||
|
||||
Produces `num_outputs` outputs and saves the state of the iterator in the
|
||||
Saver checkpoint.
|
||||
|
||||
Args:
|
||||
ds_fn: 0-argument function that returns the dataset.
|
||||
break_points: A list of integers. For each `break_point` in
|
||||
`break_points`, we produce outputs till `break_point` number of items
|
||||
have been produced and then checkpoint the state. The current graph
|
||||
and session are destroyed and a new graph and session are used to
|
||||
produce outputs till next checkpoint or till `num_outputs` elements
|
||||
have been produced. `break_point` must be <= `num_outputs`.
|
||||
num_outputs: The total number of outputs to produce from the iterator.
|
||||
ckpt_saved: Whether a checkpoint already exists. If False, we build the
|
||||
graph from ds_fn.
|
||||
init_before_restore: Whether init should be called before saver.restore.
|
||||
This is just so that we can verify that restoring an already initialized
|
||||
iterator works.
|
||||
sparse_tensors: Whether dataset is built from SparseTensor(s).
|
||||
verify_exhausted: Whether to verify that the iterator has been exhausted
|
||||
after producing `num_outputs` elements.
|
||||
save_checkpoint_at_end: Whether to save a checkpoint after producing all
|
||||
outputs. If False, checkpoints are saved each break point but not at the
|
||||
end. Note that checkpoints overwrite each other so there is always only
|
||||
a single checkpoint available. Defaults to True.
|
||||
|
||||
Returns:
|
||||
A list of `num_outputs` items.
|
||||
"""
|
||||
outputs = []
|
||||
|
||||
def get_ops():
|
||||
if ckpt_saved:
|
||||
saver = self._import_meta_graph()
|
||||
init_op, get_next_op = self._get_iterator_ops_from_collection(
|
||||
ds_fn, sparse_tensors=sparse_tensors)
|
||||
else:
|
||||
init_op, get_next_op, saver = self._build_graph(
|
||||
ds_fn, sparse_tensors=sparse_tensors)
|
||||
return init_op, get_next_op, saver
|
||||
|
||||
for i in range(len(break_points) + 1):
|
||||
with ops.Graph().as_default() as g:
|
||||
init_op, get_next_op, saver = get_ops()
|
||||
get_next_op = remove_variants(get_next_op)
|
||||
with self.session(graph=g) as sess:
|
||||
if ckpt_saved:
|
||||
if init_before_restore:
|
||||
self._initialize(init_op, sess)
|
||||
self._restore(saver, sess)
|
||||
else:
|
||||
self._initialize(init_op, sess)
|
||||
start = break_points[i - 1] if i > 0 else 0
|
||||
end = break_points[i] if i < len(break_points) else num_outputs
|
||||
num_iters = end - start
|
||||
for _ in range(num_iters):
|
||||
outputs.append(sess.run(get_next_op))
|
||||
if i == len(break_points) and verify_exhausted:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next_op)
|
||||
if save_checkpoint_at_end or i < len(break_points):
|
||||
self._save(sess, saver)
|
||||
ckpt_saved = True
|
||||
|
||||
return outputs
|
||||
|
||||
def match(self, expected, actual):
|
||||
"""Matches nested structures.
|
||||
|
||||
Recursively matches shape and values of `expected` and `actual`.
|
||||
Handles scalars, numpy arrays and other python sequence containers
|
||||
e.g. list, dict.
|
||||
|
||||
Args:
|
||||
expected: Nested structure 1.
|
||||
actual: Nested structure 2.
|
||||
|
||||
Raises:
|
||||
AssertionError if matching fails.
|
||||
"""
|
||||
if isinstance(expected, np.ndarray):
|
||||
expected = expected.tolist()
|
||||
if isinstance(actual, np.ndarray):
|
||||
actual = actual.tolist()
|
||||
self.assertEqual(type(expected), type(actual))
|
||||
|
||||
if nest.is_sequence(expected):
|
||||
self.assertEqual(len(expected), len(actual))
|
||||
if isinstance(expected, dict):
|
||||
for key1, key2 in zip(sorted(expected), sorted(actual)):
|
||||
self.assertEqual(key1, key2)
|
||||
self.match(expected[key1], actual[key2])
|
||||
else:
|
||||
for item1, item2 in zip(expected, actual):
|
||||
self.match(item1, item2)
|
||||
else:
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def does_not_match(self, expected, actual):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.match(expected, actual)
|
||||
|
||||
def gen_break_points(self, num_outputs, num_samples=10):
|
||||
"""Generates `num_samples` breaks points in [0, num_outputs]."""
|
||||
return np.linspace(0, num_outputs, num_samples, dtype=int)
|
||||
|
||||
def _build_graph(self, ds_fn, sparse_tensors=False):
|
||||
iterator = ds_fn().make_initializable_iterator()
|
||||
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
init_op = iterator.initializer
|
||||
if sparse_tensors:
|
||||
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
|
||||
else:
|
||||
get_next = iterator.get_next()
|
||||
self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
|
||||
sparse_tensors)
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
return init_op, get_next, saver
|
||||
|
||||
def _build_empty_graph(self, ds_fn, sparse_tensors=False):
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
self._get_output_types(ds_fn),
|
||||
output_shapes=self._get_output_shapes(ds_fn),
|
||||
output_classes=self._get_output_classes(ds_fn))
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
if sparse_tensors:
|
||||
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
|
||||
else:
|
||||
get_next = iterator.get_next()
|
||||
saver = saver_lib.Saver(allow_empty=True)
|
||||
return get_next, saver
|
||||
|
||||
def _add_iterator_ops_to_collection(self,
|
||||
init_op,
|
||||
get_next,
|
||||
ds_fn,
|
||||
sparse_tensors=False):
|
||||
ops.add_to_collection("iterator_ops", init_op)
|
||||
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
|
||||
# do not support tuples we flatten the tensors and restore the shape in
|
||||
# `_get_iterator_ops_from_collection`.
|
||||
if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
|
||||
ops.add_to_collection("iterator_ops", get_next.indices)
|
||||
ops.add_to_collection("iterator_ops", get_next.values)
|
||||
ops.add_to_collection("iterator_ops", get_next.dense_shape)
|
||||
return
|
||||
|
||||
get_next_list = nest.flatten(get_next)
|
||||
for i, output_class in enumerate(
|
||||
nest.flatten(self._get_output_classes(ds_fn))):
|
||||
if output_class is sparse_tensor.SparseTensor:
|
||||
ops.add_to_collection("iterator_ops", get_next_list[i].indices)
|
||||
ops.add_to_collection("iterator_ops", get_next_list[i].values)
|
||||
ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
|
||||
else:
|
||||
ops.add_to_collection("iterator_ops", get_next_list[i])
|
||||
|
||||
def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
|
||||
all_ops = ops.get_collection("iterator_ops")
|
||||
if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
|
||||
init_op, indices, values, dense_shape = all_ops
|
||||
return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
|
||||
get_next_list = []
|
||||
i = 1
|
||||
for output_class in nest.flatten(self._get_output_classes(ds_fn)):
|
||||
if output_class is sparse_tensor.SparseTensor:
|
||||
indices, values, dense_shape = all_ops[i:i + 3]
|
||||
i += 3
|
||||
get_next_list.append(
|
||||
sparse_tensor.SparseTensor(indices, values, dense_shape))
|
||||
else:
|
||||
get_next_list.append(all_ops[i])
|
||||
i += 1
|
||||
return all_ops[0], nest.pack_sequence_as(
|
||||
self._get_output_types(ds_fn), get_next_list)
|
||||
|
||||
def _get_output_types(self, ds_fn):
|
||||
with ops.Graph().as_default():
|
||||
return ds_fn().output_types
|
||||
|
||||
def _get_output_shapes(self, ds_fn):
|
||||
with ops.Graph().as_default():
|
||||
return ds_fn().output_shapes
|
||||
|
||||
def _get_output_classes(self, ds_fn):
|
||||
with ops.Graph().as_default():
|
||||
return ds_fn().output_classes
|
||||
|
||||
def _ckpt_path(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def _latest_ckpt(self):
|
||||
return checkpoint_management.latest_checkpoint(self.get_temp_dir())
|
||||
|
||||
def _save(self, sess, saver):
|
||||
saver.save(sess, self._ckpt_path())
|
||||
|
||||
def _restore(self, saver, sess):
|
||||
sess.run(lookup_ops.tables_initializer())
|
||||
saver.restore(sess, self._latest_ckpt())
|
||||
|
||||
def _initialize(self, init_op, sess):
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess.run(lookup_ops.tables_initializer())
|
||||
sess.run(init_op)
|
||||
|
||||
def _import_meta_graph(self):
|
||||
meta_file_path = self._ckpt_path() + ".meta"
|
||||
return saver_lib.import_meta_graph(meta_file_path)
|
||||
|
||||
def _delete_ckpt(self):
|
||||
# Remove all checkpoint files.
|
||||
prefix = self._ckpt_path()
|
||||
pattern = prefix + "*"
|
||||
files = gfile.Glob(pattern)
|
||||
map(gfile.Remove, files)
|
@ -0,0 +1,124 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.dense_to_sparse_batch()."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DenseToSparseBatchTest(test_base.DatasetTestBase):
|
||||
|
||||
def testDenseToSparseBatchDataset(self):
|
||||
components = np.random.randint(12, size=(100,)).astype(np.int32)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([x], x)).apply(
|
||||
batching.dense_to_sparse_batch(4, [12]))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual([[i, j]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)], results.indices)
|
||||
self.assertAllEqual(
|
||||
[c for c in components[start:start + 4] for _ in range(c)],
|
||||
results.values)
|
||||
self.assertAllEqual([min(4,
|
||||
len(components) - start), 12],
|
||||
results.dense_shape)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testDenseToSparseBatchDatasetWithUnknownShape(self):
|
||||
components = np.random.randint(5, size=(40,)).astype(np.int32)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([x, x], x)).apply(
|
||||
batching.dense_to_sparse_batch(
|
||||
4, [5, None])).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual([[i, j, z]
|
||||
for i, c in enumerate(components[start:start + 4])
|
||||
for j in range(c)
|
||||
for z in range(c)], results.indices)
|
||||
self.assertAllEqual([
|
||||
c
|
||||
for c in components[start:start + 4] for _ in range(c)
|
||||
for _ in range(c)
|
||||
], results.values)
|
||||
self.assertAllEqual([
|
||||
min(4,
|
||||
len(components) - start), 5,
|
||||
np.max(components[start:start + 4])
|
||||
], results.dense_shape)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testDenseToSparseBatchDatasetWithInvalidShape(self):
|
||||
input_tensor = array_ops.constant([[1]])
|
||||
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
|
||||
dataset_ops.Dataset.from_tensors(input_tensor).apply(
|
||||
batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
|
||||
|
||||
def testDenseToSparseBatchDatasetShapeErrors(self):
|
||||
input_tensor = array_ops.placeholder(dtypes.int32)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensors(input_tensor).apply(
|
||||
batching.dense_to_sparse_batch(4, [12]))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Initialize with an input tensor of incompatible rank.
|
||||
sess.run(init_op, feed_dict={input_tensor: [[1]]})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"incompatible with the row shape"):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with an input tensor that is larger than `row_shape`.
|
||||
sess.run(init_op, feed_dict={input_tensor: range(13)})
|
||||
with self.assertRaisesRegexp(errors.DataLossError,
|
||||
"larger than the row shape"):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -12,12 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test RangeDataset."""
|
||||
"""Tests for `tf.data.experimental.enumerate_dataset()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import counter
|
||||
from tensorflow.python.data.experimental.ops import enumerate_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -28,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RangeDatasetTest(test_base.DatasetTestBase):
|
||||
class EnumerateDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
def testEnumerateDataset(self):
|
||||
components = (["a", "b"], [1, 2], [37.0, 38])
|
||||
@ -52,27 +51,6 @@ class RangeDatasetTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testCounter(self):
|
||||
"""Test dataset construction using `count`."""
|
||||
iterator = (counter.Counter(start=3, step=4)
|
||||
.make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
self.assertEqual([], get_next.shape.as_list())
|
||||
self.assertEqual(dtypes.int64, get_next.dtype)
|
||||
|
||||
negative_iterator = (counter.Counter(start=0, step=-1)
|
||||
.make_one_shot_iterator())
|
||||
negative_get_next = negative_iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual(3, sess.run(get_next))
|
||||
self.assertEqual(3 + 4, sess.run(get_next))
|
||||
self.assertEqual(3 + 2 * 4, sess.run(get_next))
|
||||
|
||||
self.assertEqual(0, sess.run(negative_get_next))
|
||||
self.assertEqual(-1, sess.run(negative_get_next))
|
||||
self.assertEqual(-2, sess.run(negative_get_next))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,247 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the private `FunctionBufferingResource` used in prefetching."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FunctionBufferingResourceTest(test_base.DatasetTestBase):
|
||||
|
||||
def setUp(self):
|
||||
self._event = threading.Event()
|
||||
|
||||
def _create_ds_and_iterator(self, device0, initializable=False):
|
||||
|
||||
def gen():
|
||||
for i in range(1, 10):
|
||||
yield [float(i)]
|
||||
if i == 6:
|
||||
self._event.set()
|
||||
|
||||
with ops.device(device0):
|
||||
ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
|
||||
if initializable:
|
||||
ds_iterator = ds.make_initializable_iterator()
|
||||
else:
|
||||
ds_iterator = ds.make_one_shot_iterator()
|
||||
return (ds, ds_iterator)
|
||||
|
||||
def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
|
||||
ds_iterator_handle = ds_iterator.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, ds.output_types, ds.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
target = constant_op.constant(device0)
|
||||
with ops.device(device1):
|
||||
buffer_resource_handle = prefetching_ops.function_buffering_resource(
|
||||
f=_remote_fn,
|
||||
output_types=[dtypes.float32],
|
||||
target_device=target,
|
||||
string_arg=ds_iterator_handle,
|
||||
buffer_size=3,
|
||||
shared_name=buffer_name)
|
||||
|
||||
with ops.device(device1):
|
||||
prefetch_op = prefetching_ops.function_buffering_resource_get_next(
|
||||
function_buffer_resource=buffer_resource_handle,
|
||||
output_types=[dtypes.float32])
|
||||
reset_op = prefetching_ops.function_buffering_resource_reset(
|
||||
function_buffer_resource=buffer_resource_handle)
|
||||
destroy_op = resource_variable_ops.destroy_resource_op(
|
||||
buffer_resource_handle, ignore_lookup_error=True)
|
||||
|
||||
return (prefetch_op, reset_op, destroy_op)
|
||||
|
||||
def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
|
||||
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
|
||||
prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
|
||||
device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testSameDeviceCPU(self):
|
||||
self._prefetch_fn_helper_one_shot("same_device_cpu",
|
||||
"/job:localhost/replica:0/task:0/cpu:0",
|
||||
"/job:localhost/replica:0/task:0/cpu:0")
|
||||
|
||||
def testDifferentDeviceCPU(self):
|
||||
self._prefetch_fn_helper_one_shot("diff_device_cpu",
|
||||
"/job:localhost/replica:0/task:0/cpu:0",
|
||||
"/job:localhost/replica:0/task:0/cpu:1")
|
||||
|
||||
def testDifferentDeviceCPUGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
self._prefetch_fn_helper_one_shot("cpu_gpu",
|
||||
"/job:localhost/replica:0/task:0/cpu:0",
|
||||
"/job:localhost/replica:0/task:0/gpu:0")
|
||||
|
||||
def testReinitialization(self):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
|
||||
device0 = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
device1 = "/job:localhost/replica:0/task:0/cpu:1"
|
||||
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
|
||||
prefetch_op, reset_op, destroy_op = self._create_ops(
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
# Lets reset the function buffering resource and reinitialize the
|
||||
# iterator. Should be able to go through this again.
|
||||
self._event.clear()
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [1.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [2.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [3.0])
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [4.0])
|
||||
self._event.wait()
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [5.0])
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testReinitializationOutOfRange(self):
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
|
||||
device0 = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
device1 = "/job:localhost/replica:0/task:0/cpu:1"
|
||||
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
|
||||
prefetch_op, reset_op, destroy_op = self._create_ops(
|
||||
ds, ds_iterator, "reinit", device0, device1)
|
||||
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
|
||||
# Now reset everything and try it out again.
|
||||
self._event.clear()
|
||||
sess.run(reset_op)
|
||||
sess.run(ds_iterator.initializer)
|
||||
for i in range(1, 10):
|
||||
elem = sess.run(prefetch_op)
|
||||
self.assertEqual(elem, [float(i)])
|
||||
# Try fetching after its over twice to test out end of sequence.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
|
||||
sess.run(destroy_op)
|
||||
|
||||
def testStringsGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
device0 = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
device1 = "/job:localhost/replica:0/task:0/gpu:0"
|
||||
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
|
||||
ds_iterator = ds.make_one_shot_iterator()
|
||||
ds_iterator_handle = ds_iterator.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _remote_fn(h):
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
h, ds.output_types, ds.output_shapes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
target = constant_op.constant(device0)
|
||||
with ops.device(device1):
|
||||
buffer_resource_handle = prefetching_ops.function_buffering_resource(
|
||||
f=_remote_fn,
|
||||
output_types=[dtypes.string],
|
||||
target_device=target,
|
||||
string_arg=ds_iterator_handle,
|
||||
buffer_size=3,
|
||||
shared_name="strings")
|
||||
|
||||
with ops.device(device1):
|
||||
prefetch_op = prefetching_ops.function_buffering_resource_get_next(
|
||||
function_buffer_resource=buffer_resource_handle,
|
||||
output_types=[dtypes.string])
|
||||
destroy_op = resource_variable_ops.destroy_resource_op(
|
||||
buffer_resource_handle, ignore_lookup_error=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertEqual([b"a"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"b"], sess.run(prefetch_op))
|
||||
self.assertEqual([b"c"], sess.run(prefetch_op))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(prefetch_op)
|
||||
|
||||
sess.run(destroy_op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,199 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.group_by_reducer()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GroupByReducerTest(test_base.DatasetTestBase):
|
||||
|
||||
def checkResults(self, dataset, shapes, values):
|
||||
self.assertEqual(shapes, dataset.output_shapes)
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
for expected in values:
|
||||
got = sess.run(get_next)
|
||||
self.assertEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSum(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda _: np.int64(0),
|
||||
reduce_func=lambda x, y: x + y,
|
||||
finalize_func=lambda x: x)
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.range(2 * i).apply(
|
||||
grouping.group_by_reducer(lambda x: x % 2, reducer))
|
||||
self.checkResults(
|
||||
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
|
||||
|
||||
def testAverage(self):
|
||||
|
||||
def reduce_fn(x, y):
|
||||
return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
|
||||
x[1] + 1), x[1] + 1
|
||||
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda _: (0.0, 0.0),
|
||||
reduce_func=reduce_fn,
|
||||
finalize_func=lambda x, _: x)
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.range(2 * i).apply(
|
||||
grouping.group_by_reducer(
|
||||
lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
|
||||
self.checkResults(
|
||||
dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
|
||||
|
||||
def testConcat(self):
|
||||
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: "",
|
||||
reduce_func=lambda x, y: x + y[0],
|
||||
finalize_func=lambda x: x)
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensor_slices(components),
|
||||
dataset_ops.Dataset.range(2 * i))).apply(
|
||||
grouping.group_by_reducer(lambda x, y: y % 2, reducer))
|
||||
self.checkResults(
|
||||
dataset,
|
||||
shapes=tensor_shape.scalar(),
|
||||
values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
|
||||
|
||||
def testSparseSum(self):
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=np.array([[0, 0]]),
|
||||
values=(i * np.array([1], dtype=np.int64)),
|
||||
dense_shape=np.array([1, 1]))
|
||||
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda _: _sparse(np.int64(0)),
|
||||
reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
|
||||
finalize_func=lambda x: x.values[0])
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
|
||||
grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
|
||||
self.checkResults(
|
||||
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
|
||||
|
||||
def testChangingStateShape(self):
|
||||
|
||||
def reduce_fn(x, _):
|
||||
# Statically known rank, but dynamic length.
|
||||
larger_dim = array_ops.concat([x[0], x[0]], 0)
|
||||
# Statically unknown rank.
|
||||
larger_rank = array_ops.expand_dims(x[1], 0)
|
||||
return larger_dim, larger_rank
|
||||
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: ([0], 1),
|
||||
reduce_func=reduce_fn,
|
||||
finalize_func=lambda x, y: (x, y))
|
||||
|
||||
for i in range(1, 11):
|
||||
dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
|
||||
grouping.group_by_reducer(lambda x: x, reducer))
|
||||
self.assertEqual([None], dataset.output_shapes[0].as_list())
|
||||
self.assertIs(None, dataset.output_shapes[1].ndims)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
self.assertAllEqual([0] * (2**i), x)
|
||||
self.assertAllEqual(np.array(1, ndmin=i), y)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testTypeMismatch(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
|
||||
reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
"The element types for the new state must match the initial state."):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: np.int64(0), reducer))
|
||||
|
||||
# TODO(b/78665031): Remove once non-scalar keys are supported.
|
||||
def testInvalidKeyShape(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: np.int64(0),
|
||||
reduce_func=lambda x, y: x + y,
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`key_func` must return a single tf.int64 tensor."):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
|
||||
|
||||
# TODO(b/78665031): Remove once non-int64 keys are supported.
|
||||
def testInvalidKeyType(self):
|
||||
reducer = grouping.Reducer(
|
||||
init_func=lambda x: np.int64(0),
|
||||
reduce_func=lambda x, y: x + y,
|
||||
finalize_func=lambda x: x)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "`key_func` must return a single tf.int64 tensor."):
|
||||
dataset.apply(
|
||||
grouping.group_by_reducer(lambda _: "wrong", reducer))
|
||||
|
||||
def testTuple(self):
|
||||
def init_fn(_):
|
||||
return np.array([], dtype=np.int64), np.int64(0)
|
||||
|
||||
def reduce_fn(state, value):
|
||||
s1, s2 = state
|
||||
v1, v2 = value
|
||||
return array_ops.concat([s1, [v1]], 0), s2 + v2
|
||||
|
||||
def finalize_fn(s1, s2):
|
||||
return s1, s2
|
||||
|
||||
reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
|
||||
dataset = dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
|
||||
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
with self.cached_session() as sess:
|
||||
x, y = sess.run(get_next)
|
||||
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
|
||||
self.assertEqual(y, 45)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,367 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.group_by_window()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
|
||||
# Currently, they use a constant batch size, though should be made to use a
|
||||
# different batch size per key.
|
||||
class GroupByWindowTest(test_base.DatasetTestBase):
|
||||
|
||||
def _dynamicPad(self, bucket, window, window_size):
|
||||
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
|
||||
# generic form of padded_batch that pads every component
|
||||
# dynamically and does not rely on static shape information about
|
||||
# the arguments.
|
||||
return dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensors(bucket),
|
||||
window.padded_batch(
|
||||
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
|
||||
[None]), tensor_shape.TensorShape([3])))))
|
||||
|
||||
def testSingleBucket(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
grouping.group_by_window(
|
||||
lambda x, y, z: 0,
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
|
||||
|
||||
iterator = bucketed_dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
which_bucket, bucketed_values = sess.run(get_next)
|
||||
|
||||
self.assertEqual(0, which_bucket)
|
||||
|
||||
expected_scalar_int = np.arange(32, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
|
||||
for i in range(32):
|
||||
expected_unk_int64[i, :i] = i
|
||||
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
|
||||
|
||||
def testEvenOddBuckets(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
grouping.group_by_window(
|
||||
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
|
||||
|
||||
iterator = bucketed_dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
which_bucket_even, bucketed_values_even = sess.run(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values_even))
|
||||
self.assertEqual(3, len(bucketed_values_odd))
|
||||
|
||||
# Ensure bucket 0 was used for all minibatch entries.
|
||||
self.assertAllEqual(0, which_bucket_even)
|
||||
self.assertAllEqual(1, which_bucket_odd)
|
||||
|
||||
# Test the first bucket outputted, the events starting at 0
|
||||
expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2 * i] = 2 * i
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
|
||||
|
||||
# Test the second bucket outputted, the odds starting at 1
|
||||
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
|
||||
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
|
||||
def _map_fn(v):
|
||||
return {
|
||||
"x": v,
|
||||
"y": array_ops.fill([v], v),
|
||||
"z": array_ops.fill([3], string_ops.as_string(v))
|
||||
}
|
||||
|
||||
def _dynamic_pad_fn(bucket, window, _):
|
||||
return dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensors(bucket),
|
||||
window.padded_batch(
|
||||
32, {
|
||||
"x": tensor_shape.TensorShape([]),
|
||||
"y": tensor_shape.TensorShape([None]),
|
||||
"z": tensor_shape.TensorShape([3])
|
||||
})))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
|
||||
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
|
||||
|
||||
bucketed_dataset = input_dataset.apply(
|
||||
grouping.group_by_window(
|
||||
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
|
||||
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
|
||||
|
||||
iterator = bucketed_dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
which_bucket0, bucketed_values_even0 = sess.run(get_next)
|
||||
which_bucket1, bucketed_values_even1 = sess.run(get_next)
|
||||
|
||||
# Ensure that bucket 1 was completely filtered out
|
||||
self.assertAllEqual(0, which_bucket0)
|
||||
self.assertAllEqual(0, which_bucket1)
|
||||
self.assertAllEqual(
|
||||
np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
|
||||
self.assertAllEqual(
|
||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
|
||||
|
||||
def testDynamicWindowSize(self):
|
||||
components = np.arange(100).astype(np.int64)
|
||||
|
||||
# Key fn: even/odd
|
||||
# Reduce fn: batches of 5
|
||||
# Window size fn: even=5, odd=10
|
||||
|
||||
def window_size_func(key):
|
||||
window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
|
||||
return window_sizes[key]
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
|
||||
None, window_size_func))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
batches = 0
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
is_even = all(x % 2 == 0 for x in result)
|
||||
is_odd = all(x % 2 == 1 for x in result)
|
||||
self.assertTrue(is_even or is_odd)
|
||||
expected_batch_size = 5 if is_even else 10
|
||||
self.assertEqual(expected_batch_size, result.shape[0])
|
||||
batches += 1
|
||||
|
||||
self.assertEqual(batches, 15)
|
||||
|
||||
def testSimple(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
|
||||
.apply(
|
||||
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
|
||||
4)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
self.assertTrue(
|
||||
all(x % 2 == 0
|
||||
for x in result) or all(x % 2 == 1)
|
||||
for x in result)
|
||||
counts.append(result.shape[0])
|
||||
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
num_full_batches = len([c for c in counts if c == 4])
|
||||
self.assertGreaterEqual(num_full_batches, 24)
|
||||
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
|
||||
|
||||
def testImmediateOutput(self):
|
||||
components = np.array(
|
||||
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
|
||||
grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
|
||||
4)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
# The input is infinite, so this test demonstrates that:
|
||||
# 1. We produce output without having to consume the entire input,
|
||||
# 2. Different buckets can produce output at different rates, and
|
||||
# 3. For deterministic input, the output is deterministic.
|
||||
for _ in range(3):
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).apply(
|
||||
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
|
||||
4)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
# The small outputs at the end are deterministically produced in key
|
||||
# order.
|
||||
self.assertAllEqual([0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1], sess.run(get_next))
|
||||
|
||||
def testEmpty(self):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(4).apply(
|
||||
grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Window size must be greater than zero, but got 0."):
|
||||
print(sess.run(get_next))
|
||||
|
||||
def testReduceFuncError(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
|
||||
def reduce_func(_, xs):
|
||||
# Introduce an incorrect padded shape that cannot (currently) be
|
||||
# detected at graph construction time.
|
||||
return xs.padded_batch(
|
||||
4,
|
||||
padded_shapes=(tensor_shape.TensorShape([]),
|
||||
constant_op.constant([5], dtype=dtypes.int64) * -1))
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
|
||||
grouping.group_by_window(lambda x, _: x % 2, reduce_func,
|
||||
32)).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testConsumeWindowDatasetMoreThanOnce(self):
|
||||
components = np.random.randint(50, size=(200,)).astype(np.int64)
|
||||
|
||||
def reduce_func(key, window):
|
||||
# Apply two different kinds of padding to the input: tight
|
||||
# padding, and quantized (to a multiple of 10) padding.
|
||||
return dataset_ops.Dataset.zip((
|
||||
window.padded_batch(
|
||||
4, padded_shapes=tensor_shape.TensorShape([None])),
|
||||
window.padded_batch(
|
||||
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
|
||||
))
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
|
||||
.apply(grouping.group_by_window(
|
||||
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
|
||||
reduce_func, 4))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
tight_result, multiple_of_10_result = sess.run(get_next)
|
||||
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
|
||||
self.assertAllEqual(tight_result,
|
||||
multiple_of_10_result[:, :tight_result.shape[1]])
|
||||
counts.append(tight_result.shape[0])
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,115 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.ignore_errors()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import error_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_NUMPY_RANDOM_SEED = 42
|
||||
|
||||
|
||||
class IgnoreErrorsTest(test_base.DatasetTestBase):
|
||||
|
||||
def testMapIgnoreError(self):
|
||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.check_numerics(x, "message")).apply(
|
||||
error_ops.ignore_errors()))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testParallelMapIgnoreError(self):
|
||||
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).map(
|
||||
lambda x: array_ops.check_numerics(x, "message"),
|
||||
num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for x in [1., 2., 3., 5.]:
|
||||
self.assertEqual(x, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testReadFileIgnoreError(self):
|
||||
|
||||
def write_string_to_file(value, filename):
|
||||
with open(filename, "w") as f:
|
||||
f.write(value)
|
||||
|
||||
filenames = [
|
||||
os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
|
||||
]
|
||||
for filename in filenames:
|
||||
write_string_to_file(filename, filename)
|
||||
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(filenames).map(
|
||||
io_ops.read_file,
|
||||
num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# All of the files are present.
|
||||
sess.run(init_op)
|
||||
for filename in filenames:
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Delete one of the files.
|
||||
os.remove(filenames[0])
|
||||
|
||||
# Attempting to read filenames[0] will fail, but ignore_errors()
|
||||
# will catch the error.
|
||||
sess.run(init_op)
|
||||
for filename in filenames[1:]:
|
||||
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,239 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.make_batched_features_dataset()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
from tensorflow.python.data.ops import readers as core_readers
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MakeBatchedFeaturesDatasetTest(
|
||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
|
||||
|
||||
def testRead(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 10]:
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
# Basic test: read from file 0.
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size).make_one_shot_iterator().get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
0,
|
||||
num_epochs=num_epochs,
|
||||
label_key_provided=True)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess, label_key_provided=True)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
# Basic test: read from file 1.
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames[1],
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size).make_one_shot_iterator().get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
1,
|
||||
num_epochs=num_epochs,
|
||||
label_key_provided=True)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess, label_key_provided=True)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
# Basic test: read from both files.
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames,
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size).make_one_shot_iterator().get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
num_epochs=num_epochs,
|
||||
label_key_provided=True)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess, label_key_provided=True)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
# Basic test: read from both files.
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size).make_one_shot_iterator().get_next()
|
||||
self.verify_records(sess, batch_size, num_epochs=num_epochs)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess)
|
||||
|
||||
def testReadWithEquivalentDataset(self):
|
||||
features = {
|
||||
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||
}
|
||||
dataset = (
|
||||
core_readers.TFRecordDataset(self.test_filenames)
|
||||
.map(lambda x: parsing_ops.parse_single_example(x, features))
|
||||
.repeat(10).batch(2))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
|
||||
range(self._num_files), 2, 10):
|
||||
actual_batch = sess.run(next_element)
|
||||
self.assertAllEqual(file_batch, actual_batch["file"])
|
||||
self.assertAllEqual(record_batch, actual_batch["record"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testReadWithFusedShuffleRepeatDataset(self):
|
||||
num_epochs = 5
|
||||
total_records = num_epochs * self._num_records
|
||||
for batch_size in [1, 2]:
|
||||
# Test that shuffling with same seed produces the same result.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
outputs1 = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle_seed=5).make_one_shot_iterator().get_next()
|
||||
outputs2 = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle_seed=5).make_one_shot_iterator().get_next()
|
||||
for _ in range(total_records // batch_size):
|
||||
batch1 = self._run_actual_batch(outputs1, sess)
|
||||
batch2 = self._run_actual_batch(outputs2, sess)
|
||||
for i in range(len(batch1)):
|
||||
self.assertAllEqual(batch1[i], batch2[i])
|
||||
|
||||
# Test that shuffling with different seeds produces a different order.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
outputs1 = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle_seed=5).make_one_shot_iterator().get_next()
|
||||
outputs2 = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle_seed=15).make_one_shot_iterator().get_next()
|
||||
all_equal = True
|
||||
for _ in range(total_records // batch_size):
|
||||
batch1 = self._run_actual_batch(outputs1, sess)
|
||||
batch2 = self._run_actual_batch(outputs2, sess)
|
||||
for i in range(len(batch1)):
|
||||
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
|
||||
self.assertFalse(all_equal)
|
||||
|
||||
def testParallelReadersAndParsers(self):
|
||||
num_epochs = 5
|
||||
for batch_size in [1, 2]:
|
||||
for reader_num_threads in [2, 4]:
|
||||
for parser_num_threads in [2, 4]:
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames,
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
reader_num_threads=reader_num_threads,
|
||||
parser_num_threads=parser_num_threads).make_one_shot_iterator(
|
||||
).get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
num_epochs=num_epochs,
|
||||
label_key_provided=True,
|
||||
interleave_cycle_length=reader_num_threads)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess, label_key_provided=True)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
reader_num_threads=reader_num_threads,
|
||||
parser_num_threads=parser_num_threads).make_one_shot_iterator(
|
||||
).get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
num_epochs=num_epochs,
|
||||
interleave_cycle_length=reader_num_threads)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess)
|
||||
|
||||
def testDropFinalBatch(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 10]:
|
||||
with ops.Graph().as_default():
|
||||
# Basic test: read from file 0.
|
||||
outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
drop_final_batch=True).make_one_shot_iterator().get_next()
|
||||
for tensor in nest.flatten(outputs):
|
||||
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
|
||||
self.assertEqual(tensor.shape[0], batch_size)
|
||||
|
||||
def testIndefiniteRepeatShapeInference(self):
|
||||
dataset = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
label_key="label",
|
||||
num_epochs=None,
|
||||
batch_size=32)
|
||||
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
|
||||
nest.flatten(dataset.output_classes)):
|
||||
if issubclass(clazz, ops.Tensor):
|
||||
self.assertEqual(32, shape[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
"""Tests for `tf.data.experimental.make_csv_dataset()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -23,226 +23,16 @@ import zlib
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import readers as core_readers
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ReadBatchFeaturesTest(
|
||||
reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
|
||||
|
||||
def testRead(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 10]:
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
# Basic test: read from file 0.
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size).make_one_shot_iterator().get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
0,
|
||||
num_epochs=num_epochs,
|
||||
label_key_provided=True)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess, label_key_provided=True)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
# Basic test: read from file 1.
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames[1],
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size).make_one_shot_iterator().get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
1,
|
||||
num_epochs=num_epochs,
|
||||
label_key_provided=True)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess, label_key_provided=True)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
# Basic test: read from both files.
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames,
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size).make_one_shot_iterator().get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
num_epochs=num_epochs,
|
||||
label_key_provided=True)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess, label_key_provided=True)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
# Basic test: read from both files.
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size).make_one_shot_iterator().get_next()
|
||||
self.verify_records(sess, batch_size, num_epochs=num_epochs)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess)
|
||||
|
||||
def testReadWithEquivalentDataset(self):
|
||||
features = {
|
||||
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||
}
|
||||
dataset = (
|
||||
core_readers.TFRecordDataset(self.test_filenames)
|
||||
.map(lambda x: parsing_ops.parse_single_example(x, features))
|
||||
.repeat(10).batch(2))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
|
||||
range(self._num_files), 2, 10):
|
||||
actual_batch = sess.run(next_element)
|
||||
self.assertAllEqual(file_batch, actual_batch["file"])
|
||||
self.assertAllEqual(record_batch, actual_batch["record"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testReadWithFusedShuffleRepeatDataset(self):
|
||||
num_epochs = 5
|
||||
total_records = num_epochs * self._num_records
|
||||
for batch_size in [1, 2]:
|
||||
# Test that shuffling with same seed produces the same result.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
outputs1 = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle_seed=5).make_one_shot_iterator().get_next()
|
||||
outputs2 = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle_seed=5).make_one_shot_iterator().get_next()
|
||||
for _ in range(total_records // batch_size):
|
||||
batch1 = self._run_actual_batch(outputs1, sess)
|
||||
batch2 = self._run_actual_batch(outputs2, sess)
|
||||
for i in range(len(batch1)):
|
||||
self.assertAllEqual(batch1[i], batch2[i])
|
||||
|
||||
# Test that shuffling with different seeds produces a different order.
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
outputs1 = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle_seed=5).make_one_shot_iterator().get_next()
|
||||
outputs2 = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle_seed=15).make_one_shot_iterator().get_next()
|
||||
all_equal = True
|
||||
for _ in range(total_records // batch_size):
|
||||
batch1 = self._run_actual_batch(outputs1, sess)
|
||||
batch2 = self._run_actual_batch(outputs2, sess)
|
||||
for i in range(len(batch1)):
|
||||
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
|
||||
self.assertFalse(all_equal)
|
||||
|
||||
def testParallelReadersAndParsers(self):
|
||||
num_epochs = 5
|
||||
for batch_size in [1, 2]:
|
||||
for reader_num_threads in [2, 4]:
|
||||
for parser_num_threads in [2, 4]:
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames,
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
reader_num_threads=reader_num_threads,
|
||||
parser_num_threads=parser_num_threads).make_one_shot_iterator(
|
||||
).get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
num_epochs=num_epochs,
|
||||
label_key_provided=True,
|
||||
interleave_cycle_length=reader_num_threads)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess, label_key_provided=True)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
self.outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
reader_num_threads=reader_num_threads,
|
||||
parser_num_threads=parser_num_threads).make_one_shot_iterator(
|
||||
).get_next()
|
||||
self.verify_records(
|
||||
sess,
|
||||
batch_size,
|
||||
num_epochs=num_epochs,
|
||||
interleave_cycle_length=reader_num_threads)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess)
|
||||
|
||||
def testDropFinalBatch(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 10]:
|
||||
with ops.Graph().as_default():
|
||||
# Basic test: read from file 0.
|
||||
outputs = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
label_key="label",
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
drop_final_batch=True).make_one_shot_iterator().get_next()
|
||||
for tensor in nest.flatten(outputs):
|
||||
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
|
||||
self.assertEqual(tensor.shape[0], batch_size)
|
||||
|
||||
def testIndefiniteRepeatShapeInference(self):
|
||||
dataset = self.make_batch_feature(
|
||||
filenames=self.test_filenames[0],
|
||||
label_key="label",
|
||||
num_epochs=None,
|
||||
batch_size=32)
|
||||
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
|
||||
nest.flatten(dataset.output_classes)):
|
||||
if issubclass(clazz, ops.Tensor):
|
||||
self.assertEqual(32, shape[0])
|
||||
|
||||
|
||||
class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
|
||||
@ -866,218 +656,5 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
|
||||
self.assertEqual(32, shape[0])
|
||||
|
||||
|
||||
class MakeTFRecordDatasetTest(
|
||||
reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
||||
|
||||
def _interleave(self, iterators, cycle_length):
|
||||
pending_iterators = iterators
|
||||
open_iterators = []
|
||||
num_open = 0
|
||||
for i in range(cycle_length):
|
||||
if pending_iterators:
|
||||
open_iterators.append(pending_iterators.pop(0))
|
||||
num_open += 1
|
||||
|
||||
while num_open:
|
||||
for i in range(min(cycle_length, len(open_iterators))):
|
||||
if open_iterators[i] is None:
|
||||
continue
|
||||
try:
|
||||
yield next(open_iterators[i])
|
||||
except StopIteration:
|
||||
if pending_iterators:
|
||||
open_iterators[i] = pending_iterators.pop(0)
|
||||
else:
|
||||
open_iterators[i] = None
|
||||
num_open -= 1
|
||||
|
||||
def _next_expected_batch(self,
|
||||
file_indices,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
cycle_length,
|
||||
drop_final_batch,
|
||||
use_parser_fn):
|
||||
|
||||
def _next_record(file_indices):
|
||||
for j in file_indices:
|
||||
for i in range(self._num_records):
|
||||
yield j, i
|
||||
|
||||
def _next_record_interleaved(file_indices, cycle_length):
|
||||
return self._interleave([_next_record([i]) for i in file_indices],
|
||||
cycle_length)
|
||||
|
||||
record_batch = []
|
||||
batch_index = 0
|
||||
for _ in range(num_epochs):
|
||||
if cycle_length == 1:
|
||||
next_records = _next_record(file_indices)
|
||||
else:
|
||||
next_records = _next_record_interleaved(file_indices, cycle_length)
|
||||
for f, r in next_records:
|
||||
record = self._record(f, r)
|
||||
if use_parser_fn:
|
||||
record = record[1:]
|
||||
record_batch.append(record)
|
||||
batch_index += 1
|
||||
if len(record_batch) == batch_size:
|
||||
yield record_batch
|
||||
record_batch = []
|
||||
batch_index = 0
|
||||
if record_batch and not drop_final_batch:
|
||||
yield record_batch
|
||||
|
||||
def _verify_records(self,
|
||||
sess,
|
||||
outputs,
|
||||
batch_size,
|
||||
file_index,
|
||||
num_epochs,
|
||||
interleave_cycle_length,
|
||||
drop_final_batch,
|
||||
use_parser_fn):
|
||||
if file_index is not None:
|
||||
file_indices = [file_index]
|
||||
else:
|
||||
file_indices = range(self._num_files)
|
||||
|
||||
for expected_batch in self._next_expected_batch(
|
||||
file_indices, batch_size, num_epochs, interleave_cycle_length,
|
||||
drop_final_batch, use_parser_fn):
|
||||
actual_batch = sess.run(outputs)
|
||||
self.assertAllEqual(expected_batch, actual_batch)
|
||||
|
||||
def _read_test(self, batch_size, num_epochs, file_index=None,
|
||||
num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
|
||||
if file_index is None:
|
||||
file_pattern = self.test_filenames
|
||||
else:
|
||||
file_pattern = self.test_filenames[file_index]
|
||||
|
||||
if parser_fn:
|
||||
fn = lambda x: string_ops.substr(x, 1, 999)
|
||||
else:
|
||||
fn = None
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
outputs = readers.make_tf_record_dataset(
|
||||
file_pattern=file_pattern,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
parser_fn=fn,
|
||||
num_parallel_reads=num_parallel_reads,
|
||||
drop_final_batch=drop_final_batch,
|
||||
shuffle=False).make_one_shot_iterator().get_next()
|
||||
self._verify_records(
|
||||
sess, outputs, batch_size, file_index, num_epochs=num_epochs,
|
||||
interleave_cycle_length=num_parallel_reads,
|
||||
drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(outputs)
|
||||
|
||||
def testRead(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
# Basic test: read from file 0.
|
||||
self._read_test(batch_size, num_epochs, 0)
|
||||
|
||||
# Basic test: read from file 1.
|
||||
self._read_test(batch_size, num_epochs, 1)
|
||||
|
||||
# Basic test: read from both files.
|
||||
self._read_test(batch_size, num_epochs)
|
||||
|
||||
# Basic test: read from both files, with parallel reads.
|
||||
self._read_test(batch_size, num_epochs, num_parallel_reads=8)
|
||||
|
||||
def testDropFinalBatch(self):
|
||||
for batch_size in [1, 2, 10]:
|
||||
for num_epochs in [1, 3]:
|
||||
# Read from file 0.
|
||||
self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
|
||||
|
||||
# Read from both files.
|
||||
self._read_test(batch_size, num_epochs, drop_final_batch=True)
|
||||
|
||||
# Read from both files, with parallel reads.
|
||||
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
|
||||
drop_final_batch=True)
|
||||
|
||||
def testParserFn(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
for drop_final_batch in [False, True]:
|
||||
self._read_test(batch_size, num_epochs, parser_fn=True,
|
||||
drop_final_batch=drop_final_batch)
|
||||
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
|
||||
parser_fn=True, drop_final_batch=drop_final_batch)
|
||||
|
||||
def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
|
||||
seed=None):
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
dataset = readers.make_tf_record_dataset(
|
||||
file_pattern=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
num_parallel_reads=num_parallel_reads,
|
||||
shuffle=True,
|
||||
shuffle_seed=seed)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
first_batches = []
|
||||
try:
|
||||
while True:
|
||||
first_batches.append(sess.run(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
second_batches = []
|
||||
try:
|
||||
while True:
|
||||
second_batches.append(sess.run(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
self.assertEqual(len(first_batches), len(second_batches))
|
||||
if seed is not None:
|
||||
# if you set a seed, should get the same results
|
||||
for i in range(len(first_batches)):
|
||||
self.assertAllEqual(first_batches[i], second_batches[i])
|
||||
|
||||
expected = []
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
expected.extend([self._record(f, r)] * num_epochs)
|
||||
|
||||
for batches in (first_batches, second_batches):
|
||||
actual = []
|
||||
for b in batches:
|
||||
actual.extend(b)
|
||||
self.assertAllEqual(sorted(expected), sorted(actual))
|
||||
|
||||
def testShuffle(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
for num_parallel_reads in [1, 2]:
|
||||
# Test that all expected elements are produced
|
||||
self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
|
||||
# Test that elements are produced in a consistent order if
|
||||
# you specify a seed.
|
||||
self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
|
||||
seed=21345)
|
||||
|
||||
def testIndefiniteRepeatShapeInference(self):
|
||||
dataset = readers.make_tf_record_dataset(
|
||||
file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
|
||||
for shape in nest.flatten(dataset.output_shapes):
|
||||
self.assertEqual(32, shape[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,243 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.make_tf_record_dataset()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MakeTFRecordDatasetTest(
|
||||
reader_dataset_ops_test_base.TFRecordDatasetTestBase):
|
||||
|
||||
def _interleave(self, iterators, cycle_length):
|
||||
pending_iterators = iterators
|
||||
open_iterators = []
|
||||
num_open = 0
|
||||
for i in range(cycle_length):
|
||||
if pending_iterators:
|
||||
open_iterators.append(pending_iterators.pop(0))
|
||||
num_open += 1
|
||||
|
||||
while num_open:
|
||||
for i in range(min(cycle_length, len(open_iterators))):
|
||||
if open_iterators[i] is None:
|
||||
continue
|
||||
try:
|
||||
yield next(open_iterators[i])
|
||||
except StopIteration:
|
||||
if pending_iterators:
|
||||
open_iterators[i] = pending_iterators.pop(0)
|
||||
else:
|
||||
open_iterators[i] = None
|
||||
num_open -= 1
|
||||
|
||||
def _next_expected_batch(self,
|
||||
file_indices,
|
||||
batch_size,
|
||||
num_epochs,
|
||||
cycle_length,
|
||||
drop_final_batch,
|
||||
use_parser_fn):
|
||||
|
||||
def _next_record(file_indices):
|
||||
for j in file_indices:
|
||||
for i in range(self._num_records):
|
||||
yield j, i
|
||||
|
||||
def _next_record_interleaved(file_indices, cycle_length):
|
||||
return self._interleave([_next_record([i]) for i in file_indices],
|
||||
cycle_length)
|
||||
|
||||
record_batch = []
|
||||
batch_index = 0
|
||||
for _ in range(num_epochs):
|
||||
if cycle_length == 1:
|
||||
next_records = _next_record(file_indices)
|
||||
else:
|
||||
next_records = _next_record_interleaved(file_indices, cycle_length)
|
||||
for f, r in next_records:
|
||||
record = self._record(f, r)
|
||||
if use_parser_fn:
|
||||
record = record[1:]
|
||||
record_batch.append(record)
|
||||
batch_index += 1
|
||||
if len(record_batch) == batch_size:
|
||||
yield record_batch
|
||||
record_batch = []
|
||||
batch_index = 0
|
||||
if record_batch and not drop_final_batch:
|
||||
yield record_batch
|
||||
|
||||
def _verify_records(self,
|
||||
sess,
|
||||
outputs,
|
||||
batch_size,
|
||||
file_index,
|
||||
num_epochs,
|
||||
interleave_cycle_length,
|
||||
drop_final_batch,
|
||||
use_parser_fn):
|
||||
if file_index is not None:
|
||||
file_indices = [file_index]
|
||||
else:
|
||||
file_indices = range(self._num_files)
|
||||
|
||||
for expected_batch in self._next_expected_batch(
|
||||
file_indices, batch_size, num_epochs, interleave_cycle_length,
|
||||
drop_final_batch, use_parser_fn):
|
||||
actual_batch = sess.run(outputs)
|
||||
self.assertAllEqual(expected_batch, actual_batch)
|
||||
|
||||
def _read_test(self, batch_size, num_epochs, file_index=None,
|
||||
num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
|
||||
if file_index is None:
|
||||
file_pattern = self.test_filenames
|
||||
else:
|
||||
file_pattern = self.test_filenames[file_index]
|
||||
|
||||
if parser_fn:
|
||||
fn = lambda x: string_ops.substr(x, 1, 999)
|
||||
else:
|
||||
fn = None
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
outputs = readers.make_tf_record_dataset(
|
||||
file_pattern=file_pattern,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
parser_fn=fn,
|
||||
num_parallel_reads=num_parallel_reads,
|
||||
drop_final_batch=drop_final_batch,
|
||||
shuffle=False).make_one_shot_iterator().get_next()
|
||||
self._verify_records(
|
||||
sess, outputs, batch_size, file_index, num_epochs=num_epochs,
|
||||
interleave_cycle_length=num_parallel_reads,
|
||||
drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(outputs)
|
||||
|
||||
def testRead(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
# Basic test: read from file 0.
|
||||
self._read_test(batch_size, num_epochs, 0)
|
||||
|
||||
# Basic test: read from file 1.
|
||||
self._read_test(batch_size, num_epochs, 1)
|
||||
|
||||
# Basic test: read from both files.
|
||||
self._read_test(batch_size, num_epochs)
|
||||
|
||||
# Basic test: read from both files, with parallel reads.
|
||||
self._read_test(batch_size, num_epochs, num_parallel_reads=8)
|
||||
|
||||
def testDropFinalBatch(self):
|
||||
for batch_size in [1, 2, 10]:
|
||||
for num_epochs in [1, 3]:
|
||||
# Read from file 0.
|
||||
self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
|
||||
|
||||
# Read from both files.
|
||||
self._read_test(batch_size, num_epochs, drop_final_batch=True)
|
||||
|
||||
# Read from both files, with parallel reads.
|
||||
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
|
||||
drop_final_batch=True)
|
||||
|
||||
def testParserFn(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
for drop_final_batch in [False, True]:
|
||||
self._read_test(batch_size, num_epochs, parser_fn=True,
|
||||
drop_final_batch=drop_final_batch)
|
||||
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
|
||||
parser_fn=True, drop_final_batch=drop_final_batch)
|
||||
|
||||
def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
|
||||
seed=None):
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(graph=g) as sess:
|
||||
dataset = readers.make_tf_record_dataset(
|
||||
file_pattern=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
num_parallel_reads=num_parallel_reads,
|
||||
shuffle=True,
|
||||
shuffle_seed=seed)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
first_batches = []
|
||||
try:
|
||||
while True:
|
||||
first_batches.append(sess.run(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
second_batches = []
|
||||
try:
|
||||
while True:
|
||||
second_batches.append(sess.run(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
|
||||
self.assertEqual(len(first_batches), len(second_batches))
|
||||
if seed is not None:
|
||||
# if you set a seed, should get the same results
|
||||
for i in range(len(first_batches)):
|
||||
self.assertAllEqual(first_batches[i], second_batches[i])
|
||||
|
||||
expected = []
|
||||
for f in range(self._num_files):
|
||||
for r in range(self._num_records):
|
||||
expected.extend([self._record(f, r)] * num_epochs)
|
||||
|
||||
for batches in (first_batches, second_batches):
|
||||
actual = []
|
||||
for b in batches:
|
||||
actual.extend(b)
|
||||
self.assertAllEqual(sorted(expected), sorted(actual))
|
||||
|
||||
def testShuffle(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 3]:
|
||||
for num_parallel_reads in [1, 2]:
|
||||
# Test that all expected elements are produced
|
||||
self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
|
||||
# Test that elements are produced in a consistent order if
|
||||
# you specify a seed.
|
||||
self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
|
||||
seed=21345)
|
||||
|
||||
def testIndefiniteRepeatShapeInference(self):
|
||||
dataset = readers.make_tf_record_dataset(
|
||||
file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
|
||||
for shape in nest.flatten(dataset.output_shapes):
|
||||
self.assertEqual(32, shape[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,337 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.map_and_batch()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Default", None, None),
|
||||
("SequentialCalls", 1, None),
|
||||
("ParallelCalls", 2, None),
|
||||
("ParallelBatches", None, 10),
|
||||
)
|
||||
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
# The pipeline is TensorSliceDataset ->
|
||||
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
|
||||
components = (np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7))
|
||||
|
||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
|
||||
batching.map_and_batch(
|
||||
map_func=_map_fn,
|
||||
batch_size=batch_size,
|
||||
num_parallel_calls=num_parallel_calls,
|
||||
num_parallel_batches=num_parallel_batches))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Batch of a finite input, where the batch_size divides the
|
||||
# total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
num_batches = (28 * 7) // 14
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Batch of a finite input, where the batch_size does not
|
||||
# divide the total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 14, batch_size: 8})
|
||||
|
||||
# We expect (num_batches - 1) full-sized batches.
|
||||
num_batches = int(math.ceil((14 * 7) / 8))
|
||||
for i in range(num_batches - 1):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(8):
|
||||
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((14 * 7) % 8):
|
||||
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Batch of an empty input should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Empty batch should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Even", False),
|
||||
("Uneven", True),
|
||||
)
|
||||
def testMapAndBatchPartialBatch(self, drop_remainder):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(10).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: array_ops.reshape(x * x, [1]),
|
||||
batch_size=4,
|
||||
drop_remainder=drop_remainder)).make_one_shot_iterator())
|
||||
if drop_remainder:
|
||||
self.assertEqual([4, 1], iterator.output_shapes.as_list())
|
||||
else:
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
if not drop_remainder:
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testMapAndBatchYieldsPartialBatch(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.apply(batching.map_and_batch(
|
||||
lambda x: array_ops.reshape(x * x, [1]), 4))
|
||||
.make_one_shot_iterator())
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
|
||||
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
|
||||
self.assertAllEqual([[64], [81]], sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testMapAndBatchParallelGetNext(self):
|
||||
iterator = (dataset_ops.Dataset.range(50000)
|
||||
.apply(batching.map_and_batch(lambda x: x, batch_size=100))
|
||||
.make_one_shot_iterator())
|
||||
elements = []
|
||||
for _ in range(100):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(5):
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
|
||||
self.assertAllEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elements)
|
||||
|
||||
def testMapAndBatchParallelGetNextDropRemainder(self):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(49999).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: x, batch_size=100, drop_remainder=True))
|
||||
.make_one_shot_iterator())
|
||||
elements = []
|
||||
for _ in range(100):
|
||||
elements.append(iterator.get_next())
|
||||
with self.cached_session() as sess:
|
||||
for i in range(4):
|
||||
got = sess.run(elements)
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
|
||||
self.assertAllEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elements)
|
||||
|
||||
def testMapAndBatchSparse(self):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0]], values=(i * [1]), dense_shape=[1])
|
||||
|
||||
iterator = dataset_ops.Dataset.range(10).apply(
|
||||
batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(2):
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
|
||||
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
|
||||
dense_shape=[5, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMapAndBatchFails(self):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
array_ops.check_numerics(
|
||||
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = (
|
||||
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
|
||||
sess.run(init_op, feed_dict={batch_size: 14})
|
||||
|
||||
def testMapAndBatchShapeMismatch(self):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
|
||||
def generator():
|
||||
yield [1]
|
||||
yield [2]
|
||||
yield [3]
|
||||
yield [[4, 5, 6]]
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int32)
|
||||
batch_size = 4
|
||||
iterator = (
|
||||
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"number of elements does not match"):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMapAndBatchImplicitDispose(self):
|
||||
# Tests whether a map and batch dataset will be cleaned up correctly when
|
||||
# the pipeline does not run it until exhaustion.
|
||||
# The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
|
||||
# MapAndBatchDataset(f=square_3, batch_size=100).
|
||||
components = (np.arange(1000),
|
||||
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(1000))
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
|
||||
1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
|
||||
dataset = dataset.prefetch(5)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 0),
|
||||
("2", 5),
|
||||
("3", 10),
|
||||
("4", 90),
|
||||
("5", 95),
|
||||
("6", 99),
|
||||
)
|
||||
def testMapAndBatchOutOfRangeError(self, threshold):
|
||||
|
||||
def raising_py_fn(i):
|
||||
if i >= threshold:
|
||||
raise StopIteration()
|
||||
else:
|
||||
return i
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(100).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
|
||||
batch_size=10)).make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(threshold // 10):
|
||||
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
|
||||
if threshold % 10 != 0:
|
||||
self.assertAllEqual(
|
||||
[threshold // 10 * 10 + j for j in range(threshold % 10)],
|
||||
sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", False, dtypes.bool),
|
||||
("2", -42, dtypes.int8),
|
||||
("3", -42, dtypes.int16),
|
||||
("4", -42, dtypes.int32),
|
||||
("5", -42, dtypes.int64),
|
||||
("6", 42, dtypes.uint8),
|
||||
("7", 42, dtypes.uint16),
|
||||
("8", 42.0, dtypes.float16),
|
||||
("9", 42.0, dtypes.float32),
|
||||
("10", 42.0, dtypes.float64),
|
||||
("11", b"hello", dtypes.string),
|
||||
)
|
||||
def testMapAndBatchTypes(self, element, dtype):
|
||||
def gen():
|
||||
yield element
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
|
||||
batching.map_and_batch(lambda x: x, batch_size=10))
|
||||
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -235,6 +235,18 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
sess.close()
|
||||
thread.join()
|
||||
|
||||
def testMapDefunWithCapturedInputs(self):
|
||||
c = constant_op.constant(2)
|
||||
|
||||
@function.Defun(dtypes.int32)
|
||||
def fn(x):
|
||||
return x + c
|
||||
|
||||
x = constant_op.constant([1, 2, 3, 4])
|
||||
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
|
||||
expected = x + c
|
||||
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
|
||||
|
||||
|
||||
class MapDefunBenchmark(test.Benchmark):
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline statistics gathering ops."""
|
||||
"""Tests for the private `override_threadpool()` transformation."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -32,8 +32,8 @@ from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 1, None),
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
"""Tests for `tf.data.experimental.parallel_interleave()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -37,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
|
||||
class ParallelInterleaveTest(test_base.DatasetTestBase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.parsing_ops."""
|
||||
"""Tests for `tf.data.experimental.parse_example_dataset()."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -73,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
|
||||
i += 1
|
||||
|
||||
|
||||
class ParseExampleTest(test_base.DatasetTestBase):
|
||||
class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
def _test(self,
|
||||
input_tensor,
|
@ -0,0 +1,234 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.prefetch_to_device()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PrefetchToDeviceTest(test_base.DatasetTestBase):
|
||||
|
||||
def testPrefetchToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
self.assertEqual([], next_element.shape)
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToSameDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
self.assertEqual([], next_element.shape)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchDictToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element["a"].dtype)
|
||||
self.assertEqual([], next_element["a"].shape)
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual({"a": i}, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchSparseTensorsToDevice(self):
|
||||
def make_tensor(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
|
||||
host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
|
||||
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
actual = sess.run(next_element)
|
||||
self.assertAllEqual([i], actual.values)
|
||||
self.assertAllEqual([[0, 0]], actual.indices)
|
||||
self.assertAllEqual([2, 2], actual.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToDeviceGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToDeviceWithReInit(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
self.assertEqual([], next_element.shape)
|
||||
|
||||
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToDeviceGpuWithReInit(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -63,11 +63,11 @@ class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
|
||||
return filenames
|
||||
|
||||
|
||||
class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
|
||||
class MakeBatchedFeaturesDatasetTestBase(test_base.DatasetTestBase):
|
||||
"""Base class for setting up and testing `make_batched_feature_dataset`."""
|
||||
|
||||
def setUp(self):
|
||||
super(ReadBatchFeaturesTestBase, self).setUp()
|
||||
super(MakeBatchedFeaturesDatasetTestBase, self).setUp()
|
||||
self._num_files = 2
|
||||
self._num_records = 7
|
||||
self.test_filenames = self._createFiles()
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
"""Tests for `tf.data.experimental.rejection_resample()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -58,7 +58,7 @@ def _time_resampling(
|
||||
return end_time - start_time
|
||||
|
||||
|
||||
class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("InitialDistributionKnown", True),
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
"""Tests for the private `_RestructuredDataset` transformation."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -26,7 +26,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DatasetConstructorTest(test_base.DatasetTestBase):
|
||||
class RestructuredDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
def testRestructureDataset(self):
|
||||
components = (array_ops.placeholder(dtypes.int32),
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
"""Tests for `tf.data.experimental.scan()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -34,7 +34,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ScanDatasetTest(test_base.DatasetTestBase):
|
||||
class ScanTest(test_base.DatasetTestBase):
|
||||
|
||||
def _counting_dataset(self, start, scan_fn):
|
||||
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
|
@ -69,6 +69,26 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "checkpoint_input_pipeline_hook_test",
|
||||
size = "small",
|
||||
srcs = ["checkpoint_input_pipeline_hook_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/experimental/ops:iterator_ops",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "concatenate_dataset_serialization_test",
|
||||
size = "small",
|
||||
@ -580,7 +600,7 @@ py_test(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python/data/experimental/kernel_tests:sql_dataset_op_test_base",
|
||||
"//tensorflow/python/data/experimental/kernel_tests:sql_dataset_test_base",
|
||||
"//tensorflow/python/data/experimental/ops:readers",
|
||||
],
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ParseExampleDatasetSerializationTest(
|
||||
reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
|
||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
||||
def ParseExampleDataset(self, num_repeat, batch_size):
|
||||
|
@ -19,7 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
|
||||
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
|
||||
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
||||
from tensorflow.python.data.experimental.ops import readers
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SqlDatasetSerializationTest(
|
||||
sql_dataset_op_test_base.SqlDatasetTestBase,
|
||||
sql_dataset_test_base.SqlDatasetTestBase,
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
||||
def _build_dataset(self, num_repeats):
|
||||
|
@ -1,85 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Integration test for dataset serialization."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
class SerializationIntegrationTest(test.TestCase):
|
||||
|
||||
def _build_input_pipeline(self, name, num_outputs):
|
||||
with ops.name_scope(name):
|
||||
ds = dataset_ops.Dataset.range(num_outputs).shuffle(
|
||||
10, reshuffle_each_iteration=False).prefetch(10)
|
||||
iterator = ds.make_initializable_iterator()
|
||||
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||
return iterator.initializer, iterator.get_next()
|
||||
|
||||
def _build_graph(self, num_pipelines, num_outputs):
|
||||
init_ops = []
|
||||
get_next_ops = []
|
||||
for i in range(num_pipelines):
|
||||
name = "input_pipeline_%d" % i
|
||||
init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
|
||||
init_ops.append(init_op)
|
||||
get_next_ops.append(get_next_op)
|
||||
saver = saver_lib.Saver()
|
||||
return init_ops, get_next_ops, saver
|
||||
|
||||
def _ckpt_path(self):
|
||||
return os.path.join(self.get_temp_dir(), "iterator")
|
||||
|
||||
def testConcurrentSaves(self):
|
||||
num_pipelines = 100
|
||||
num_outputs = 100
|
||||
break_point = 10
|
||||
all_outputs = [[] for _ in range(num_pipelines)]
|
||||
with ops.Graph().as_default() as g:
|
||||
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
|
||||
num_outputs)
|
||||
with self.session(graph=g) as sess:
|
||||
sess.run(init_ops)
|
||||
for _ in range(break_point):
|
||||
output = sess.run(get_next_ops)
|
||||
for i in range(num_pipelines):
|
||||
all_outputs[i].append(output[i])
|
||||
saver.save(sess, self._ckpt_path())
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
|
||||
num_outputs)
|
||||
with self.session(graph=g) as sess:
|
||||
saver.restore(sess, self._ckpt_path())
|
||||
for _ in range(num_outputs - break_point):
|
||||
output = sess.run(get_next_ops)
|
||||
for i in range(num_pipelines):
|
||||
all_outputs[i].append(output[i])
|
||||
|
||||
for output in all_outputs:
|
||||
self.assertSequenceEqual(sorted(output), range(num_outputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
"""Tests for `tf.data.experimental.shuffle_and_repeat()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -12,19 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for experimental sql input op."""
|
||||
"""Tests for `tf.data.experimental.SqlDataset`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
|
||||
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
|
||||
class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
|
||||
|
||||
# Test that SqlDataset can read from a database table.
|
||||
def testReadResultSet(self):
|
@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Base class for testing SqlDataset."""
|
||||
|
||||
"""Base class for testing `tf.data.experimental.SqlDataset`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
@ -280,7 +280,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
class FeatureStatsDatasetTest(
|
||||
stats_dataset_test_base.StatsDatasetTestBase,
|
||||
reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
|
||||
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
|
||||
|
||||
def testFeaturesStats(self):
|
||||
num_epochs = 5
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
"""Tests for `tf.data.experimental.TFRecordWriter`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
300
tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
Normal file
300
tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
Normal file
@ -0,0 +1,300 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.unbatch()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testUnbatchWithUnknownRankInput(self):
|
||||
placeholder = array_ops.placeholder(dtypes.int32)
|
||||
dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
|
||||
batching.unbatch())
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_elem = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
|
||||
for i in range(4):
|
||||
self.assertEqual(i, sess.run(next_elem))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_elem)
|
||||
|
||||
def testUnbatchScalarDataset(self):
|
||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
expected_types = (dtypes.int32,) * 3
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
data = data.apply(batching.unbatch())
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i,) * 3, sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testUnbatchDatasetWithStrings(self):
|
||||
data = tuple([math_ops.range(10) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
|
||||
expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
data = data.apply(batching.unbatch())
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testUnbatchDatasetWithSparseTensor(self):
|
||||
st = sparse_tensor.SparseTensorValue(
|
||||
indices=[[i, i] for i in range(10)],
|
||||
values=list(range(10)),
|
||||
dense_shape=[10, 10])
|
||||
data = dataset_ops.Dataset.from_tensors(st)
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.batch(5)
|
||||
data = data.apply(batching.unbatch())
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
st_row = sess.run(next_element)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testUnbatchDatasetWithDenseAndSparseTensor(self):
|
||||
st = sparse_tensor.SparseTensorValue(
|
||||
indices=[[i, i] for i in range(10)],
|
||||
values=list(range(10)),
|
||||
dense_shape=[10, 10])
|
||||
data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
|
||||
data = data.apply(batching.unbatch())
|
||||
data = data.batch(5)
|
||||
data = data.apply(batching.unbatch())
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
dense_elem, st_row = sess.run(next_element)
|
||||
self.assertEqual(i, dense_elem)
|
||||
self.assertEqual([i], st_row.indices)
|
||||
self.assertEqual([i], st_row.values)
|
||||
self.assertEqual([10], st_row.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testUnbatchSingleElementTupleDataset(self):
|
||||
data = tuple([(math_ops.range(10),) for _ in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
expected_types = ((dtypes.int32,),) * 3
|
||||
data = data.batch(2)
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
data = data.apply(batching.unbatch())
|
||||
self.assertEqual(expected_types, data.output_types)
|
||||
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i,),) * 3, sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testUnbatchMultiElementTupleDataset(self):
|
||||
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
|
||||
array_ops.fill([10], "hi")) for i in range(3)])
|
||||
data = dataset_ops.Dataset.from_tensor_slices(data)
|
||||
expected_types = ((dtypes.int32, dtypes.string),) * 3
|
||||
data = data.batch(2)
|
||||
self.assertAllEqual(expected_types, data.output_types)
|
||||
data = data.apply(batching.unbatch())
|
||||
self.assertAllEqual(expected_types, data.output_types)
|
||||
|
||||
iterator = data.make_one_shot_iterator()
|
||||
op = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
|
||||
sess.run(op))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(op)
|
||||
|
||||
def testUnbatchEmpty(self):
|
||||
data = dataset_ops.Dataset.from_tensors(
|
||||
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
|
||||
constant_op.constant([], shape=[0, 4, 0])))
|
||||
data = data.apply(batching.unbatch())
|
||||
iterator = data.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testUnbatchStaticShapeMismatch(self):
|
||||
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
|
||||
np.arange(9)))
|
||||
with self.assertRaises(ValueError):
|
||||
data.apply(batching.unbatch())
|
||||
|
||||
def testUnbatchDynamicShapeMismatch(self):
|
||||
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
|
||||
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
|
||||
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
|
||||
data = data.apply(batching.unbatch())
|
||||
iterator = data.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Mismatch in the 0th dimension.
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
feed_dict={
|
||||
ph1: np.arange(7).astype(np.int32),
|
||||
ph2: np.arange(8).astype(np.int32)
|
||||
})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(next_element)
|
||||
|
||||
# No 0th dimension (i.e. scalar value) for one component.
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
feed_dict={
|
||||
ph1: np.arange(7).astype(np.int32),
|
||||
ph2: 7
|
||||
})
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(next_element)
|
||||
|
||||
|
||||
class UnbatchBenchmark(test.Benchmark):
|
||||
|
||||
def benchmarkNativeUnbatch(self):
|
||||
batch_sizes = [1, 2, 5, 10, 20, 50]
|
||||
elems_per_trial = 10000
|
||||
with ops.Graph().as_default():
|
||||
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
|
||||
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
dataset = dataset.batch(batch_size_placeholder)
|
||||
dataset = dataset.apply(batching.unbatch())
|
||||
dataset = dataset.skip(elems_per_trial)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
for batch_size in batch_sizes:
|
||||
deltas = []
|
||||
for _ in range(5):
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
feed_dict={batch_size_placeholder: batch_size})
|
||||
start = time.time()
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append((end - start) / elems_per_trial)
|
||||
|
||||
median_wall_time = np.median(deltas)
|
||||
print("Unbatch (native) batch size: %d Median wall time per element:"
|
||||
" %f microseconds" % (batch_size, median_wall_time * 1e6))
|
||||
self.report_benchmark(
|
||||
iters=10000,
|
||||
wall_time=median_wall_time,
|
||||
name="benchmark_unbatch_dataset_native_batch_size_%d" %
|
||||
batch_size)
|
||||
|
||||
# Include a benchmark of the previous `unbatch()` implementation that uses
|
||||
# a composition of more primitive ops. Eventually we'd hope to generate code
|
||||
# that is as good in both cases.
|
||||
def benchmarkOldUnbatchImplementation(self):
|
||||
batch_sizes = [1, 2, 5, 10, 20, 50]
|
||||
elems_per_trial = 10000
|
||||
with ops.Graph().as_default():
|
||||
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
|
||||
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
dataset = dataset.batch(batch_size_placeholder)
|
||||
dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
|
||||
dataset = dataset.skip(elems_per_trial)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with session.Session() as sess:
|
||||
for batch_size in batch_sizes:
|
||||
deltas = []
|
||||
for _ in range(5):
|
||||
sess.run(
|
||||
iterator.initializer,
|
||||
feed_dict={batch_size_placeholder: batch_size})
|
||||
start = time.time()
|
||||
sess.run(next_element.op)
|
||||
end = time.time()
|
||||
deltas.append((end - start) / elems_per_trial)
|
||||
|
||||
median_wall_time = np.median(deltas)
|
||||
print("Unbatch (unfused) batch size: %d Median wall time per element:"
|
||||
" %f microseconds" % (batch_size, median_wall_time * 1e6))
|
||||
self.report_benchmark(
|
||||
iters=10000,
|
||||
wall_time=median_wall_time,
|
||||
name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
|
||||
batch_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
"""Tests for `tf.data.experimental.unique()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -26,7 +26,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class UniqueDatasetTest(test_base.DatasetTestBase):
|
||||
class UniqueTest(test_base.DatasetTestBase):
|
||||
|
||||
def _testSimpleHelper(self, dtype, test_cases):
|
||||
"""Test the `unique()` transformation on a list of test cases.
|
@ -47,10 +47,12 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
|
||||
if not isinstance(elems, list):
|
||||
raise ValueError("`elems` must be a list of tensors.")
|
||||
if not isinstance(output_dtypes, list):
|
||||
raise ValueError("`output_dtypes` must be a list of tensors.")
|
||||
raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.")
|
||||
if not isinstance(output_shapes, list):
|
||||
raise ValueError("`output_shapes` must be a list of tensors.")
|
||||
raise ValueError("`output_shapes` must be a list of `tf.TensorShape` "
|
||||
"objects.")
|
||||
|
||||
elems = [ops.convert_to_tensor(e) for e in elems]
|
||||
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
|
||||
return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
|
||||
return gen_dataset_ops.map_defun(elems, fn.captured_inputs, output_dtypes,
|
||||
output_shapes, fn)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user