Merge branch 'master' into upstream-staging-terminology-3

This commit is contained in:
Jeff Poznanovic 2018-10-04 16:27:22 -06:00 committed by GitHub
commit 19d836b4fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
111 changed files with 4242 additions and 3911 deletions

View File

@ -894,6 +894,22 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "tensor_list_ops_test",
size = "small",
srcs = ["tensor_list_ops_test.py"],
# TensorList ops are not implemented in the on-demand compilation model yet.
disabled_backends = "cpu_ondemand",
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:list_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:function",
],
)
tf_xla_py_test(
name = "ternary_ops_test",
size = "small",

View File

@ -1820,7 +1820,7 @@ TEST_F(OpTest, Diag) {
do {
dims = RandomDims(1);
size = TensorShape(dims).num_elements();
} while (size * size < tf_xla_max_tensor_size);
} while (size * size > tf_xla_max_tensor_size);
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type));
});

View File

@ -0,0 +1,105 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ops which manipulate lists of tensors via bridge."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
def scalar_shape():
return ops.convert_to_tensor([], dtype=dtypes.int32)
class ListOpsTest(xla_test.XLATestCase):
def testElementShape(self):
with self.cached_session() as sess, self.test_scope():
dim = array_ops.placeholder(dtypes.int32)
l = list_ops.tensor_list_reserve(
element_shape=(dim, 15), num_elements=20,
element_dtype=dtypes.float32)
e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64)
self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15))
self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
def testPushPop(self):
with self.cached_session() as sess, self.test_scope():
num = array_ops.placeholder(dtypes.int32)
l = list_ops.tensor_list_reserve(
element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32)
l = list_ops.tensor_list_push_back(
l, constant_op.constant(1.0, shape=(7, 15)))
l = list_ops.tensor_list_push_back(
l, constant_op.constant(2.0, shape=(7, 15)))
l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
_, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15)))
self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15)))
def testPushPopSeparateLists(self):
with self.cached_session() as sess, self.test_scope():
num = array_ops.placeholder(dtypes.int32)
l = list_ops.tensor_list_reserve(
element_shape=scalar_shape(),
num_elements=num,
element_dtype=dtypes.float32)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0))
_, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20})
self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
def testEmptyTensorList(self):
dim = 7
with self.cached_session() as sess, self.test_scope():
p = array_ops.placeholder(dtypes.int32)
l = list_ops.empty_tensor_list(
element_shape=(p, 15), element_dtype=dtypes.float32)
l = list_ops.tensor_list_push_back(
l, constant_op.constant(1.0, shape=(dim, 15)))
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Use TensorListReserve instead"):
self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15)))
if __name__ == "__main__":
test.main()

View File

@ -95,6 +95,7 @@ tf_kernel_library(
"stateless_random_ops.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
"tensor_list_ops.cc",
"tile_ops.cc",
"topk_op.cc",
"training_ops.cc",
@ -158,6 +159,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:list_kernels",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:pooling_ops",

View File

@ -0,0 +1,226 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// XLA TensorList operators.
#include <limits>
#include <vector>
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op,
TensorShape* tensor_list_shape) {
auto shape_or_status = builder->GetShape(op);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
xla::Shape shape = shape_or_status.ValueOrDie();
TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
tensor_list_shape);
}
class TensorListReserveOp : public XlaOpKernel {
public:
explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape element_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
int64 num_elements;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
TensorShape tensor_shape;
tensor_shape.AddDim(num_elements);
tensor_shape.AppendShape(element_shape);
xla::XlaBuilder* b = ctx->builder();
ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_),
tensor_shape.dim_sizes()),
xla::ConstantR0<int32>(b, 0)}));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
};
REGISTER_XLA_OP(Name("TensorListReserve")
.CompileTimeConstInput("element_shape")
.CompileTimeConstInput("num_elements"),
TensorListReserveOp);
class EmptyTensorListOp : public XlaOpKernel {
public:
explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->CtxFailure(
errors::InvalidArgument("XLA compilation requires a fixed tensor list "
"size. Use TensorListReserve instead."));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
};
REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp);
class TensorListElementShapeOp : public XlaOpKernel {
public:
explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
TensorShape shape;
OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape));
shape.RemoveDim(0);
switch (shape_type_) {
case DT_INT64:
ctx->SetOutput(0, xla::ConstantR1<int64>(b, shape.dim_sizes()));
break;
case DT_INT32: {
std::vector<int32> size;
for (int64 s : shape.dim_sizes()) {
size.push_back(s);
}
ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
break;
}
default:
ctx->CtxFailure(
errors::InvalidArgument("Unsupported shape type requested"));
return;
}
}
private:
DataType shape_type_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
};
REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp);
class TensorListPushBackOp : public XlaOpKernel {
public:
explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp list = ctx->Input(0);
TensorShape elem_shape = ctx->InputShape(1);
xla::XlaOp ta = xla::GetTupleElement(list, 0);
xla::XlaOp index = xla::GetTupleElement(list, 1);
xla::XlaOp value = ctx->Input(1);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
auto update = xla::Reshape(value, slice_shape.dim_sizes());
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
ctx->SetOutput(
0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices),
index + xla::ConstantR0<int32>(b, 1)}));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
};
REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp);
class TensorListPopBackOp : public XlaOpKernel {
public:
explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp state = ctx->Input(0);
TensorShape shape;
OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
xla::XlaOp ta = xla::GetTupleElement(state, 0);
xla::XlaOp index = xla::GetTupleElement(state, 1);
index = index - xla::ConstantR0<int32>(b, 1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}}));
auto slice_shape = shape.dim_sizes();
slice_shape[0] = 1LL;
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
ctx->SetOutput(0, xla::Tuple(b, {ta, index}));
ctx->SetOutput(1, xla::Reshape(read, value_shape));
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
};
REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp);
} // anonymous namespace
} // namespace tensorflow

View File

@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
return Status::OK();
}
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression.
Tensor* output = nullptr;
auto shape = builder()->GetShape(handle);
if (!shape.ok()) {
SetStatus(shape.status());
return;
}
Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
Tensor** output) {
// The step's default allocator is the dummy XlaCompilationAllocator which
// simply allocates a metadata buffer to hold the expression to which it
// corresponds.
TensorShape tensor_shape;
if (expected_output_dtype(index) == DT_VARIANT) {
// tensor_data() is not supported for variant Tensor (i.e.,
// DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
// XlaExpression inside the Tensor's tensor_data() does not work for
// variant. Instead construct a uint8 tensor and store the expression in its
// value.
// TODO(jpienaar): This should be refactored to stop masquerading
// XlaExpressions as Tensors.
*output = new Tensor();
TensorShape tensor_shape;
TF_RETURN_IF_ERROR(
context_->allocate_temp(DT_UINT8, tensor_shape, *output));
context_->set_output(index, **output);
} else {
TensorShape tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
}
return Status::OK();
}
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression.
Tensor* output = nullptr;
auto shape_or = builder()->GetShape(handle);
if (!shape_or.ok()) {
SetStatus(shape_or.status());
return;
}
OP_REQUIRES_OK(context_,
XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape));
OP_REQUIRES_OK(context_,
context_->allocate_output(index, tensor_shape, &output));
allocate_output(index, shape_or.ValueOrDie(), &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.

View File

@ -255,6 +255,11 @@ class XlaOpKernelContext {
// Returns the tensor of input `name`.
const Tensor& GetInputTensorByName(absl::string_view name);
// Wraps OpKernelContext's allocate_output method while providing special
// behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
// type to allow mapping for variant to more generic types.
Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
OpKernelContext* const context_;
};

View File

@ -1945,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
}
} break;
case TUPLE:
LOG(FATAL) << "Should not be called on tuple shapes: "
<< ShapeUtil::HumanString(subshape());
break;
return InvalidArgument("Should not be called on tuple shapes: %s",
ShapeUtil::HumanString(subshape()));
default:
LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
return InvalidArgument("Is called on unsupported shape: %s",
ShapeUtil::HumanString(subshape()));
}
return Status::OK();
}

View File

@ -1841,42 +1841,6 @@ tf_cc_test(
],
)
cc_library(
name = "inliner",
srcs = ["inliner.cc"],
hdrs = ["inliner.h"],
deps = [
":hlo",
":hlo_pass",
":hlo_query",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "inliner_test",
srcs = ["inliner_test.cc"],
deps = [
":cpu_plugin",
":hlo",
":hlo_matchers",
":inliner",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
],
)
cc_library(
name = "computation_placer",
srcs = ["computation_placer.cc"],
@ -3492,6 +3456,39 @@ cc_library(
deps = ["//tensorflow/core:lib"],
)
cc_library(
name = "map_inliner",
srcs = ["map_inliner.cc"],
hdrs = ["map_inliner.h"],
deps = [
":hlo",
":hlo_pass",
":hlo_query",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "map_inliner_test",
srcs = ["map_inliner_test.cc"],
deps = [
":hlo",
":hlo_matchers",
":map_inliner",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"@com_google_absl//absl/memory",
],
)
tf_cc_test(
name = "hlo_casting_utils_test",
srcs = ["hlo_casting_utils_test.cc"],

View File

@ -94,6 +94,7 @@ cc_library(
":target_machine_features",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
@ -127,7 +128,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
"//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",

View File

@ -86,8 +86,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/scatter_expander.h"
@ -249,7 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
pipeline.AddPass<Inliner>();
pipeline.AddPass<MapInliner>();
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
// pass.

View File

@ -465,8 +465,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kIota:
TF_RET_CHECK(proto.dimensions_size() <= 1)
<< "Iota instruction should have at most 1 dimension but sees "
TF_RET_CHECK(proto.dimensions_size() == 1)
<< "Iota instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;

View File

@ -195,13 +195,15 @@ class ListScheduler {
return entry;
}
// Returns the number of bytes freed if the HLO instruction is scheduled.
// If the instruction calls subcomputations, we count the memory used by the
// subcomputations as memory "defined" by the instruction. This is not
// entirely accurate, because subcomputation memory will be freed after the
// instruction finishes. But it is more accurate than not taking
// subcomputations into account at all. In the future, we may improve
// accounting for subcomputation memory (b/65409243).
// Returns the number of bytes freed *after* the HLO instruction finishes.
// The current List algorithm only considers two states for an instruction:
// right before it runs, and after it finishes. We don't represent memory
// usage during the execution of an instruction. But if the instruction calls
// subcomputations, they are only live during the instruction's execution.
// We end up counting the memory used by subcomputations as memory "defined"
// by the instruction. This is not entirely accurate, but it is more accurate
// than not taking subcomputations into account at all. In the future, we may
// improve accounting for subcomputation memory (b/65409243).
int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
int64 freed_bytes = 0;
for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
@ -223,7 +225,18 @@ class ListScheduler {
}
}
}
return freed_bytes - entry.bytes_defined - max_subcomputation_bytes;
int64 bytes_defined;
if (max_subcomputation_bytes > 0 &&
(entry.instruction->opcode() == HloOpcode::kWhile ||
entry.instruction->opcode() == HloOpcode::kCall ||
entry.instruction->opcode() == HloOpcode::kConditional)) {
// The output buffer of while/call/conditional is always aliased with the
// output buffer of the root instruction in the body. Don't double count.
bytes_defined = max_subcomputation_bytes;
} else {
bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
}
return freed_bytes - bytes_defined;
}
// Constructs the scheduling priority of the given instruction.

View File

@ -1304,7 +1304,7 @@ TEST_F(HloParserTest, MoreConstants) {
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
%constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
%constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4}
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@ -377,6 +378,20 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
<< "Maximal sharding is expected to have single device assignment, but "
<< proto.tile_assignment_devices().size() << " has provided.";
TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
// RE: the product of tile assignment tensor dimensions must be
// equal to tile_assignment_devices.size().
int64 product_of_dimensions = 1;
for (auto dimension : proto.tile_assignment_dimensions()) {
TF_RET_CHECK(dimension > 0);
product_of_dimensions =
MultiplyWithoutOverflow(product_of_dimensions, dimension);
TF_RET_CHECK(product_of_dimensions > 0);
}
TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),

View File

@ -45,8 +45,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/core:lib",

View File

@ -28,9 +28,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/status_macros.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include <memory>
#include <string>
@ -32,10 +32,10 @@ limitations under the License.
namespace xla {
// InlinerVisitor traverses the HLO computation and inlines maps.
class InlinerVisitor : public DfsHloVisitorWithDefault {
// MapInlinerVisitor traverses the HLO computation and inlines maps.
class MapInlinerVisitor : public DfsHloVisitorWithDefault {
public:
explicit InlinerVisitor(HloComputation* computation)
explicit MapInlinerVisitor(HloComputation* computation)
: computation_(computation) {}
// Default visitor action is to do nothing and return OK.
@ -49,24 +49,23 @@ class InlinerVisitor : public DfsHloVisitorWithDefault {
StatusOr<bool> Run(HloComputation* computation);
private:
// Current HloComputation instance the InlinerVisitor is traversing.
// Current HloComputation instance the MapInlinerVisitor is traversing.
HloComputation* computation_;
// Whether algebraic simplification has occurred.
bool changed_ = false;
};
StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) {
StatusOr<bool> MapInlinerVisitor::Run(HloComputation* computation) {
changed_ = false;
computation_ = computation;
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
return changed_;
}
Status InlinerVisitor::HandleMap(HloInstruction* map) {
Status MapInlinerVisitor::HandleMap(HloInstruction* map) {
HloComputation* function = map->to_apply();
HloInstruction& root = *function->root_instruction();
// TODO(b/29249531): Add DCE pass to remove unused HloComputations.
// Only inlining functions that are simply a single operation until a better
// profitability model for inlining is defined.
if (hlo_query::AllOperandsAreParameters(root)) {
@ -112,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
return Status::OK();
}
StatusOr<bool> Inliner::Run(HloModule* module) {
InlinerVisitor visitor(/*computation=*/nullptr);
StatusOr<bool> MapInliner::Run(HloModule* module) {
MapInlinerVisitor visitor(/*computation=*/nullptr);
bool changed = false;
for (HloComputation* computation : module->computations()) {
TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation));

View File

@ -13,27 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// A pass which performs inlining. Which can result, for example, in functions
// that were previously being mapped by Map instead directly applied to the
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
class Inliner : public HloModulePass {
// A pass which performs map inlining. This replaces kMap instructions with
// their equivalent sequence of array operations. For example:
// map({X, Y}, add) -> add(X, Y)).
class MapInliner : public HloModulePass {
public:
~Inliner() override = default;
absl::string_view name() const override { return "inline"; }
~MapInliner() override = default;
absl::string_view name() const override { return "map-inline"; }
// Run inlining on the given computation. Returns whether the computation was
// changed.
// Run map inlining on the given computation. Returns whether the computation
// was changed.
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include <memory>
#include <utility>
@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
using InlinerTest = HloVerifiedTestBase;
using MapInlinerTest = HloVerifiedTestBase;
// Test that `map` with `max` is transformed to `max`
TEST_F(InlinerTest, MapMax) {
TEST_F(MapInlinerTest, MapMax) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
auto max_builder = HloComputation::Builder(TestName());
@ -63,7 +63,7 @@ TEST_F(InlinerTest, MapMax) {
hlo_module->AddEmbeddedComputation(std::move(max_f32));
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Maximum(lhs, rhs));
@ -75,7 +75,7 @@ TEST_F(InlinerTest, MapMax) {
}
// Test that `constant` function is changed to `broadcast`.
TEST_F(InlinerTest, MapConstant) {
TEST_F(MapInlinerTest, MapConstant) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
auto const2_builder = HloComputation::Builder(TestName());
@ -97,7 +97,7 @@ TEST_F(InlinerTest, MapConstant) {
hlo_module->AddEmbeddedComputation(std::move(const2_f32));
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
Inliner inliner;
MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
root = hlo_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Broadcast(op::Constant()));
@ -108,7 +108,7 @@ TEST_F(InlinerTest, MapConstant) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
TEST_F(MapInlinerTest, MapSubtractOppositeOrder) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
// Note that the parameter ordinals are in the opposite order to their
@ -135,7 +135,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
hlo_module->AddEmbeddedComputation(std::move(max_f32));
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Subtract(rhs, lhs));
@ -146,7 +146,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
TEST_F(InlinerTest, MapParameter) {
TEST_F(MapInlinerTest, MapParameter) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
auto param_builder = HloComputation::Builder(TestName());
@ -167,7 +167,7 @@ TEST_F(InlinerTest, MapParameter) {
hlo_module->AddEmbeddedComputation(std::move(param_f32));
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);

View File

@ -869,11 +869,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
return Status::OK();
}
if (Rank(shape) != shape.dimensions_size()) {
return InvalidArgument(
"shape's rank is mismatched with dimension count; rank=%d "
"dimensions_size=%d",
Rank(shape), shape.dimensions_size());
if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) {
return InvalidArgument("sparse arrays must have rank > 0");
}
for (int64 i = 0; i < Rank(shape); ++i) {
int64 dimension = shape.dimensions(i);

View File

@ -4,6 +4,7 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
exports_files(glob([
@ -165,10 +166,6 @@ cc_library(
"stderr_reporter.h",
],
copts = tflite_copts(),
defines = select({
":with_tflite_flex": ["TFLITE_FLEX"],
"//conditions:default": [],
}),
linkopts = [
] + select({
"//tensorflow:android": [
@ -276,6 +273,7 @@ cc_test(
"testdata/0_subgraphs.bin",
"testdata/2_subgraphs.bin",
"testdata/empty_model.bin",
"testdata/multi_add_flex.bin",
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
],
@ -283,6 +281,26 @@ cc_test(
":framework",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
# Test model framework with the flex library linked into the target.
tf_cc_test(
name = "model_flex_test",
size = "small",
srcs = ["model_flex_test.cc"],
data = [
"testdata/multi_add_flex.bin",
],
tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC.
deps = [
":framework",
"//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/delegates/flex:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],

View File

@ -2,7 +2,7 @@
# This is a TF Lite delegate that is powered by TensorFlow's Eager.
#
package(default_visibility = [
"//visibility:public",
"//visibility:private",
])
licenses(["notice"]) # Apache 2.0
@ -50,6 +50,7 @@ cc_library(
hdrs = [
"delegate.h",
],
visibility = ["//visibility:public"],
deps = [
":buffer_map",
":delegate_data",
@ -66,6 +67,7 @@ cc_library(
"//tensorflow/core:lib",
],
}),
alwayslink = 1,
)
tf_cc_test(

View File

@ -83,6 +83,15 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
} // namespace delegate
} // namespace flex
// Corresponding weak declaration found in lite/model.cc.
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
AcquireFlexDelegate() {
return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
});
}
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
std::unique_ptr<flex::DelegateData> delegate_data;
if (!flex::DelegateData::Create(&delegate_data).ok()) {

View File

@ -349,6 +349,10 @@ class Interpreter {
return context_.allow_fp32_relax_to_fp16;
}
// Owning handle to a TfLiteDelegate instance.
using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
// Allow a delegate to look at the graph and modify the graph to handle
// parts of the graph themselves. After this is called, the graph may
// contain new nodes that replace 1 more nodes.
@ -574,19 +578,11 @@ class Interpreter {
TfLiteExternalContextType type,
TfLiteExternalContext* ctx);
using TfLiteDelegatePtr =
std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
// Variant of the public ModifyGraphWithDelegate method that additionally
// Assumes ownership of the provided delegate.
// WARNING: This is an experimental API and subject to change.
template <typename Delegate>
TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr<Delegate> typed_delegate,
TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate,
bool allow_dynamic_tensors = false) {
TfLiteDelegatePtr delegate(typed_delegate.release(),
[](TfLiteDelegate* delegate) {
delete static_cast<Delegate*>(delegate);
});
// Note that we retain ownership of the delegate even if graph modification
// fails, as delegate use will be in an indeterminate state at that point.
owned_delegates_.push_back(std::move(delegate));
@ -676,6 +672,7 @@ class Interpreter {
// List of delegates that have been installed and are owned by this
// interpreter instance. Useful if client delegate ownership is burdensome.
// WARNING: This is an experimental API and subject to change.
// TODO(b/116667551): Use TfLiteExternalContext for storing state.
std::vector<TfLiteDelegatePtr> owned_delegates_;
std::unique_ptr<MemoryPlanner> memory_planner_;

View File

@ -30,7 +30,11 @@ class InterpreterTest : public ::testing::Test {
template <typename Delegate>
static TfLiteStatus ModifyGraphWithDelegate(
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
return interpreter->ModifyGraphWithDelegate(std::move(delegate));
Interpreter::TfLiteDelegatePtr tflite_delegate(
delegate.release(), [](TfLiteDelegate* delegate) {
delete reinterpret_cast<Delegate*>(delegate);
});
return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate));
}
protected:

View File

@ -113,6 +113,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// input configuration.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int fw_num_units = fw_input_weights->dims->data[0];

View File

@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
if (input1->type == kTfLiteUInt8) { \
auto input1_offset = -input1->params.zero_point; \
auto input2_offset = -input2->params.zero_point; \
const int left_shift = 20; \
const double twice_max_input_scale = \
2 * std::max(input1->params.scale, input2->params.scale); \
const double real_input1_multiplier = \
input1->params.scale / twice_max_input_scale; \
const double real_input2_multiplier = \
input2->params.scale / twice_max_input_scale; \
const int left_shift = 8; \
\
int32 input1_multiplier; \
int input1_shift; \
QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \
QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
&input1_multiplier, &input1_shift); \
int32 input2_multiplier; \
int input2_shift; \
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
&input2_multiplier, &input2_shift); \
\
ComparisonParams op_params; \
op_params.left_shift = left_shift; \
op_params.input1_offset = input1_offset; \
op_params.input1_multiplier = input1_multiplier; \
op_params.input1_shift = -input1_shift; \
op_params.input1_shift = input1_shift; \
op_params.input2_offset = input2_offset; \
op_params.input2_multiplier = input2_multiplier; \
op_params.input2_shift = -input2_shift; \
op_params.input2_shift = input2_shift; \
if (requires_broadcast) { \
reference_ops::Broadcast4DSlow##opname##WithScaling( \
op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \

View File

@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}
TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
{TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
TensorType_UINT8, BuiltinOperator_GREATER);
model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
model.Invoke();
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}
TEST(ComparisonsTest, GreaterEqualQuantized) {
const float kMin = -1.f;
const float kMax = 128.f;

View File

@ -27,9 +27,6 @@ limitations under the License.
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#endif
#if defined(TFLITE_FLEX)
#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#endif
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
@ -43,6 +40,25 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
const char* kEmptyTensorName = "";
// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but
// we avoid the absl dependency for binary size reasons.
#ifdef __has_attribute
#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x)
#else
#define TFLITE_HAS_ATTRIBUTE(x) 0
#endif
#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__))
// Using weak symbols for the flex delegate allows automatic injection of the
// delegate simply by adding it as a dependency. See also the strong override in
// lite/delegates/flex/delegate.cc.
__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
}
#else
Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
#endif
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
@ -450,13 +466,14 @@ TfLiteStatus InterpreterBuilder::operator()(
}
(**interpreter).SetVariables(std::move(variables));
#if defined(TFLITE_FLEX)
if (auto delegate = FlexDelegate::Create()) {
(**interpreter)
.ModifyGraphWithDelegate(std::move(delegate),
/*allow_dynamic_tensors=*/true);
// TODO(b/116667551): Only create the flex delegate if the model has flex ops.
if (AcquireFlexDelegate != nullptr) {
if (auto flex_delegate = AcquireFlexDelegate()) {
(**interpreter)
.ModifyGraphWithDelegate(std::move(flex_delegate),
/*allow_dynamic_tensors=*/true);
}
}
#endif
return kTfLiteOk;
}

View File

@ -0,0 +1,45 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/model.h"
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
// Ensures that a model with TensorFlow ops can be imported as long as the
// appropriate delegate is linked into the client.
TEST(FlexModel, WithFlexDelegate) {
auto model = FlatBufferModel::BuildFromFile(
"tensorflow/contrib/lite/testdata/multi_add_flex.bin");
ASSERT_TRUE(model);
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(InterpreterBuilder(*model,
ops::builtin::BuiltinOpResolver{})(&interpreter),
kTfLiteOk);
ASSERT_TRUE(interpreter);
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
}
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
@ -193,6 +194,27 @@ TEST(BasicFlatBufferModel, TestModelInInterpreter) {
}
}
// Test that loading a model with TensorFlow ops fails when the flex delegate is
// not linked into the target.
TEST(FlexModel, FailureWithoutFlexDelegate) {
auto model = FlatBufferModel::BuildFromFile(
"tensorflow/contrib/lite/testdata/multi_add_flex.bin");
ASSERT_TRUE(model);
// Note that creation will succeed when using the BuiltinOpResolver, but
// unless the appropriate delegate is linked into the target or the client
// explicitly installs the delegate, execution will fail.
std::unique_ptr<Interpreter> interpreter;
ASSERT_EQ(InterpreterBuilder(*model,
ops::builtin::BuiltinOpResolver{})(&interpreter),
kTfLiteOk);
ASSERT_TRUE(interpreter);
// As the flex ops weren't resolved implicitly by the flex delegate, runtime
// allocation and execution will fail.
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError);
}
// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
// buffer. But the buffer is provided to be only 1 element.
TEST(BasicFlatBufferModel, TestBrokenMmap) {

Binary file not shown.

View File

@ -40,7 +40,7 @@ cc_binary(
srcs = [
"benchmark_main.cc",
],
copts = common_copts + ["-DTFLITE_FLEX"],
copts = common_copts,
linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
@ -49,8 +49,9 @@ cc_binary(
"//conditions:default": [],
}),
deps = [
":benchmark_tflite_model_plus_flex_lib",
":benchmark_tflite_model_lib",
":logging",
"//tensorflow/contrib/lite/delegates/flex:delegate",
],
)
@ -110,25 +111,6 @@ cc_library(
],
)
cc_library(
name = "benchmark_tflite_model_plus_flex_lib",
srcs = [
"benchmark_tflite_model.cc",
"logging.h",
],
hdrs = ["benchmark_tflite_model.h"],
copts = common_copts + ["-DTFLITE_FLEX"],
deps = [
":benchmark_model_lib",
":logging",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/delegates/flex:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/profiling:profile_summarizer",
],
)
cc_library(
name = "benchmark_params",
srcs = [

View File

@ -23,9 +23,6 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#ifdef TFLITE_FLEX
#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
@ -305,15 +302,6 @@ void BenchmarkTfLiteModel::Init() {
interpreter->UseNNAPI(use_nnapi);
#ifdef TFLITE_FLEX
TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
delegate_ = FlexDelegate::Create();
if (delegate_) {
interpreter->ModifyGraphWithDelegate(delegate_.get(),
/*allow_dynamic_tensors=*/true);
}
#endif // TFLITE_FLEX
auto interpreter_inputs = interpreter->inputs();
if (!inputs.empty()) {

View File

@ -20,9 +20,6 @@ limitations under the License.
#include <string>
#include <vector>
#ifdef TFLITE_FLEX
#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
@ -73,9 +70,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
void PrepareInputsAndOutputs() override;
private:
#ifdef TFLITE_FLEX
std::unique_ptr<FlexDelegate> delegate_;
#endif // TFLITE_FLEX
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<InputLayerInfo> inputs;

View File

@ -4,22 +4,33 @@ op {
in_arg {
name: "arguments"
description: <<END
A list of tensors whose types are Targuments, corresponding to the inputs the
function should be mapped over.
A list of tensors whose types are `Targuments`, corresponding to the inputs
the function should be mapped over.
END
}
in_arg {
name: "captured_inputs"
description: <<END
A list of tensors whose types are `Tcaptured`, corresponding to the captured
inputs of the defun.
END
}
out_arg {
name: "output"
description: <<END
A list of output tensors whose types are output_types and whose dimensions 0
are the same as the dimensions 0 of the tensors in arguments, and whose
remaining dimensions correspond to those in output_shapes.
A list of output tensors whose types are `output_types` and whose dimensions
0 are the same as the dimensions 0 of the tensors in `arguments`, and whose
remaining dimensions correspond to those in `output_shapes`.
END
}
attr {
name: "Targuments"
description: "A list of types."
}
attr {
name: "Tcaptured"
description: "A list of types."
}
attr {
name: "output_types"
description: "A list of types."
@ -29,6 +40,6 @@ END
description: "A list of shapes."
}
summary: <<END
Maps a function on the list of tensors unpacked from inputs on dimension 0.
Maps a function on the list of tensors unpacked from arguments on dimension 0.
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListPushBackBatch"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "EmptyTensorList"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListConcatLists"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListElementShape"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListFromTensor"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListGather"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListGetItem"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListLength"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListPopBack"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListPushBack"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListReserve"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListScatter"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListSetItem"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "TensorListStack"
visibility: HIDDEN
}

View File

@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
map_defun_node->add_input(input.name());
}
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");

View File

@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
func.set_name(function_name);
NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
graph_transforms::SetNodeAttr("output_types", output_types, node);
graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
graph_transforms::SetNodeAttr("f", func, node);
@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
LOG(ERROR) << s;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));

View File

@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
// Validates inputs and gets the size of their leading dimension.
*batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
if (ctx->input(i).dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
} else if (ctx->input(i).dim_size(0) != *batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
", while all previous inputs have leading dimension ", batch_size);
}
}
return Status::OK();
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ComputeOptions* compute_opts = nullptr;
@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel {
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
const std::vector<Tensor> args;
OpInputList args;
const std::vector<TensorShape> arg_shapes;
OpInputList captured_inputs;
const int64 batch_size;
// Output of a compute call
@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel {
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
ComputeOptions(std::vector<Tensor> args,
ComputeOptions(OpInputList args, OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr)
: args(std::move(args)),
: args(args),
arg_shapes(std::move(arg_shapes)),
captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {}
};
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
int64 batch_size =
ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
OpInputList arguments;
TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
OpInputList captured_inputs;
TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
if (ctx->input(i).dims() == 0) {
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
for (size_t i = 0; i < arguments.size(); ++i) {
if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
} else if (ctx->input(i).dim_size(0) != batch_size) {
} else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel {
}
}
std::vector<Tensor> args;
std::vector<TensorShape> arg_shapes;
args.reserve(ctx->num_inputs());
arg_shapes.reserve(ctx->num_inputs());
arg_shapes.reserve(arguments.size());
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
args.push_back(ctx->input(i));
arg_shapes.push_back(ctx->input(i).shape());
for (size_t i = 0; i < arguments.size(); ++i) {
arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
*compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
batch_size, output_shapes_);
*compute_opts =
new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes),
batch_size, output_shapes_);
return Status::OK();
}
@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel {
}
Status GetArg(int index, Tensor* val) const override {
if (index < 0 || index >= compute_opts_->args.size()) {
if (index < 0 || index >= compute_opts_->args.size() +
compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
if (index >= compute_opts_->args.size()) {
// The function is calling for a captured input
*val =
compute_opts_->captured_inputs[index - compute_opts_->args.size()];
return Status::OK();
}
bool result =
val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
return Status::OK();
}

View File

@ -97,6 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
errors::Internal("Could not find handle ", handle),
done);
OP_REQUIRES_ASYNC(
ctx, args.size() == fbody->arg_nodes.size(),
errors::InvalidArgument(
"Wrong number of arguments to the op; function expects ",
fbody->arg_nodes.size(), " but PartitionedCall received ",
args.size()),
done);
// We need to pass global op_registry as default_registry when creating
// graph. So that graph optimization passes can lookup all possible ops
// by name.

View File

@ -30566,6 +30566,52 @@ op {
type: "func"
}
}
op {
name: "MapDefun"
input_arg {
name: "arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "captured_inputs"
type_list_attr: "Tcaptured"
}
output_arg {
name: "output"
type_list_attr: "output_types"
}
attr {
name: "Targuments"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "Tcaptured"
type: "list(type)"
default_value {
list {
}
}
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
has_minimum: true
minimum: 1
}
attr {
name: "output_shapes"
type: "list(shape)"
has_minimum: true
minimum: 1
}
attr {
name: "f"
type: "func"
}
}
op {
name: "MapIncompleteSize"
output_arg {
@ -71843,6 +71889,48 @@ op {
}
}
}
op {
name: "Substr"
input_arg {
name: "input"
type: DT_STRING
}
input_arg {
name: "pos"
type_attr: "T"
}
input_arg {
name: "len"
type_attr: "T"
}
output_arg {
name: "output"
type: DT_STRING
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "unit"
type: "string"
default_value {
s: "BYTE"
}
allowed_values {
list {
s: "BYTE"
s: "UTF8_CHAR"
}
}
}
}
op {
name: "Sum"
input_arg {

View File

@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset")
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
.Input("captured_inputs: Tcaptured")
.Output("output: output_types")
.Attr("Targuments: list(type) >= 1")
.Attr("Tcaptured: list(type) >= 0 = []")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
DataTypeVector t_args;
TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
@ -918,10 +922,11 @@ REGISTER_OP("MapDefun")
}
int64 dim_zero = -1;
for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
for (size_t i = 0; i < t_args.size(); ++i) {
if (c->Rank(c->input(i)) == 0) {
return errors::InvalidArgument(
"Inputs must have rank at least 1. Input ", i, " has rank of 0");
"Arguments must have rank at least 1. Input ", i,
" has rank of 0.");
}
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
@ -929,7 +934,7 @@ REGISTER_OP("MapDefun")
dim_zero = c->Value(dim_handle);
} else if (c->Value(dim_handle) != dim_zero) {
return errors::InvalidArgument(
"Inputs must have the same dimension 0.");
"Arguments must have the same dimension 0.");
}
}
}

View File

@ -15262,6 +15262,10 @@ op {
name: "arguments"
type_list_attr: "Targuments"
}
input_arg {
name: "captured_inputs"
type_list_attr: "Tcaptured"
}
output_arg {
name: "output"
type_list_attr: "output_types"
@ -15272,6 +15276,15 @@ op {
has_minimum: true
minimum: 1
}
attr {
name: "Tcaptured"
type: "list(type)"
default_value {
list {
}
}
has_minimum: true
}
attr {
name: "output_types"
type: "list(type)"
@ -33748,6 +33761,19 @@ op {
}
}
}
attr {
name: "unit"
type: "string"
default_value {
s: "BYTE"
}
allowed_values {
list {
s: "BYTE"
s: "UTF8_CHAR"
}
}
}
}
op {
name: "Sum"

View File

@ -0,0 +1,25 @@
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
name = "map_benchmark",
size = "medium",
srcs = ["map_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import hashlib
import itertools
import os
import time
import numpy as np
@ -27,128 +26,15 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
class MapDatasetTest(test_base.DatasetTestBase):
def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
dataset = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.check_numerics(x, "message")).apply(
error_ops.ignore_errors()))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testParallelMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
dataset = (
dataset_ops.Dataset.from_tensor_slices(components).map(
lambda x: array_ops.check_numerics(x, "message"),
num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testReadFileIgnoreError(self):
def write_string_to_file(value, filename):
with open(filename, "w") as f:
f.write(value)
filenames = [
os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
]
for filename in filenames:
write_string_to_file(filename, filename)
dataset = (
dataset_ops.Dataset.from_tensor_slices(filenames).map(
io_ops.read_file,
num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
# All of the files are present.
sess.run(init_op)
for filename in filenames:
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Delete one of the files.
os.remove(filenames[0])
# Attempting to read filenames[0] will fail, but ignore_errors()
# will catch the error.
sess.run(init_op)
for filename in filenames[1:]:
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testCaptureResourceInMapFn(self):
def _build_ds(iterator):
def _map_fn(x):
get_next = iterator.get_next()
return x * get_next
return dataset_ops.Dataset.range(10).map(_map_fn)
def _build_graph():
captured_iterator = dataset_ops.Dataset.range(
10).make_initializable_iterator()
ds = _build_ds(captured_iterator)
iterator = ds.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
return captured_iterator.initializer, init_op, get_next
with ops.Graph().as_default() as g:
captured_init_op, init_op, get_next = _build_graph()
with self.session(graph=g) as sess:
sess.run(captured_init_op)
sess.run(init_op)
for i in range(10):
self.assertEquals(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
class MapDatasetBenchmark(test.Benchmark):
# The purpose of this benchmark is to compare the performance of chaining vs

View File

@ -8,57 +8,16 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
name = "batch_dataset_op_test",
name = "bucket_by_sequence_length_test",
size = "medium",
srcs = ["batch_dataset_op_test.py"],
srcs = ["bucket_by_sequence_length_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss", # (b/79552534)
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "bucketing_test",
size = "medium",
srcs = ["bucketing_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
@ -67,16 +26,44 @@ py_test(
],
)
py_test(
name = "csv_dataset_op_test",
size = "medium",
srcs = ["csv_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
cuda_py_test(
name = "copy_to_device_test",
size = "small",
srcs = ["copy_to_device_test.py"],
additional_deps = [
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/compat:compat",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
tags = ["no_windows_gpu"],
)
py_test(
name = "counter_test",
size = "small",
srcs = ["counter_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python/data/experimental/ops:counter",
"//tensorflow/python/data/kernel_tests:test_base",
],
)
py_test(
name = "csv_dataset_test",
size = "medium",
srcs = ["csv_dataset_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@ -97,25 +84,18 @@ py_test(
)
py_test(
name = "dataset_constructor_op_test",
size = "medium",
srcs = ["dataset_constructor_op_test.py"],
name = "dense_to_sparse_batch_test",
srcs = ["dense_to_sparse_batch_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual",
"no_oss",
"no_pip",
"no_windows",
"nomac", # b/62040583
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
@ -124,11 +104,6 @@ py_test(
size = "medium",
srcs = ["directed_interleave_dataset_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@ -140,15 +115,68 @@ py_test(
],
)
py_test(
name = "enumerate_dataset_test",
size = "small",
srcs = ["enumerate_dataset_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/experimental/ops:enumerate_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_test(
name = "filter_dataset_op_test",
size = "medium",
srcs = ["filter_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
cuda_py_test(
name = "function_buffering_resource_test",
size = "small",
srcs = ["function_buffering_resource_test.py"],
additional_deps = [
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:function",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
tags = ["no_windows_gpu"],
)
py_test(
name = "get_single_element_test",
size = "small",
srcs = ["get_single_element_test.py"],
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -164,14 +192,69 @@ py_test(
],
)
py_test(
name = "group_by_reducer_test",
size = "medium",
srcs = ["group_by_reducer_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
name = "group_by_window_test",
size = "medium",
srcs = ["group_by_window_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
name = "ignore_errors_test",
srcs = ["ignore_errors_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:io_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:error_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
name = "indexed_dataset_ops_test",
srcs = ["indexed_dataset_ops_test.py"],
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -185,14 +268,134 @@ py_test(
)
py_test(
name = "interleave_dataset_op_test",
name = "make_batched_features_dataset_test",
size = "medium",
srcs = ["interleave_dataset_op_test.py"],
srcs = ["make_batched_features_dataset_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
py_test(
name = "make_csv_dataset_test",
size = "medium",
srcs = ["make_csv_dataset_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
py_test(
name = "make_tf_record_dataset_test",
size = "medium",
srcs = ["make_tf_record_dataset_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/util:nest",
],
)
py_test(
name = "map_and_batch_test",
size = "medium",
srcs = ["map_and_batch_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "map_defun_op_test",
size = "small",
srcs = ["map_defun_op_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/experimental/ops:map_defun",
"//tensorflow/python/data/kernel_tests:test_base",
],
)
py_test(
name = "override_threadpool_test",
size = "small",
srcs = ["override_threadpool_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:script_ops",
"//tensorflow/python/data/experimental/ops:threadpool",
"//tensorflow/python/data/experimental/ops:unique",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "parallel_interleave_test",
size = "medium",
srcs = ["parallel_interleave_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
"notap",
],
deps = [
@ -212,120 +415,10 @@ py_test(
)
py_test(
name = "iterator_ops_test",
name = "parse_example_dataset_test",
size = "small",
srcs = ["iterator_ops_test.py"],
srcs = ["parse_example_dataset_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/data/experimental/ops:iterator_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/estimator:estimator_py",
],
)
py_test(
name = "map_dataset_op_test",
size = "medium",
srcs = ["map_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
"noasan", # times out
"optonly",
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:error_ops",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
name = "filter_dataset_op_test",
size = "medium",
srcs = ["filter_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
name = "map_defun_op_test",
size = "small",
srcs = ["map_defun_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/experimental/ops:map_defun",
"//tensorflow/python/data/kernel_tests:test_base",
],
)
py_test(
name = "parsing_ops_test",
size = "small",
srcs = ["parsing_ops_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
@ -344,53 +437,20 @@ py_test(
)
cuda_py_test(
name = "prefetching_ops_test",
name = "prefetch_to_device_test",
size = "small",
srcs = ["prefetching_ops_test.py"],
srcs = ["prefetch_to_device_test.py"],
additional_deps = [
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:function",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/compat:compat",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
tags = [
"no_oss",
"no_pip",
"no_windows",
"no_windows_gpu",
],
)
py_test(
name = "range_dataset_op_test",
size = "small",
srcs = ["range_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/experimental/ops:counter",
"//tensorflow/python/data/experimental/ops:enumerate_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
tags = ["no_windows_gpu"],
)
py_library(
@ -421,41 +481,12 @@ py_library(
)
py_test(
name = "reader_dataset_ops_test",
name = "rejection_resample_test",
size = "medium",
srcs = ["reader_dataset_ops_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
":reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
py_test(
name = "resample_test",
size = "medium",
srcs = ["resample_test.py"],
srcs = ["rejection_resample_test.py"],
shard_count = 2,
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
"noasan",
"optonly",
],
@ -477,15 +508,27 @@ py_test(
)
py_test(
name = "scan_dataset_op_test",
size = "small",
srcs = ["scan_dataset_op_test.py"],
name = "restructured_dataset_test",
size = "medium",
srcs = ["restructured_dataset_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
],
)
py_test(
name = "scan_test",
size = "small",
srcs = ["scan_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -503,14 +546,12 @@ py_test(
)
py_test(
name = "shuffle_dataset_op_test",
name = "shuffle_and_repeat_test",
size = "medium",
srcs = ["shuffle_dataset_op_test.py"],
srcs = ["shuffle_and_repeat_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
"optonly",
],
deps = [
@ -525,8 +566,8 @@ py_test(
)
py_library(
name = "sql_dataset_op_test_base",
srcs = ["sql_dataset_op_test_base.py"],
name = "sql_dataset_test_base",
srcs = ["sql_dataset_test_base.py"],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow/python/data/experimental/kernel_tests:__pkg__",
@ -543,17 +584,13 @@ py_library(
)
py_test(
name = "sql_dataset_op_test",
name = "sql_dataset_test",
size = "small",
srcs = ["sql_dataset_op_test.py"],
srcs = ["sql_dataset_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
tags = ["no_pip"],
deps = [
":sql_dataset_op_test_base",
":sql_dataset_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@ -565,11 +602,7 @@ py_test(
size = "medium",
srcs = ["stats_dataset_ops_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
tags = ["no_pip"],
deps = [
":reader_dataset_ops_test_base",
":stats_dataset_test_base",
@ -595,59 +628,9 @@ py_library(
)
py_test(
name = "threadpool_dataset_ops_test",
name = "tf_record_writer_test",
size = "small",
srcs = ["threadpool_dataset_ops_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:script_ops",
"//tensorflow/python/data/experimental/ops:threadpool",
"//tensorflow/python/data/experimental/ops:unique",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "unique_dataset_op_test",
size = "small",
srcs = ["unique_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
"no_windows",
],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:unique",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_test(
name = "writer_ops_test",
size = "small",
srcs = ["writer_ops_test.py"],
tags = [
"no_oss",
"no_pip",
"no_windows",
],
srcs = ["tf_record_writer_test.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -660,3 +643,45 @@ py_test(
"//tensorflow/python/data/ops:readers",
],
)
py_test(
name = "unbatch_test",
size = "medium",
srcs = ["unbatch_test.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "unique_test",
size = "small",
srcs = ["unique_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:unique",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)

View File

@ -1,686 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import time
from absl.testing import parameterized
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([x], x)).apply(
batching.dense_to_sparse_batch(4, [12]))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
results = sess.run(get_next)
self.assertAllEqual([[i, j]
for i, c in enumerate(components[start:start + 4])
for j in range(c)], results.indices)
self.assertAllEqual(
[c for c in components[start:start + 4] for _ in range(c)],
results.values)
self.assertAllEqual([min(4,
len(components) - start), 12],
results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testDenseToSparseBatchDatasetWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([x, x], x)).apply(
batching.dense_to_sparse_batch(
4, [5, None])).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
results = sess.run(get_next)
self.assertAllEqual([[i, j, z]
for i, c in enumerate(components[start:start + 4])
for j in range(c)
for z in range(c)], results.indices)
self.assertAllEqual([
c
for c in components[start:start + 4] for _ in range(c)
for _ in range(c)
], results.values)
self.assertAllEqual([
min(4,
len(components) - start), 5,
np.max(components[start:start + 4])
], results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
dataset_ops.Dataset.from_tensors(input_tensor).apply(
batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
def testDenseToSparseBatchDatasetShapeErrors(self):
input_tensor = array_ops.placeholder(dtypes.int32)
iterator = (
dataset_ops.Dataset.from_tensors(input_tensor).apply(
batching.dense_to_sparse_batch(4, [12]))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
# Initialize with an input tensor of incompatible rank.
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"incompatible with the row shape"):
sess.run(get_next)
# Initialize with an input tensor that is larger than `row_shape`.
sess.run(init_op, feed_dict={input_tensor: range(13)})
with self.assertRaisesRegexp(errors.DataLossError,
"larger than the row shape"):
sess.run(get_next)
def testUnbatchWithUnknownRankInput(self):
placeholder = array_ops.placeholder(dtypes.int32)
dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
batching.unbatch())
iterator = dataset.make_initializable_iterator()
next_elem = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
for i in range(4):
self.assertEqual(i, sess.run(next_elem))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_elem)
def testUnbatchScalarDataset(self):
data = tuple([math_ops.range(10) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
expected_types = (dtypes.int32,) * 3
data = data.batch(2)
self.assertEqual(expected_types, data.output_types)
data = data.apply(batching.unbatch())
self.assertEqual(expected_types, data.output_types)
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testUnbatchDatasetWithStrings(self):
data = tuple([math_ops.range(10) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
data = data.batch(2)
self.assertEqual(expected_types, data.output_types)
data = data.apply(batching.unbatch())
self.assertEqual(expected_types, data.output_types)
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testUnbatchDatasetWithSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
indices=[[i, i] for i in range(10)],
values=list(range(10)),
dense_shape=[10, 10])
data = dataset_ops.Dataset.from_tensors(st)
data = data.apply(batching.unbatch())
data = data.batch(5)
data = data.apply(batching.unbatch())
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testUnbatchDatasetWithDenseAndSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
indices=[[i, i] for i in range(10)],
values=list(range(10)),
dense_shape=[10, 10])
data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
data = data.apply(batching.unbatch())
data = data.batch(5)
data = data.apply(batching.unbatch())
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testUnbatchSingleElementTupleDataset(self):
data = tuple([(math_ops.range(10),) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
expected_types = ((dtypes.int32,),) * 3
data = data.batch(2)
self.assertEqual(expected_types, data.output_types)
data = data.apply(batching.unbatch())
self.assertEqual(expected_types, data.output_types)
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testUnbatchMultiElementTupleDataset(self):
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
array_ops.fill([10], "hi")) for i in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
expected_types = ((dtypes.int32, dtypes.string),) * 3
data = data.batch(2)
self.assertAllEqual(expected_types, data.output_types)
data = data.apply(batching.unbatch())
self.assertAllEqual(expected_types, data.output_types)
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testUnbatchEmpty(self):
data = dataset_ops.Dataset.from_tensors(
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
constant_op.constant([], shape=[0, 4, 0])))
data = data.apply(batching.unbatch())
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testUnbatchStaticShapeMismatch(self):
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
np.arange(9)))
with self.assertRaises(ValueError):
data.apply(batching.unbatch())
def testUnbatchDynamicShapeMismatch(self):
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
data = data.apply(batching.unbatch())
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
# Mismatch in the 0th dimension.
sess.run(
iterator.initializer,
feed_dict={
ph1: np.arange(7).astype(np.int32),
ph2: np.arange(8).astype(np.int32)
})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(next_element)
# No 0th dimension (i.e. scalar value) for one component.
sess.run(
iterator.initializer,
feed_dict={
ph1: np.arange(7).astype(np.int32),
ph2: 7
})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(next_element)
@parameterized.named_parameters(
("Default", None, None),
("SequentialCalls", 1, None),
("ParallelCalls", 2, None),
("ParallelBatches", None, 10),
)
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset ->
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7))
count = array_ops.placeholder(dtypes.int64, shape=[])
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
batching.map_and_batch(
map_func=_map_fn,
batch_size=batch_size,
num_parallel_calls=num_parallel_calls,
num_parallel_batches=num_parallel_batches))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
with self.cached_session() as sess:
# Batch of a finite input, where the batch_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
num_batches = (28 * 7) // 14
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Batch of a finite input, where the batch_size does not
# divide the total number of elements.
sess.run(init_op, feed_dict={count: 14, batch_size: 8})
# We expect (num_batches - 1) full-sized batches.
num_batches = int(math.ceil((14 * 7) / 8))
for i in range(num_batches - 1):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Batch of an empty input should fail straight away.
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Empty batch should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
("Even", False),
("Uneven", True),
)
def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(
lambda x: array_ops.reshape(x * x, [1]),
batch_size=4,
drop_remainder=drop_remainder)).make_one_shot_iterator())
if drop_remainder:
self.assertEqual([4, 1], iterator.output_shapes.as_list())
else:
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
self.assertAllEqual([[64], [81]], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testMapAndBatchYieldsPartialBatch(self):
iterator = (dataset_ops.Dataset.range(10)
.apply(batching.map_and_batch(
lambda x: array_ops.reshape(x * x, [1]), 4))
.make_one_shot_iterator())
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testMapAndBatchParallelGetNext(self):
iterator = (dataset_ops.Dataset.range(50000)
.apply(batching.map_and_batch(lambda x: x, batch_size=100))
.make_one_shot_iterator())
elements = []
for _ in range(100):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(5):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
def testMapAndBatchParallelGetNextDropRemainder(self):
iterator = (
dataset_ops.Dataset.range(49999).apply(
batching.map_and_batch(
lambda x: x, batch_size=100, drop_remainder=True))
.make_one_shot_iterator())
elements = []
for _ in range(100):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(4):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
def testMapAndBatchSparse(self):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])
iterator = dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
dense_shape=[5, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testMapAndBatchFails(self):
"""Test a dataset that maps a TF function across its input elements."""
dataset = dataset_ops.Dataset.from_tensors(
array_ops.check_numerics(
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
def testMapAndBatchShapeMismatch(self):
"""Test a dataset that maps a TF function across its input elements."""
def generator():
yield [1]
yield [2]
yield [3]
yield [[4, 5, 6]]
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int32)
batch_size = 4
iterator = (
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
sess.run(get_next)
def testMapAndBatchImplicitDispose(self):
# Tests whether a map and batch dataset will be cleaned up correctly when
# the pipeline does not run it until exhaustion.
# The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
# MapAndBatchDataset(f=square_3, batch_size=100).
components = (np.arange(1000),
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
np.array(37.0) * np.arange(1000))
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
dataset = dataset.prefetch(5)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
for _ in range(3):
sess.run(get_next)
@parameterized.named_parameters(
("1", 0),
("2", 5),
("3", 10),
("4", 90),
("5", 95),
("6", 99),
)
def testMapAndBatchOutOfRangeError(self, threshold):
def raising_py_fn(i):
if i >= threshold:
raise StopIteration()
else:
return i
iterator = (
dataset_ops.Dataset.range(100).apply(
batching.map_and_batch(
lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
batch_size=10)).make_one_shot_iterator())
get_next = iterator.get_next()
with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
self.assertAllEqual(
[threshold // 10 * 10 + j for j in range(threshold % 10)],
sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@parameterized.named_parameters(
("1", False, dtypes.bool),
("2", -42, dtypes.int8),
("3", -42, dtypes.int16),
("4", -42, dtypes.int32),
("5", -42, dtypes.int64),
("6", 42, dtypes.uint8),
("7", 42, dtypes.uint16),
("8", 42.0, dtypes.float16),
("9", 42.0, dtypes.float32),
("10", 42.0, dtypes.float64),
("11", b"hello", dtypes.string),
)
def testMapAndBatchTypes(self, element, dtype):
def gen():
yield element
dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
batching.map_and_batch(lambda x: x, batch_size=10))
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
class UnbatchDatasetBenchmark(test.Benchmark):
def benchmarkNativeUnbatch(self):
batch_sizes = [1, 2, 5, 10, 20, 50]
elems_per_trial = 10000
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
dataset = dataset.batch(batch_size_placeholder)
dataset = dataset.apply(batching.unbatch())
dataset = dataset.skip(elems_per_trial)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with session.Session() as sess:
for batch_size in batch_sizes:
deltas = []
for _ in range(5):
sess.run(
iterator.initializer,
feed_dict={batch_size_placeholder: batch_size})
start = time.time()
sess.run(next_element.op)
end = time.time()
deltas.append((end - start) / elems_per_trial)
median_wall_time = np.median(deltas)
print("Unbatch (native) batch size: %d Median wall time per element:"
" %f microseconds" % (batch_size, median_wall_time * 1e6))
self.report_benchmark(
iters=10000,
wall_time=median_wall_time,
name="benchmark_unbatch_dataset_native_batch_size_%d" %
batch_size)
# Include a benchmark of the previous `unbatch()` implementation that uses
# a composition of more primitive ops. Eventually we'd hope to generate code
# that is as good in both cases.
def benchmarkOldUnbatchImplementation(self):
batch_sizes = [1, 2, 5, 10, 20, 50]
elems_per_trial = 10000
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
dataset = dataset.batch(batch_size_placeholder)
dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
dataset = dataset.skip(elems_per_trial)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with session.Session() as sess:
for batch_size in batch_sizes:
deltas = []
for _ in range(5):
sess.run(
iterator.initializer,
feed_dict={batch_size_placeholder: batch_size})
start = time.time()
sess.run(next_element.op)
end = time.time()
deltas.append((end - start) / elems_per_trial)
median_wall_time = np.median(deltas)
print("Unbatch (unfused) batch size: %d Median wall time per element:"
" %f microseconds" % (batch_size, median_wall_time * 1e6))
self.report_benchmark(
iters=10000,
wall_time=median_wall_time,
name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
batch_size)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,322 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.bucket_by_sequence_length()."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
def _element_length_fn(x, y=None):
del y
return array_ops.shape(x)[0]
def _to_sparse_tensor(record):
return sparse_tensor.SparseTensor(**record)
def _format_record(array, sparse):
if sparse:
return {
"values": array,
"indices": [[i] for i in range(len(array))],
"dense_shape": (len(array),)
}
return array
def _get_record_type(sparse):
if sparse:
return {
"values": dtypes.int64,
"indices": dtypes.int64,
"dense_shape": dtypes.int64
}
return dtypes.int32
def _get_record_shape(sparse):
if sparse:
return {
"values": tensor_shape.TensorShape([None,]),
"indices": tensor_shape.TensorShape([None, 1]),
"dense_shape": tensor_shape.TensorShape([1,])
}
return tensor_shape.TensorShape([None])
class BucketBySequenceLengthTest(test_base.DatasetTestBase):
def testBucket(self):
boundaries = [10, 20, 30]
batch_sizes = [10, 8, 4, 2]
lengths = [8, 13, 25, 35]
def build_dataset(sparse):
def _generator():
# Produce 1 batch for each bucket
elements = []
for batch_size, length in zip(batch_sizes, lengths):
record_len = length - 1
for _ in range(batch_size):
elements.append([1] * record_len)
record_len = length
random.shuffle(elements)
for el in elements:
yield (_format_record(el, sparse),)
dataset = dataset_ops.Dataset.from_generator(
_generator,
(_get_record_type(sparse),),
(_get_record_shape(sparse),))
if sparse:
dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
return dataset
def _test_bucket_by_padding(no_padding):
dataset = build_dataset(sparse=no_padding)
dataset = dataset.apply(
grouping.bucket_by_sequence_length(
_element_length_fn,
boundaries,
batch_sizes,
no_padding=no_padding))
batch, = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
batches = []
for _ in range(4):
batches.append(sess.run(batch))
with self.assertRaises(errors.OutOfRangeError):
sess.run(batch)
batch_sizes_val = []
lengths_val = []
for batch in batches:
shape = batch.dense_shape if no_padding else batch.shape
batch_size = shape[0]
length = shape[1]
batch_sizes_val.append(batch_size)
lengths_val.append(length)
sum_check = batch.values.sum() if no_padding else batch.sum()
self.assertEqual(sum_check, batch_size * length - 1)
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
self.assertEqual(sorted(lengths), sorted(lengths_val))
for no_padding in (True, False):
_test_bucket_by_padding(no_padding)
def testPadToBoundary(self):
boundaries = [10, 20, 30]
batch_sizes = [10, 8, 4, 2]
lengths = [8, 13, 25]
def element_gen():
# Produce 1 batch for each bucket
elements = []
for batch_size, length in zip(batch_sizes[:-1], lengths):
for _ in range(batch_size):
elements.append([1] * length)
random.shuffle(elements)
for el in elements:
yield (el,)
for _ in range(batch_sizes[-1]):
el = [1] * (boundaries[-1] + 5)
yield (el,)
element_len = lambda el: array_ops.shape(el)[0]
dataset = dataset_ops.Dataset.from_generator(
element_gen, (dtypes.int64,), ([None],)).apply(
grouping.bucket_by_sequence_length(
element_len, boundaries, batch_sizes,
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
batches = []
for _ in range(3):
batches.append(sess.run(batch))
with self.assertRaisesOpError("bucket_boundaries"):
sess.run(batch)
batch_sizes_val = []
lengths_val = []
for batch in batches:
batch_size = batch.shape[0]
length = batch.shape[1]
batch_sizes_val.append(batch_size)
lengths_val.append(length)
batch_sizes = batch_sizes[:-1]
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
sorted(lengths_val))
def testPadToBoundaryNoExtraneousPadding(self):
boundaries = [3, 7, 11]
batch_sizes = [2, 2, 2, 2]
lengths = range(1, 11)
def element_gen():
for length in lengths:
yield ([1] * length,)
element_len = lambda element: array_ops.shape(element)[0]
dataset = dataset_ops.Dataset.from_generator(
element_gen, (dtypes.int64,), ([None],)).apply(
grouping.bucket_by_sequence_length(
element_len, boundaries, batch_sizes,
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
batches = []
for _ in range(5):
batches.append(sess.run(batch))
with self.assertRaises(errors.OutOfRangeError):
sess.run(batch)
self.assertAllEqual(batches[0], [[1, 0],
[1, 1]])
self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0]])
self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]])
self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
def testTupleElements(self):
def build_dataset(sparse):
def _generator():
text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
label = [1, 2, 1, 2]
for x, y in zip(text, label):
yield (_format_record(x, sparse), y)
dataset = dataset_ops.Dataset.from_generator(
generator=_generator,
output_types=(_get_record_type(sparse), dtypes.int32),
output_shapes=(_get_record_shape(sparse),
tensor_shape.TensorShape([])))
if sparse:
dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
return dataset
def _test_tuple_elements_by_padding(no_padding):
dataset = build_dataset(sparse=no_padding)
dataset = dataset.apply(grouping.bucket_by_sequence_length(
element_length_func=_element_length_fn,
bucket_batch_sizes=[2, 2, 2],
bucket_boundaries=[0, 8],
no_padding=no_padding))
shapes = dataset.output_shapes
self.assertEqual([None, None], shapes[0].as_list())
self.assertEqual([None], shapes[1].as_list())
for no_padding in (True, False):
_test_tuple_elements_by_padding(no_padding)
def testBucketSparse(self):
"""Tests bucketing of sparse tensors (case where `no_padding` == True).
Test runs on following dataset:
[
[0],
[0, 1],
[0, 1, 2]
...
[0, ..., max_len - 1]
]
Sequences are bucketed by length and batched with
`batch_size` < `bucket_size`.
"""
min_len = 0
max_len = 100
batch_size = 7
bucket_size = 10
def _build_dataset():
input_data = [range(i+1) for i in range(min_len, max_len)]
def generator_fn():
for record in input_data:
yield _format_record(record, sparse=True)
dataset = dataset_ops.Dataset.from_generator(
generator=generator_fn,
output_types=_get_record_type(sparse=True))
dataset = dataset.map(_to_sparse_tensor)
return dataset
def _compute_expected_batches():
"""Computes expected batch outputs and stores in a set."""
all_expected_sparse_tensors = set()
for bucket_start_len in range(min_len, max_len, bucket_size):
for batch_offset in range(0, bucket_size, batch_size):
batch_start_len = bucket_start_len + batch_offset
batch_end_len = min(batch_start_len + batch_size,
bucket_start_len + bucket_size)
expected_indices = []
expected_values = []
for length in range(batch_start_len, batch_end_len):
for val in range(length + 1):
expected_indices.append((length - batch_start_len, val))
expected_values.append(val)
expected_sprs_tensor = (tuple(expected_indices),
tuple(expected_values))
all_expected_sparse_tensors.add(expected_sprs_tensor)
return all_expected_sparse_tensors
def _compute_batches(dataset):
"""Computes actual batch outputs of dataset and stores in a set."""
batch = dataset.make_one_shot_iterator().get_next()
all_sparse_tensors = set()
with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
while True:
output = sess.run(batch)
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
tuple(output.values))
all_sparse_tensors.add(sprs_tensor)
return all_sparse_tensors
dataset = _build_dataset()
boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
dataset = dataset.apply(grouping.bucket_by_sequence_length(
_element_length_fn,
boundaries,
[batch_size] * (len(boundaries) + 1),
no_padding=True))
batches = _compute_batches(dataset)
expected_batches = _compute_expected_batches()
self.assertEqual(batches, expected_batches)
if __name__ == "__main__":
test.main()

View File

@ -1,824 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class GroupByReducerTest(test_base.DatasetTestBase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
for expected in values:
got = sess.run(get_next)
self.assertEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSum(self):
reducer = grouping.Reducer(
init_func=lambda _: np.int64(0),
reduce_func=lambda x, y: x + y,
finalize_func=lambda x: x)
for i in range(1, 11):
dataset = dataset_ops.Dataset.range(2 * i).apply(
grouping.group_by_reducer(lambda x: x % 2, reducer))
self.checkResults(
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
def testAverage(self):
def reduce_fn(x, y):
return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
x[1] + 1), x[1] + 1
reducer = grouping.Reducer(
init_func=lambda _: (0.0, 0.0),
reduce_func=reduce_fn,
finalize_func=lambda x, _: x)
for i in range(1, 11):
dataset = dataset_ops.Dataset.range(2 * i).apply(
grouping.group_by_reducer(
lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
self.checkResults(
dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
def testConcat(self):
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
reducer = grouping.Reducer(
init_func=lambda x: "",
reduce_func=lambda x, y: x + y[0],
finalize_func=lambda x: x)
for i in range(1, 11):
dataset = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensor_slices(components),
dataset_ops.Dataset.range(2 * i))).apply(
grouping.group_by_reducer(lambda x, y: y % 2, reducer))
self.checkResults(
dataset,
shapes=tensor_shape.scalar(),
values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
def testSparseSum(self):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=np.array([[0, 0]]),
values=(i * np.array([1], dtype=np.int64)),
dense_shape=np.array([1, 1]))
reducer = grouping.Reducer(
init_func=lambda _: _sparse(np.int64(0)),
reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
finalize_func=lambda x: x.values[0])
for i in range(1, 11):
dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
self.checkResults(
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
def testChangingStateShape(self):
def reduce_fn(x, _):
# Statically known rank, but dynamic length.
larger_dim = array_ops.concat([x[0], x[0]], 0)
# Statically unknown rank.
larger_rank = array_ops.expand_dims(x[1], 0)
return larger_dim, larger_rank
reducer = grouping.Reducer(
init_func=lambda x: ([0], 1),
reduce_func=reduce_fn,
finalize_func=lambda x, y: (x, y))
for i in range(1, 11):
dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
grouping.group_by_reducer(lambda x: x, reducer))
self.assertEqual([None], dataset.output_shapes[0].as_list())
self.assertIs(None, dataset.output_shapes[1].ndims)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testTypeMismatch(self):
reducer = grouping.Reducer(
init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
finalize_func=lambda x: x)
dataset = dataset_ops.Dataset.range(10)
with self.assertRaisesRegexp(
TypeError,
"The element types for the new state must match the initial state."):
dataset.apply(
grouping.group_by_reducer(lambda _: np.int64(0), reducer))
# TODO(b/78665031): Remove once non-scalar keys are supported.
def testInvalidKeyShape(self):
reducer = grouping.Reducer(
init_func=lambda x: np.int64(0),
reduce_func=lambda x, y: x + y,
finalize_func=lambda x: x)
dataset = dataset_ops.Dataset.range(10)
with self.assertRaisesRegexp(
ValueError, "`key_func` must return a single tf.int64 tensor."):
dataset.apply(
grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
# TODO(b/78665031): Remove once non-int64 keys are supported.
def testInvalidKeyType(self):
reducer = grouping.Reducer(
init_func=lambda x: np.int64(0),
reduce_func=lambda x, y: x + y,
finalize_func=lambda x: x)
dataset = dataset_ops.Dataset.range(10)
with self.assertRaisesRegexp(
ValueError, "`key_func` must return a single tf.int64 tensor."):
dataset.apply(
grouping.group_by_reducer(lambda _: "wrong", reducer))
def testTuple(self):
def init_fn(_):
return np.array([], dtype=np.int64), np.int64(0)
def reduce_fn(state, value):
s1, s2 = state
v1, v2 = value
return array_ops.concat([s1, [v1]], 0), s2 + v2
def finalize_fn(s1, s2):
return s1, s2
reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
dataset = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)
class GroupByWindowTest(test_base.DatasetTestBase):
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
.apply(
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
4)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
result = sess.run(get_next)
self.assertTrue(
all(x % 2 == 0
for x in result) or all(x % 2 == 1)
for x in result)
counts.append(result.shape[0])
self.assertEqual(len(components), sum(counts))
num_full_batches = len([c for c in counts if c == 4])
self.assertGreaterEqual(num_full_batches, 24)
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
def testImmediateOutput(self):
components = np.array(
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
4)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
# 2. Different buckets can produce output at different rates, and
# 3. For deterministic input, the output is deterministic.
for _ in range(3):
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).apply(
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
4)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
# The small outputs at the end are deterministically produced in key
# order.
self.assertAllEqual([0, 0, 0], sess.run(get_next))
self.assertAllEqual([1], sess.run(get_next))
def testEmpty(self):
iterator = (
dataset_ops.Dataset.range(4).apply(
grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Window size must be greater than zero, but got 0."):
print(sess.run(get_next))
def testReduceFuncError(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
def reduce_func(_, xs):
# Introduce an incorrect padded shape that cannot (currently) be
# detected at graph construction time.
return xs.padded_batch(
4,
padded_shapes=(tensor_shape.TensorShape([]),
constant_op.constant([5], dtype=dtypes.int64) * -1))
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
grouping.group_by_window(lambda x, _: x % 2, reduce_func,
32)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
def testConsumeWindowDatasetMoreThanOnce(self):
components = np.random.randint(50, size=(200,)).astype(np.int64)
def reduce_func(key, window):
# Apply two different kinds of padding to the input: tight
# padding, and quantized (to a multiple of 10) padding.
return dataset_ops.Dataset.zip((
window.padded_batch(
4, padded_shapes=tensor_shape.TensorShape([None])),
window.padded_batch(
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
))
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
.apply(grouping.group_by_window(
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
reduce_func, 4))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
tight_result, multiple_of_10_result = sess.run(get_next)
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
self.assertAllEqual(tight_result,
multiple_of_10_result[:, :tight_result.shape[1]])
counts.append(tight_result.shape[0])
self.assertEqual(len(components), sum(counts))
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
class BucketTest(test_base.DatasetTestBase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
# generic form of padded_batch that pads every component
# dynamically and does not rely on static shape information about
# the arguments.
return dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensors(bucket),
window.padded_batch(
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
[None]), tensor_shape.TensorShape([3])))))
def testSingleBucket(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
bucketed_dataset = input_dataset.apply(
grouping.group_by_window(
lambda x, y, z: 0,
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
iterator = bucketed_dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
which_bucket, bucketed_values = sess.run(get_next)
self.assertEqual(0, which_bucket)
expected_scalar_int = np.arange(32, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
for i in range(32):
expected_unk_int64[i, :i] = i
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values[0])
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
def testEvenOddBuckets(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
bucketed_dataset = input_dataset.apply(
grouping.group_by_window(
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
iterator = bucketed_dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
which_bucket_even, bucketed_values_even = sess.run(get_next)
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
# Count number of bucket_tensors.
self.assertEqual(3, len(bucketed_values_even))
self.assertEqual(3, len(bucketed_values_odd))
# Ensure bucket 0 was used for all minibatch entries.
self.assertAllEqual(0, which_bucket_even)
self.assertAllEqual(1, which_bucket_odd)
# Test the first bucket outputted, the events starting at 0
expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
for i in range(0, 32):
expected_unk_int64[i, :2 * i] = 2 * i
expected_vec3_str = np.vstack(
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
# Test the second bucket outputted, the odds starting at 1
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
for i in range(0, 32):
expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
expected_vec3_str = np.vstack(
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
def testEvenOddBucketsFilterOutAllOdd(self):
def _map_fn(v):
return {
"x": v,
"y": array_ops.fill([v], v),
"z": array_ops.fill([3], string_ops.as_string(v))
}
def _dynamic_pad_fn(bucket, window, _):
return dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensors(bucket),
window.padded_batch(
32, {
"x": tensor_shape.TensorShape([]),
"y": tensor_shape.TensorShape([None]),
"z": tensor_shape.TensorShape([3])
})))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
bucketed_dataset = input_dataset.apply(
grouping.group_by_window(
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
iterator = bucketed_dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
which_bucket0, bucketed_values_even0 = sess.run(get_next)
which_bucket1, bucketed_values_even1 = sess.run(get_next)
# Ensure that bucket 1 was completely filtered out
self.assertAllEqual(0, which_bucket0)
self.assertAllEqual(0, which_bucket1)
self.assertAllEqual(
np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
self.assertAllEqual(
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
def testDynamicWindowSize(self):
components = np.arange(100).astype(np.int64)
# Key fn: even/odd
# Reduce fn: batches of 5
# Window size fn: even=5, odd=10
def window_size_func(key):
window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
return window_sizes[key]
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
None, window_size_func))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
while True:
result = sess.run(get_next)
is_even = all(x % 2 == 0 for x in result)
is_odd = all(x % 2 == 1 for x in result)
self.assertTrue(is_even or is_odd)
expected_batch_size = 5 if is_even else 10
self.assertEqual(expected_batch_size, result.shape[0])
batches += 1
self.assertEqual(batches, 15)
def _element_length_fn(x, y=None):
del y
return array_ops.shape(x)[0]
def _to_sparse_tensor(record):
return sparse_tensor.SparseTensor(**record)
def _format_record(array, sparse):
if sparse:
return {
"values": array,
"indices": [[i] for i in range(len(array))],
"dense_shape": (len(array),)
}
return array
def _get_record_type(sparse):
if sparse:
return {
"values": dtypes.int64,
"indices": dtypes.int64,
"dense_shape": dtypes.int64
}
return dtypes.int32
def _get_record_shape(sparse):
if sparse:
return {
"values": tensor_shape.TensorShape([None,]),
"indices": tensor_shape.TensorShape([None, 1]),
"dense_shape": tensor_shape.TensorShape([1,])
}
return tensor_shape.TensorShape([None])
class BucketBySequenceLength(test_base.DatasetTestBase):
def testBucket(self):
boundaries = [10, 20, 30]
batch_sizes = [10, 8, 4, 2]
lengths = [8, 13, 25, 35]
def build_dataset(sparse):
def _generator():
# Produce 1 batch for each bucket
elements = []
for batch_size, length in zip(batch_sizes, lengths):
record_len = length - 1
for _ in range(batch_size):
elements.append([1] * record_len)
record_len = length
random.shuffle(elements)
for el in elements:
yield (_format_record(el, sparse),)
dataset = dataset_ops.Dataset.from_generator(
_generator,
(_get_record_type(sparse),),
(_get_record_shape(sparse),))
if sparse:
dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
return dataset
def _test_bucket_by_padding(no_padding):
dataset = build_dataset(sparse=no_padding)
dataset = dataset.apply(
grouping.bucket_by_sequence_length(
_element_length_fn,
boundaries,
batch_sizes,
no_padding=no_padding))
batch, = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
batches = []
for _ in range(4):
batches.append(sess.run(batch))
with self.assertRaises(errors.OutOfRangeError):
sess.run(batch)
batch_sizes_val = []
lengths_val = []
for batch in batches:
shape = batch.dense_shape if no_padding else batch.shape
batch_size = shape[0]
length = shape[1]
batch_sizes_val.append(batch_size)
lengths_val.append(length)
sum_check = batch.values.sum() if no_padding else batch.sum()
self.assertEqual(sum_check, batch_size * length - 1)
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
self.assertEqual(sorted(lengths), sorted(lengths_val))
for no_padding in (True, False):
_test_bucket_by_padding(no_padding)
def testPadToBoundary(self):
boundaries = [10, 20, 30]
batch_sizes = [10, 8, 4, 2]
lengths = [8, 13, 25]
def element_gen():
# Produce 1 batch for each bucket
elements = []
for batch_size, length in zip(batch_sizes[:-1], lengths):
for _ in range(batch_size):
elements.append([1] * length)
random.shuffle(elements)
for el in elements:
yield (el,)
for _ in range(batch_sizes[-1]):
el = [1] * (boundaries[-1] + 5)
yield (el,)
element_len = lambda el: array_ops.shape(el)[0]
dataset = dataset_ops.Dataset.from_generator(
element_gen, (dtypes.int64,), ([None],)).apply(
grouping.bucket_by_sequence_length(
element_len, boundaries, batch_sizes,
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
batches = []
for _ in range(3):
batches.append(sess.run(batch))
with self.assertRaisesOpError("bucket_boundaries"):
sess.run(batch)
batch_sizes_val = []
lengths_val = []
for batch in batches:
batch_size = batch.shape[0]
length = batch.shape[1]
batch_sizes_val.append(batch_size)
lengths_val.append(length)
batch_sizes = batch_sizes[:-1]
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
sorted(lengths_val))
def testPadToBoundaryNoExtraneousPadding(self):
boundaries = [3, 7, 11]
batch_sizes = [2, 2, 2, 2]
lengths = range(1, 11)
def element_gen():
for length in lengths:
yield ([1] * length,)
element_len = lambda element: array_ops.shape(element)[0]
dataset = dataset_ops.Dataset.from_generator(
element_gen, (dtypes.int64,), ([None],)).apply(
grouping.bucket_by_sequence_length(
element_len, boundaries, batch_sizes,
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
batches = []
for _ in range(5):
batches.append(sess.run(batch))
with self.assertRaises(errors.OutOfRangeError):
sess.run(batch)
self.assertAllEqual(batches[0], [[1, 0],
[1, 1]])
self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0]])
self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]])
self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
def testTupleElements(self):
def build_dataset(sparse):
def _generator():
text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
label = [1, 2, 1, 2]
for x, y in zip(text, label):
yield (_format_record(x, sparse), y)
dataset = dataset_ops.Dataset.from_generator(
generator=_generator,
output_types=(_get_record_type(sparse), dtypes.int32),
output_shapes=(_get_record_shape(sparse),
tensor_shape.TensorShape([])))
if sparse:
dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
return dataset
def _test_tuple_elements_by_padding(no_padding):
dataset = build_dataset(sparse=no_padding)
dataset = dataset.apply(grouping.bucket_by_sequence_length(
element_length_func=_element_length_fn,
bucket_batch_sizes=[2, 2, 2],
bucket_boundaries=[0, 8],
no_padding=no_padding))
shapes = dataset.output_shapes
self.assertEqual([None, None], shapes[0].as_list())
self.assertEqual([None], shapes[1].as_list())
for no_padding in (True, False):
_test_tuple_elements_by_padding(no_padding)
def testBucketSparse(self):
"""Tests bucketing of sparse tensors (case where `no_padding` == True).
Test runs on following dataset:
[
[0],
[0, 1],
[0, 1, 2]
...
[0, ..., max_len - 1]
]
Sequences are bucketed by length and batched with
`batch_size` < `bucket_size`.
"""
min_len = 0
max_len = 100
batch_size = 7
bucket_size = 10
def _build_dataset():
input_data = [range(i+1) for i in range(min_len, max_len)]
def generator_fn():
for record in input_data:
yield _format_record(record, sparse=True)
dataset = dataset_ops.Dataset.from_generator(
generator=generator_fn,
output_types=_get_record_type(sparse=True))
dataset = dataset.map(_to_sparse_tensor)
return dataset
def _compute_expected_batches():
"""Computes expected batch outputs and stores in a set."""
all_expected_sparse_tensors = set()
for bucket_start_len in range(min_len, max_len, bucket_size):
for batch_offset in range(0, bucket_size, batch_size):
batch_start_len = bucket_start_len + batch_offset
batch_end_len = min(batch_start_len + batch_size,
bucket_start_len + bucket_size)
expected_indices = []
expected_values = []
for length in range(batch_start_len, batch_end_len):
for val in range(length + 1):
expected_indices.append((length - batch_start_len, val))
expected_values.append(val)
expected_sprs_tensor = (tuple(expected_indices),
tuple(expected_values))
all_expected_sparse_tensors.add(expected_sprs_tensor)
return all_expected_sparse_tensors
def _compute_batches(dataset):
"""Computes actual batch outputs of dataset and stores in a set."""
batch = dataset.make_one_shot_iterator().get_next()
all_sparse_tensors = set()
with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
while True:
output = sess.run(batch)
sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
tuple(output.values))
all_sparse_tensors.add(sprs_tensor)
return all_sparse_tensors
dataset = _build_dataset()
boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
dataset = dataset.apply(grouping.bucket_by_sequence_length(
_element_length_fn,
boundaries,
[batch_size] * (len(boundaries) + 1),
no_padding=True))
batches = _compute_batches(dataset)
expected_batches = _compute_expected_batches()
self.assertEqual(batches, expected_batches)
if __name__ == "__main__":
test.main()

View File

@ -12,440 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for prefetching_ops."""
"""Tests for `tf.data.experimental.copy_to_device()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
def setUp(self):
self._event = threading.Event()
def _create_ds_and_iterator(self, device0, initializable=False):
def gen():
for i in range(1, 10):
yield [float(i)]
if i == 6:
self._event.set()
with ops.device(device0):
ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
if initializable:
ds_iterator = ds.make_initializable_iterator()
else:
ds_iterator = ds.make_one_shot_iterator()
return (ds, ds_iterator)
def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
ds_iterator_handle = ds_iterator.string_handle()
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, ds.output_types, ds.output_shapes)
return remote_iterator.get_next()
target = constant_op.constant(device0)
with ops.device(device1):
buffer_resource_handle = prefetching_ops.function_buffering_resource(
f=_remote_fn,
output_types=[dtypes.float32],
target_device=target,
string_arg=ds_iterator_handle,
buffer_size=3,
shared_name=buffer_name)
with ops.device(device1):
prefetch_op = prefetching_ops.function_buffering_resource_get_next(
function_buffer_resource=buffer_resource_handle,
output_types=[dtypes.float32])
reset_op = prefetching_ops.function_buffering_resource_reset(
function_buffer_resource=buffer_resource_handle)
destroy_op = resource_variable_ops.destroy_resource_op(
buffer_resource_handle, ignore_lookup_error=True)
return (prefetch_op, reset_op, destroy_op)
def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
device0, device1)
with self.test_session(config=worker_config) as sess:
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
sess.run(destroy_op)
def testSameDeviceCPU(self):
self._prefetch_fn_helper_one_shot("same_device_cpu",
"/job:localhost/replica:0/task:0/cpu:0",
"/job:localhost/replica:0/task:0/cpu:0")
def testDifferentDeviceCPU(self):
self._prefetch_fn_helper_one_shot("diff_device_cpu",
"/job:localhost/replica:0/task:0/cpu:0",
"/job:localhost/replica:0/task:0/cpu:1")
def testDifferentDeviceCPUGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
self._prefetch_fn_helper_one_shot("cpu_gpu",
"/job:localhost/replica:0/task:0/cpu:0",
"/job:localhost/replica:0/task:0/gpu:0")
def testReinitialization(self):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
prefetch_op, reset_op, destroy_op = self._create_ops(
ds, ds_iterator, "reinit", device0, device1)
with self.test_session(config=worker_config) as sess:
sess.run(ds_iterator.initializer)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
# Lets reset the function buffering resource and reinitialize the
# iterator. Should be able to go through this again.
self._event.clear()
sess.run(reset_op)
sess.run(ds_iterator.initializer)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
sess.run(destroy_op)
def testReinitializationOutOfRange(self):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
prefetch_op, reset_op, destroy_op = self._create_ops(
ds, ds_iterator, "reinit", device0, device1)
with self.test_session(config=worker_config) as sess:
sess.run(ds_iterator.initializer)
for i in range(1, 10):
elem = sess.run(prefetch_op)
self.assertEqual(elem, [float(i)])
# Try fetching after its over twice to test out end of sequence.
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
# Now reset everything and try it out again.
self._event.clear()
sess.run(reset_op)
sess.run(ds_iterator.initializer)
for i in range(1, 10):
elem = sess.run(prefetch_op)
self.assertEqual(elem, [float(i)])
# Try fetching after its over twice to test out end of sequence.
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
sess.run(destroy_op)
def testStringsGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/gpu:0"
ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
ds_iterator = ds.make_one_shot_iterator()
ds_iterator_handle = ds_iterator.string_handle()
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, ds.output_types, ds.output_shapes)
return remote_iterator.get_next()
target = constant_op.constant(device0)
with ops.device(device1):
buffer_resource_handle = prefetching_ops.function_buffering_resource(
f=_remote_fn,
output_types=[dtypes.string],
target_device=target,
string_arg=ds_iterator_handle,
buffer_size=3,
shared_name="strings")
with ops.device(device1):
prefetch_op = prefetching_ops.function_buffering_resource_get_next(
function_buffer_resource=buffer_resource_handle,
output_types=[dtypes.string])
destroy_op = resource_variable_ops.destroy_resource_op(
buffer_resource_handle, ignore_lookup_error=True)
with self.cached_session() as sess:
self.assertEqual([b"a"], sess.run(prefetch_op))
self.assertEqual([b"b"], sess.run(prefetch_op))
self.assertEqual([b"c"], sess.run(prefetch_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
sess.run(destroy_op)
class PrefetchToDeviceTest(test_base.DatasetTestBase):
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device(
"/job:localhost/replica:0/task:0/device:CPU:0"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element["a"].dtype)
self.assertEqual([], next_element["a"].shape)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchSparseTensorsToDevice(self):
def make_tensor(i):
return sparse_tensor.SparseTensorValue(
indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = sess.run(next_element)
self.assertAllEqual([i], actual.values)
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_initializable_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
class CopyToDeviceTest(test_base.DatasetTestBase):
def testCopyToDevice(self):

View File

@ -0,0 +1,51 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.Counter`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
class CounterTest(test_base.DatasetTestBase):
def testCounter(self):
"""Test dataset construction using `count`."""
iterator = (counter.Counter(start=3, step=4)
.make_one_shot_iterator())
get_next = iterator.get_next()
self.assertEqual([], get_next.shape.as_list())
self.assertEqual(dtypes.int64, get_next.dtype)
negative_iterator = (counter.Counter(start=0, step=-1)
.make_one_shot_iterator())
negative_get_next = negative_iterator.get_next()
with self.cached_session() as sess:
self.assertEqual(3, sess.run(get_next))
self.assertEqual(3 + 4, sess.run(get_next))
self.assertEqual(3 + 2 * 4, sess.run(get_next))
self.assertEqual(0, sess.run(negative_get_next))
self.assertEqual(-1, sess.run(negative_get_next))
self.assertEqual(-2, sess.run(negative_get_next))
if __name__ == "__main__":
test.main()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for CsvDatasetOp."""
"""Tests for `tf.data.experimental.CsvDataset`."""
from __future__ import absolute_import
from __future__ import division
@ -44,7 +44,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class CsvDatasetOpTest(test_base.DatasetTestBase):
class CsvDatasetTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []

View File

@ -1,692 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for testing serializable datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest
def remove_variants(get_next_op):
# TODO(b/72408568): Remove this once session.run can get
# variant tensors.
"""Remove variants from a nest structure, so sess.run will execute."""
def _remove_variant(x):
if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
return ()
else:
return x
return nest.map_structure(_remove_variant, get_next_op)
class DatasetSerializationTestBase(test.TestCase):
"""Base class for testing serializable datasets."""
def tearDown(self):
self._delete_ckpt()
# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
# (deprecated) saveable `SparseTensorSliceDataset`, once the API
# `from_sparse_tensor_slices()`and related tests are deleted.
def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
"""Runs the core tests.
Args:
ds_fn1: 0-argument function that returns a Dataset.
ds_fn2: 0-argument function that returns a Dataset different from
ds_fn1. If None, verify_restore_in_modified_graph test is not run.
num_outputs: Total number of outputs expected from this Dataset.
sparse_tensors: Whether dataset is built from SparseTensor(s).
Raises:
AssertionError if any test fails.
"""
self.verify_unused_iterator(
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
self.verify_fully_used_iterator(
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
self.verify_exhausted_iterator(
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
self.verify_init_before_restore(
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
self.verify_multiple_breaks(
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
self.verify_reset_restored_iterator(
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
self.verify_restore_in_empty_graph(
ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
if ds_fn2:
self.verify_restore_in_modified_graph(
ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
def verify_unused_iterator(self,
ds_fn,
num_outputs,
sparse_tensors=False,
verify_exhausted=True):
"""Verifies that saving and restoring an unused iterator works.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
sparse_tensors: See `run_core_tests`.
verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
self.verify_run_with_breaks(
ds_fn, [0],
num_outputs,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
def verify_fully_used_iterator(self, ds_fn, num_outputs,
sparse_tensors=False):
"""Verifies that saving and restoring a fully used iterator works.
Note that this only checks saving and restoring an iterator from which
`num_outputs` items have been produced but does not check for an
exhausted iterator, i.e., one from which an OutOfRange error has been
returned.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
sparse_tensors: See `run_core_tests`.
Raises:
AssertionError if test fails.
"""
self.verify_run_with_breaks(
ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
"""Verifies that saving and restoring an exhausted iterator works.
An exhausted iterator is one which has returned an OutOfRange error.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
sparse_tensors: See `run_core_tests`.
Raises:
AssertionError if any test fails.
"""
self.gen_outputs(
ds_fn, [],
num_outputs,
verify_exhausted=True,
sparse_tensors=sparse_tensors)
actual = self.gen_outputs(
ds_fn, [],
0,
ckpt_saved=True,
verify_exhausted=True,
sparse_tensors=sparse_tensors)
self.assertEqual(len(actual), 0)
def verify_init_before_restore(self,
ds_fn,
num_outputs,
sparse_tensors=False,
verify_exhausted=True):
"""Verifies that restoring into an already initialized iterator works.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
sparse_tensors: See `run_core_tests`.
verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
self.verify_run_with_breaks(
ds_fn,
self.gen_break_points(num_outputs),
num_outputs,
init_before_restore=True,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
def verify_multiple_breaks(self,
ds_fn,
num_outputs,
num_breaks=10,
sparse_tensors=False,
verify_exhausted=True):
"""Attempts to save/restore at multiple break points.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
num_breaks: The number of break points. These are uniformly spread in
[0, num_outputs] both inclusive.
sparse_tensors: See `run_core_tests`.
verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
self.verify_run_with_breaks(
ds_fn,
self.gen_break_points(num_outputs, num_breaks),
num_outputs,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
def verify_reset_restored_iterator(self,
ds_fn,
num_outputs,
break_point=None,
sparse_tensors=False,
verify_exhausted=True):
"""Attempts to re-initialize a restored iterator.
This is useful when restoring a training checkpoint during validation.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
break_point: Break point. Optional. Defaults to num_outputs/2.
sparse_tensors: See `run_core_tests`.
verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
break_point = num_outputs // 2 if not break_point else break_point
# Collect ground truth containing all outputs.
expected = self.gen_outputs(
ds_fn, [],
num_outputs,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
# Skip some items and save checkpoint.
self.gen_outputs(
ds_fn, [],
break_point,
sparse_tensors=sparse_tensors,
verify_exhausted=False)
actual = []
# Restore from checkpoint and then run init_op.
with ops.Graph().as_default() as g:
saver = self._import_meta_graph()
init_op, get_next_op = self._get_iterator_ops_from_collection(
ds_fn, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
with self.session(graph=g) as sess:
self._restore(saver, sess)
self._initialize(init_op, sess)
for _ in range(num_outputs):
actual.append(sess.run(get_next_op))
if verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self.match(expected, actual)
def verify_restore_in_modified_graph(self,
ds_fn1,
ds_fn2,
num_outputs,
break_point=None,
sparse_tensors=False,
verify_exhausted=True):
"""Attempts to restore an iterator in a modified graph.
Builds an input pipeline using ds_fn1, runs it for `break_point` steps
and saves a checkpoint. Then builds a new graph using ds_fn2, restores
the checkpoint from ds_fn1 and verifies that the restore is successful.
Args:
ds_fn1: See `run_core_tests`.
ds_fn2: See `run_core_tests`.
num_outputs: See `run_core_tests`.
break_point: Break point. Optional. Defaults to num_outputs/2.
sparse_tensors: See `run_core_tests`.
verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
break_point = num_outputs // 2 if not break_point else break_point
# Skip `break_point` items and store the remaining produced from ds_fn1
# in `expected`.
self.gen_outputs(
ds_fn1, [],
break_point,
sparse_tensors=sparse_tensors,
verify_exhausted=False)
expected = self.gen_outputs(
ds_fn1, [],
num_outputs - break_point,
ckpt_saved=True,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
# Generate `break_point` items from ds_fn1 and save checkpoint.
self.gen_outputs(
ds_fn1, [],
break_point,
sparse_tensors=sparse_tensors,
verify_exhausted=False)
actual = []
# Build graph for ds_fn2 but load checkpoint for ds_fn1.
with ops.Graph().as_default() as g:
_, get_next_op, saver = self._build_graph(
ds_fn2, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
with self.session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
if verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self.match(expected, actual)
def verify_restore_in_empty_graph(self,
ds_fn,
num_outputs,
break_point=None,
sparse_tensors=False,
verify_exhausted=True):
"""Attempts to restore an iterator in an empty graph.
Builds an input pipeline using ds_fn, runs it for `break_point` steps
and saves a checkpoint. Then builds a new empty graph, restores
the checkpoint from ds_fn and verifies that the restore is successful.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
break_point: Break point. Optional. Defaults to num_outputs/2.
sparse_tensors: See `run_core_tests`.
verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
break_point = num_outputs // 2 if not break_point else break_point
# Skip `break_point` items and store the remaining produced from ds_fn
# in `expected`.
self.gen_outputs(
ds_fn, [],
break_point,
sparse_tensors=sparse_tensors,
verify_exhausted=False)
expected = self.gen_outputs(
ds_fn, [],
num_outputs - break_point,
ckpt_saved=True,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
# Generate `break_point` items from ds_fn and save checkpoint.
self.gen_outputs(
ds_fn, [],
break_point,
sparse_tensors=sparse_tensors,
verify_exhausted=False)
actual = []
# Build an empty graph but load checkpoint for ds_fn.
with ops.Graph().as_default() as g:
get_next_op, saver = self._build_empty_graph(
ds_fn, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
with self.session(graph=g) as sess:
self._restore(saver, sess)
for _ in range(num_outputs - break_point):
actual.append(sess.run(get_next_op))
if verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
self.match(expected, actual)
def verify_error_on_save(self,
ds_fn,
num_outputs,
error,
break_point=None,
sparse_tensors=False):
"""Attempts to save a non-saveable iterator.
Args:
ds_fn: See `run_core_tests`.
num_outputs: See `run_core_tests`.
error: Declared error when trying to save iterator.
break_point: Break point. Optional. Defaults to num_outputs/2.
sparse_tensors: See `run_core_tests`.
Raises:
AssertionError if any test fails.
"""
break_point = num_outputs // 2 if not break_point else break_point
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
get_next_op = remove_variants(get_next_op)
with self.session(graph=g) as sess:
self._initialize(init_op, sess)
for _ in range(break_point):
sess.run(get_next_op)
with self.assertRaises(error):
self._save(sess, saver)
def verify_run_with_breaks(self,
ds_fn,
break_points,
num_outputs,
init_before_restore=False,
sparse_tensors=False,
verify_exhausted=True):
"""Verifies that ds_fn() produces the same outputs with and without breaks.
1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
*without* stopping at break points.
2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
with stopping at break points.
Deep matches outputs from 1 and 2.
Args:
ds_fn: See `gen_outputs`.
break_points: See `gen_outputs`.
num_outputs: See `gen_outputs`.
init_before_restore: See `gen_outputs`.
sparse_tensors: See `run_core_tests`.
verify_exhausted: See `gen_outputs`.
Raises:
AssertionError if any test fails.
"""
expected = self.gen_outputs(
ds_fn, [],
num_outputs,
init_before_restore=init_before_restore,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
actual = self.gen_outputs(
ds_fn,
break_points,
num_outputs,
init_before_restore=init_before_restore,
sparse_tensors=sparse_tensors,
verify_exhausted=verify_exhausted)
self.match(expected, actual)
def gen_outputs(self,
ds_fn,
break_points,
num_outputs,
ckpt_saved=False,
init_before_restore=False,
sparse_tensors=False,
verify_exhausted=True,
save_checkpoint_at_end=True):
"""Generates elements from input dataset while stopping at break points.
Produces `num_outputs` outputs and saves the state of the iterator in the
Saver checkpoint.
Args:
ds_fn: 0-argument function that returns the dataset.
break_points: A list of integers. For each `break_point` in
`break_points`, we produce outputs till `break_point` number of items
have been produced and then checkpoint the state. The current graph
and session are destroyed and a new graph and session are used to
produce outputs till next checkpoint or till `num_outputs` elements
have been produced. `break_point` must be <= `num_outputs`.
num_outputs: The total number of outputs to produce from the iterator.
ckpt_saved: Whether a checkpoint already exists. If False, we build the
graph from ds_fn.
init_before_restore: Whether init should be called before saver.restore.
This is just so that we can verify that restoring an already initialized
iterator works.
sparse_tensors: Whether dataset is built from SparseTensor(s).
verify_exhausted: Whether to verify that the iterator has been exhausted
after producing `num_outputs` elements.
save_checkpoint_at_end: Whether to save a checkpoint after producing all
outputs. If False, checkpoints are saved each break point but not at the
end. Note that checkpoints overwrite each other so there is always only
a single checkpoint available. Defaults to True.
Returns:
A list of `num_outputs` items.
"""
outputs = []
def get_ops():
if ckpt_saved:
saver = self._import_meta_graph()
init_op, get_next_op = self._get_iterator_ops_from_collection(
ds_fn, sparse_tensors=sparse_tensors)
else:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
return init_op, get_next_op, saver
for i in range(len(break_points) + 1):
with ops.Graph().as_default() as g:
init_op, get_next_op, saver = get_ops()
get_next_op = remove_variants(get_next_op)
with self.session(graph=g) as sess:
if ckpt_saved:
if init_before_restore:
self._initialize(init_op, sess)
self._restore(saver, sess)
else:
self._initialize(init_op, sess)
start = break_points[i - 1] if i > 0 else 0
end = break_points[i] if i < len(break_points) else num_outputs
num_iters = end - start
for _ in range(num_iters):
outputs.append(sess.run(get_next_op))
if i == len(break_points) and verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
if save_checkpoint_at_end or i < len(break_points):
self._save(sess, saver)
ckpt_saved = True
return outputs
def match(self, expected, actual):
"""Matches nested structures.
Recursively matches shape and values of `expected` and `actual`.
Handles scalars, numpy arrays and other python sequence containers
e.g. list, dict.
Args:
expected: Nested structure 1.
actual: Nested structure 2.
Raises:
AssertionError if matching fails.
"""
if isinstance(expected, np.ndarray):
expected = expected.tolist()
if isinstance(actual, np.ndarray):
actual = actual.tolist()
self.assertEqual(type(expected), type(actual))
if nest.is_sequence(expected):
self.assertEqual(len(expected), len(actual))
if isinstance(expected, dict):
for key1, key2 in zip(sorted(expected), sorted(actual)):
self.assertEqual(key1, key2)
self.match(expected[key1], actual[key2])
else:
for item1, item2 in zip(expected, actual):
self.match(item1, item2)
else:
self.assertEqual(expected, actual)
def does_not_match(self, expected, actual):
with self.assertRaises(AssertionError):
self.match(expected, actual)
def gen_break_points(self, num_outputs, num_samples=10):
"""Generates `num_samples` breaks points in [0, num_outputs]."""
return np.linspace(0, num_outputs, num_samples, dtype=int)
def _build_graph(self, ds_fn, sparse_tensors=False):
iterator = ds_fn().make_initializable_iterator()
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
init_op = iterator.initializer
if sparse_tensors:
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
else:
get_next = iterator.get_next()
self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
sparse_tensors)
saver = saver_lib.Saver(allow_empty=True)
return init_op, get_next, saver
def _build_empty_graph(self, ds_fn, sparse_tensors=False):
iterator = iterator_ops.Iterator.from_structure(
self._get_output_types(ds_fn),
output_shapes=self._get_output_shapes(ds_fn),
output_classes=self._get_output_classes(ds_fn))
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
if sparse_tensors:
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
else:
get_next = iterator.get_next()
saver = saver_lib.Saver(allow_empty=True)
return get_next, saver
def _add_iterator_ops_to_collection(self,
init_op,
get_next,
ds_fn,
sparse_tensors=False):
ops.add_to_collection("iterator_ops", init_op)
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
# do not support tuples we flatten the tensors and restore the shape in
# `_get_iterator_ops_from_collection`.
if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
ops.add_to_collection("iterator_ops", get_next.indices)
ops.add_to_collection("iterator_ops", get_next.values)
ops.add_to_collection("iterator_ops", get_next.dense_shape)
return
get_next_list = nest.flatten(get_next)
for i, output_class in enumerate(
nest.flatten(self._get_output_classes(ds_fn))):
if output_class is sparse_tensor.SparseTensor:
ops.add_to_collection("iterator_ops", get_next_list[i].indices)
ops.add_to_collection("iterator_ops", get_next_list[i].values)
ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
else:
ops.add_to_collection("iterator_ops", get_next_list[i])
def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
all_ops = ops.get_collection("iterator_ops")
if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
init_op, indices, values, dense_shape = all_ops
return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
get_next_list = []
i = 1
for output_class in nest.flatten(self._get_output_classes(ds_fn)):
if output_class is sparse_tensor.SparseTensor:
indices, values, dense_shape = all_ops[i:i + 3]
i += 3
get_next_list.append(
sparse_tensor.SparseTensor(indices, values, dense_shape))
else:
get_next_list.append(all_ops[i])
i += 1
return all_ops[0], nest.pack_sequence_as(
self._get_output_types(ds_fn), get_next_list)
def _get_output_types(self, ds_fn):
with ops.Graph().as_default():
return ds_fn().output_types
def _get_output_shapes(self, ds_fn):
with ops.Graph().as_default():
return ds_fn().output_shapes
def _get_output_classes(self, ds_fn):
with ops.Graph().as_default():
return ds_fn().output_classes
def _ckpt_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
def _latest_ckpt(self):
return checkpoint_management.latest_checkpoint(self.get_temp_dir())
def _save(self, sess, saver):
saver.save(sess, self._ckpt_path())
def _restore(self, saver, sess):
sess.run(lookup_ops.tables_initializer())
saver.restore(sess, self._latest_ckpt())
def _initialize(self, init_op, sess):
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
sess.run(init_op)
def _import_meta_graph(self):
meta_file_path = self._ckpt_path() + ".meta"
return saver_lib.import_meta_graph(meta_file_path)
def _delete_ckpt(self):
# Remove all checkpoint files.
prefix = self._ckpt_path()
pattern = prefix + "*"
files = gfile.Glob(pattern)
map(gfile.Remove, files)

View File

@ -0,0 +1,124 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.dense_to_sparse_batch()."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class DenseToSparseBatchTest(test_base.DatasetTestBase):
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([x], x)).apply(
batching.dense_to_sparse_batch(4, [12]))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
results = sess.run(get_next)
self.assertAllEqual([[i, j]
for i, c in enumerate(components[start:start + 4])
for j in range(c)], results.indices)
self.assertAllEqual(
[c for c in components[start:start + 4] for _ in range(c)],
results.values)
self.assertAllEqual([min(4,
len(components) - start), 12],
results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testDenseToSparseBatchDatasetWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([x, x], x)).apply(
batching.dense_to_sparse_batch(
4, [5, None])).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
results = sess.run(get_next)
self.assertAllEqual([[i, j, z]
for i, c in enumerate(components[start:start + 4])
for j in range(c)
for z in range(c)], results.indices)
self.assertAllEqual([
c
for c in components[start:start + 4] for _ in range(c)
for _ in range(c)
], results.values)
self.assertAllEqual([
min(4,
len(components) - start), 5,
np.max(components[start:start + 4])
], results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
dataset_ops.Dataset.from_tensors(input_tensor).apply(
batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
def testDenseToSparseBatchDatasetShapeErrors(self):
input_tensor = array_ops.placeholder(dtypes.int32)
iterator = (
dataset_ops.Dataset.from_tensors(input_tensor).apply(
batching.dense_to_sparse_batch(4, [12]))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
# Initialize with an input tensor of incompatible rank.
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"incompatible with the row shape"):
sess.run(get_next)
# Initialize with an input tensor that is larger than `row_shape`.
sess.run(init_op, feed_dict={input_tensor: range(13)})
with self.assertRaisesRegexp(errors.DataLossError,
"larger than the row shape"):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test RangeDataset."""
"""Tests for `tf.data.experimental.enumerate_dataset()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.experimental.ops import enumerate_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@ -28,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
class RangeDatasetTest(test_base.DatasetTestBase):
class EnumerateDatasetTest(test_base.DatasetTestBase):
def testEnumerateDataset(self):
components = (["a", "b"], [1, 2], [37.0, 38])
@ -52,27 +51,6 @@ class RangeDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testCounter(self):
"""Test dataset construction using `count`."""
iterator = (counter.Counter(start=3, step=4)
.make_one_shot_iterator())
get_next = iterator.get_next()
self.assertEqual([], get_next.shape.as_list())
self.assertEqual(dtypes.int64, get_next.dtype)
negative_iterator = (counter.Counter(start=0, step=-1)
.make_one_shot_iterator())
negative_get_next = negative_iterator.get_next()
with self.cached_session() as sess:
self.assertEqual(3, sess.run(get_next))
self.assertEqual(3 + 4, sess.run(get_next))
self.assertEqual(3 + 2 * 4, sess.run(get_next))
self.assertEqual(0, sess.run(negative_get_next))
self.assertEqual(-1, sess.run(negative_get_next))
self.assertEqual(-2, sess.run(negative_get_next))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,247 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the private `FunctionBufferingResource` used in prefetching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
class FunctionBufferingResourceTest(test_base.DatasetTestBase):
def setUp(self):
self._event = threading.Event()
def _create_ds_and_iterator(self, device0, initializable=False):
def gen():
for i in range(1, 10):
yield [float(i)]
if i == 6:
self._event.set()
with ops.device(device0):
ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
if initializable:
ds_iterator = ds.make_initializable_iterator()
else:
ds_iterator = ds.make_one_shot_iterator()
return (ds, ds_iterator)
def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
ds_iterator_handle = ds_iterator.string_handle()
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, ds.output_types, ds.output_shapes)
return remote_iterator.get_next()
target = constant_op.constant(device0)
with ops.device(device1):
buffer_resource_handle = prefetching_ops.function_buffering_resource(
f=_remote_fn,
output_types=[dtypes.float32],
target_device=target,
string_arg=ds_iterator_handle,
buffer_size=3,
shared_name=buffer_name)
with ops.device(device1):
prefetch_op = prefetching_ops.function_buffering_resource_get_next(
function_buffer_resource=buffer_resource_handle,
output_types=[dtypes.float32])
reset_op = prefetching_ops.function_buffering_resource_reset(
function_buffer_resource=buffer_resource_handle)
destroy_op = resource_variable_ops.destroy_resource_op(
buffer_resource_handle, ignore_lookup_error=True)
return (prefetch_op, reset_op, destroy_op)
def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
device0, device1)
with self.test_session(config=worker_config) as sess:
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
sess.run(destroy_op)
def testSameDeviceCPU(self):
self._prefetch_fn_helper_one_shot("same_device_cpu",
"/job:localhost/replica:0/task:0/cpu:0",
"/job:localhost/replica:0/task:0/cpu:0")
def testDifferentDeviceCPU(self):
self._prefetch_fn_helper_one_shot("diff_device_cpu",
"/job:localhost/replica:0/task:0/cpu:0",
"/job:localhost/replica:0/task:0/cpu:1")
def testDifferentDeviceCPUGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
self._prefetch_fn_helper_one_shot("cpu_gpu",
"/job:localhost/replica:0/task:0/cpu:0",
"/job:localhost/replica:0/task:0/gpu:0")
def testReinitialization(self):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
prefetch_op, reset_op, destroy_op = self._create_ops(
ds, ds_iterator, "reinit", device0, device1)
with self.test_session(config=worker_config) as sess:
sess.run(ds_iterator.initializer)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
# Lets reset the function buffering resource and reinitialize the
# iterator. Should be able to go through this again.
self._event.clear()
sess.run(reset_op)
sess.run(ds_iterator.initializer)
elem = sess.run(prefetch_op)
self.assertEqual(elem, [1.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [2.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [3.0])
elem = sess.run(prefetch_op)
self.assertEqual(elem, [4.0])
self._event.wait()
elem = sess.run(prefetch_op)
self.assertEqual(elem, [5.0])
sess.run(destroy_op)
def testReinitializationOutOfRange(self):
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
prefetch_op, reset_op, destroy_op = self._create_ops(
ds, ds_iterator, "reinit", device0, device1)
with self.test_session(config=worker_config) as sess:
sess.run(ds_iterator.initializer)
for i in range(1, 10):
elem = sess.run(prefetch_op)
self.assertEqual(elem, [float(i)])
# Try fetching after its over twice to test out end of sequence.
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
# Now reset everything and try it out again.
self._event.clear()
sess.run(reset_op)
sess.run(ds_iterator.initializer)
for i in range(1, 10):
elem = sess.run(prefetch_op)
self.assertEqual(elem, [float(i)])
# Try fetching after its over twice to test out end of sequence.
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
sess.run(destroy_op)
def testStringsGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/gpu:0"
ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
ds_iterator = ds.make_one_shot_iterator()
ds_iterator_handle = ds_iterator.string_handle()
@function.Defun(dtypes.string)
def _remote_fn(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
h, ds.output_types, ds.output_shapes)
return remote_iterator.get_next()
target = constant_op.constant(device0)
with ops.device(device1):
buffer_resource_handle = prefetching_ops.function_buffering_resource(
f=_remote_fn,
output_types=[dtypes.string],
target_device=target,
string_arg=ds_iterator_handle,
buffer_size=3,
shared_name="strings")
with ops.device(device1):
prefetch_op = prefetching_ops.function_buffering_resource_get_next(
function_buffer_resource=buffer_resource_handle,
output_types=[dtypes.string])
destroy_op = resource_variable_ops.destroy_resource_op(
buffer_resource_handle, ignore_lookup_error=True)
with self.cached_session() as sess:
self.assertEqual([b"a"], sess.run(prefetch_op))
self.assertEqual([b"b"], sess.run(prefetch_op))
self.assertEqual([b"c"], sess.run(prefetch_op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(prefetch_op)
sess.run(destroy_op)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,199 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.group_by_reducer()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class GroupByReducerTest(test_base.DatasetTestBase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
for expected in values:
got = sess.run(get_next)
self.assertEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSum(self):
reducer = grouping.Reducer(
init_func=lambda _: np.int64(0),
reduce_func=lambda x, y: x + y,
finalize_func=lambda x: x)
for i in range(1, 11):
dataset = dataset_ops.Dataset.range(2 * i).apply(
grouping.group_by_reducer(lambda x: x % 2, reducer))
self.checkResults(
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
def testAverage(self):
def reduce_fn(x, y):
return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
x[1] + 1), x[1] + 1
reducer = grouping.Reducer(
init_func=lambda _: (0.0, 0.0),
reduce_func=reduce_fn,
finalize_func=lambda x, _: x)
for i in range(1, 11):
dataset = dataset_ops.Dataset.range(2 * i).apply(
grouping.group_by_reducer(
lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
self.checkResults(
dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
def testConcat(self):
components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
reducer = grouping.Reducer(
init_func=lambda x: "",
reduce_func=lambda x, y: x + y[0],
finalize_func=lambda x: x)
for i in range(1, 11):
dataset = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensor_slices(components),
dataset_ops.Dataset.range(2 * i))).apply(
grouping.group_by_reducer(lambda x, y: y % 2, reducer))
self.checkResults(
dataset,
shapes=tensor_shape.scalar(),
values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
def testSparseSum(self):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=np.array([[0, 0]]),
values=(i * np.array([1], dtype=np.int64)),
dense_shape=np.array([1, 1]))
reducer = grouping.Reducer(
init_func=lambda _: _sparse(np.int64(0)),
reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
finalize_func=lambda x: x.values[0])
for i in range(1, 11):
dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
self.checkResults(
dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
def testChangingStateShape(self):
def reduce_fn(x, _):
# Statically known rank, but dynamic length.
larger_dim = array_ops.concat([x[0], x[0]], 0)
# Statically unknown rank.
larger_rank = array_ops.expand_dims(x[1], 0)
return larger_dim, larger_rank
reducer = grouping.Reducer(
init_func=lambda x: ([0], 1),
reduce_func=reduce_fn,
finalize_func=lambda x, y: (x, y))
for i in range(1, 11):
dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
grouping.group_by_reducer(lambda x: x, reducer))
self.assertEqual([None], dataset.output_shapes[0].as_list())
self.assertIs(None, dataset.output_shapes[1].ndims)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testTypeMismatch(self):
reducer = grouping.Reducer(
init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
finalize_func=lambda x: x)
dataset = dataset_ops.Dataset.range(10)
with self.assertRaisesRegexp(
TypeError,
"The element types for the new state must match the initial state."):
dataset.apply(
grouping.group_by_reducer(lambda _: np.int64(0), reducer))
# TODO(b/78665031): Remove once non-scalar keys are supported.
def testInvalidKeyShape(self):
reducer = grouping.Reducer(
init_func=lambda x: np.int64(0),
reduce_func=lambda x, y: x + y,
finalize_func=lambda x: x)
dataset = dataset_ops.Dataset.range(10)
with self.assertRaisesRegexp(
ValueError, "`key_func` must return a single tf.int64 tensor."):
dataset.apply(
grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
# TODO(b/78665031): Remove once non-int64 keys are supported.
def testInvalidKeyType(self):
reducer = grouping.Reducer(
init_func=lambda x: np.int64(0),
reduce_func=lambda x, y: x + y,
finalize_func=lambda x: x)
dataset = dataset_ops.Dataset.range(10)
with self.assertRaisesRegexp(
ValueError, "`key_func` must return a single tf.int64 tensor."):
dataset.apply(
grouping.group_by_reducer(lambda _: "wrong", reducer))
def testTuple(self):
def init_fn(_):
return np.array([], dtype=np.int64), np.int64(0)
def reduce_fn(state, value):
s1, s2 = state
v1, v2 = value
return array_ops.concat([s1, [v1]], 0), s2 + v2
def finalize_fn(s1, s2):
return s1, s2
reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
dataset = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,367 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.group_by_window()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
class GroupByWindowTest(test_base.DatasetTestBase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
# generic form of padded_batch that pads every component
# dynamically and does not rely on static shape information about
# the arguments.
return dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensors(bucket),
window.padded_batch(
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
[None]), tensor_shape.TensorShape([3])))))
def testSingleBucket(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
bucketed_dataset = input_dataset.apply(
grouping.group_by_window(
lambda x, y, z: 0,
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
iterator = bucketed_dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
which_bucket, bucketed_values = sess.run(get_next)
self.assertEqual(0, which_bucket)
expected_scalar_int = np.arange(32, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
for i in range(32):
expected_unk_int64[i, :i] = i
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values[0])
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
def testEvenOddBuckets(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
bucketed_dataset = input_dataset.apply(
grouping.group_by_window(
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
iterator = bucketed_dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
which_bucket_even, bucketed_values_even = sess.run(get_next)
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
# Count number of bucket_tensors.
self.assertEqual(3, len(bucketed_values_even))
self.assertEqual(3, len(bucketed_values_odd))
# Ensure bucket 0 was used for all minibatch entries.
self.assertAllEqual(0, which_bucket_even)
self.assertAllEqual(1, which_bucket_odd)
# Test the first bucket outputted, the events starting at 0
expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
for i in range(0, 32):
expected_unk_int64[i, :2 * i] = 2 * i
expected_vec3_str = np.vstack(
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
# Test the second bucket outputted, the odds starting at 1
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
for i in range(0, 32):
expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
expected_vec3_str = np.vstack(
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
def testEvenOddBucketsFilterOutAllOdd(self):
def _map_fn(v):
return {
"x": v,
"y": array_ops.fill([v], v),
"z": array_ops.fill([3], string_ops.as_string(v))
}
def _dynamic_pad_fn(bucket, window, _):
return dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensors(bucket),
window.padded_batch(
32, {
"x": tensor_shape.TensorShape([]),
"y": tensor_shape.TensorShape([None]),
"z": tensor_shape.TensorShape([3])
})))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
.filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
bucketed_dataset = input_dataset.apply(
grouping.group_by_window(
lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
iterator = bucketed_dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
which_bucket0, bucketed_values_even0 = sess.run(get_next)
which_bucket1, bucketed_values_even1 = sess.run(get_next)
# Ensure that bucket 1 was completely filtered out
self.assertAllEqual(0, which_bucket0)
self.assertAllEqual(0, which_bucket1)
self.assertAllEqual(
np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
self.assertAllEqual(
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
def testDynamicWindowSize(self):
components = np.arange(100).astype(np.int64)
# Key fn: even/odd
# Reduce fn: batches of 5
# Window size fn: even=5, odd=10
def window_size_func(key):
window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
return window_sizes[key]
dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
None, window_size_func))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
while True:
result = sess.run(get_next)
is_even = all(x % 2 == 0 for x in result)
is_odd = all(x % 2 == 1 for x in result)
self.assertTrue(is_even or is_odd)
expected_batch_size = 5 if is_even else 10
self.assertEqual(expected_batch_size, result.shape[0])
batches += 1
self.assertEqual(batches, 15)
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
.apply(
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
4)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
result = sess.run(get_next)
self.assertTrue(
all(x % 2 == 0
for x in result) or all(x % 2 == 1)
for x in result)
counts.append(result.shape[0])
self.assertEqual(len(components), sum(counts))
num_full_batches = len([c for c in counts if c == 4])
self.assertGreaterEqual(num_full_batches, 24)
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
def testImmediateOutput(self):
components = np.array(
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
4)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
# 2. Different buckets can produce output at different rates, and
# 3. For deterministic input, the output is deterministic.
for _ in range(3):
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).apply(
grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
4)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
# The small outputs at the end are deterministically produced in key
# order.
self.assertAllEqual([0, 0, 0], sess.run(get_next))
self.assertAllEqual([1], sess.run(get_next))
def testEmpty(self):
iterator = (
dataset_ops.Dataset.range(4).apply(
grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Window size must be greater than zero, but got 0."):
print(sess.run(get_next))
def testReduceFuncError(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
def reduce_func(_, xs):
# Introduce an incorrect padded shape that cannot (currently) be
# detected at graph construction time.
return xs.padded_batch(
4,
padded_shapes=(tensor_shape.TensorShape([]),
constant_op.constant([5], dtype=dtypes.int64) * -1))
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
grouping.group_by_window(lambda x, _: x % 2, reduce_func,
32)).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
def testConsumeWindowDatasetMoreThanOnce(self):
components = np.random.randint(50, size=(200,)).astype(np.int64)
def reduce_func(key, window):
# Apply two different kinds of padding to the input: tight
# padding, and quantized (to a multiple of 10) padding.
return dataset_ops.Dataset.zip((
window.padded_batch(
4, padded_shapes=tensor_shape.TensorShape([None])),
window.padded_batch(
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
))
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
.apply(grouping.group_by_window(
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
reduce_func, 4))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
tight_result, multiple_of_10_result = sess.run(get_next)
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
self.assertAllEqual(tight_result,
multiple_of_10_result[:, :tight_result.shape[1]])
counts.append(tight_result.shape[0])
self.assertEqual(len(components), sum(counts))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,115 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.ignore_errors()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
class IgnoreErrorsTest(test_base.DatasetTestBase):
def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
dataset = (
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.check_numerics(x, "message")).apply(
error_ops.ignore_errors()))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testParallelMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
dataset = (
dataset_ops.Dataset.from_tensor_slices(components).map(
lambda x: array_ops.check_numerics(x, "message"),
num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testReadFileIgnoreError(self):
def write_string_to_file(value, filename):
with open(filename, "w") as f:
f.write(value)
filenames = [
os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
]
for filename in filenames:
write_string_to_file(filename, filename)
dataset = (
dataset_ops.Dataset.from_tensor_slices(filenames).map(
io_ops.read_file,
num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
# All of the files are present.
sess.run(init_op)
for filename in filenames:
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Delete one of the files.
os.remove(filenames[0])
# Attempting to read filenames[0] will fail, but ignore_errors()
# will catch the error.
sess.run(init_op)
for filename in filenames[1:]:
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,239 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.make_batched_features_dataset()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
class MakeBatchedFeaturesDatasetTest(
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
def testRead(self):
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from file 0.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
self.verify_records(
sess,
batch_size,
0,
num_epochs=num_epochs,
label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from file 1.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[1],
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
self.verify_records(
sess,
batch_size,
1,
num_epochs=num_epochs,
label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from both files.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
self.verify_records(
sess,
batch_size,
num_epochs=num_epochs,
label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from both files.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
self.verify_records(sess, batch_size, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
def testReadWithEquivalentDataset(self):
features = {
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
}
dataset = (
core_readers.TFRecordDataset(self.test_filenames)
.map(lambda x: parsing_ops.parse_single_example(x, features))
.repeat(10).batch(2))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = sess.run(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
self.assertAllEqual(record_batch, actual_batch["record"])
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testReadWithFusedShuffleRepeatDataset(self):
num_epochs = 5
total_records = num_epochs * self._num_records
for batch_size in [1, 2]:
# Test that shuffling with same seed produces the same result.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
shuffle_seed=5).make_one_shot_iterator().get_next()
outputs2 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
shuffle_seed=5).make_one_shot_iterator().get_next()
for _ in range(total_records // batch_size):
batch1 = self._run_actual_batch(outputs1, sess)
batch2 = self._run_actual_batch(outputs2, sess)
for i in range(len(batch1)):
self.assertAllEqual(batch1[i], batch2[i])
# Test that shuffling with different seeds produces a different order.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
shuffle_seed=5).make_one_shot_iterator().get_next()
outputs2 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
shuffle_seed=15).make_one_shot_iterator().get_next()
all_equal = True
for _ in range(total_records // batch_size):
batch1 = self._run_actual_batch(outputs1, sess)
batch2 = self._run_actual_batch(outputs2, sess)
for i in range(len(batch1)):
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal)
def testParallelReadersAndParsers(self):
num_epochs = 5
for batch_size in [1, 2]:
for reader_num_threads in [2, 4]:
for parser_num_threads in [2, 4]:
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads).make_one_shot_iterator(
).get_next()
self.verify_records(
sess,
batch_size,
num_epochs=num_epochs,
label_key_provided=True,
interleave_cycle_length=reader_num_threads)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads).make_one_shot_iterator(
).get_next()
self.verify_records(
sess,
batch_size,
num_epochs=num_epochs,
interleave_cycle_length=reader_num_threads)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
def testDropFinalBatch(self):
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
with ops.Graph().as_default():
# Basic test: read from file 0.
outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size,
drop_final_batch=True).make_one_shot_iterator().get_next()
for tensor in nest.flatten(outputs):
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature(
filenames=self.test_filenames[0],
label_key="label",
num_epochs=None,
batch_size=32)
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
nest.flatten(dataset.output_classes)):
if issubclass(clazz, ops.Tensor):
self.assertEqual(32, shape[0])
if __name__ == "__main__":
test.main()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
"""Tests for `tf.data.experimental.make_csv_dataset()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -23,226 +23,16 @@ import zlib
import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class ReadBatchFeaturesTest(
reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
def testRead(self):
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from file 0.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
self.verify_records(
sess,
batch_size,
0,
num_epochs=num_epochs,
label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from file 1.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[1],
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
self.verify_records(
sess,
batch_size,
1,
num_epochs=num_epochs,
label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from both files.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
self.verify_records(
sess,
batch_size,
num_epochs=num_epochs,
label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from both files.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
self.verify_records(sess, batch_size, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
def testReadWithEquivalentDataset(self):
features = {
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
}
dataset = (
core_readers.TFRecordDataset(self.test_filenames)
.map(lambda x: parsing_ops.parse_single_example(x, features))
.repeat(10).batch(2))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = sess.run(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
self.assertAllEqual(record_batch, actual_batch["record"])
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testReadWithFusedShuffleRepeatDataset(self):
num_epochs = 5
total_records = num_epochs * self._num_records
for batch_size in [1, 2]:
# Test that shuffling with same seed produces the same result.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
shuffle_seed=5).make_one_shot_iterator().get_next()
outputs2 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
shuffle_seed=5).make_one_shot_iterator().get_next()
for _ in range(total_records // batch_size):
batch1 = self._run_actual_batch(outputs1, sess)
batch2 = self._run_actual_batch(outputs2, sess)
for i in range(len(batch1)):
self.assertAllEqual(batch1[i], batch2[i])
# Test that shuffling with different seeds produces a different order.
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
shuffle_seed=5).make_one_shot_iterator().get_next()
outputs2 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
shuffle_seed=15).make_one_shot_iterator().get_next()
all_equal = True
for _ in range(total_records // batch_size):
batch1 = self._run_actual_batch(outputs1, sess)
batch2 = self._run_actual_batch(outputs2, sess)
for i in range(len(batch1)):
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal)
def testParallelReadersAndParsers(self):
num_epochs = 5
for batch_size in [1, 2]:
for reader_num_threads in [2, 4]:
for parser_num_threads in [2, 4]:
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads).make_one_shot_iterator(
).get_next()
self.verify_records(
sess,
batch_size,
num_epochs=num_epochs,
label_key_provided=True,
interleave_cycle_length=reader_num_threads)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads).make_one_shot_iterator(
).get_next()
self.verify_records(
sess,
batch_size,
num_epochs=num_epochs,
interleave_cycle_length=reader_num_threads)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
def testDropFinalBatch(self):
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
with ops.Graph().as_default():
# Basic test: read from file 0.
outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
label_key="label",
num_epochs=num_epochs,
batch_size=batch_size,
drop_final_batch=True).make_one_shot_iterator().get_next()
for tensor in nest.flatten(outputs):
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature(
filenames=self.test_filenames[0],
label_key="label",
num_epochs=None,
batch_size=32)
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
nest.flatten(dataset.output_classes)):
if issubclass(clazz, ops.Tensor):
self.assertEqual(32, shape[0])
class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
@ -866,218 +656,5 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
self.assertEqual(32, shape[0])
class MakeTFRecordDatasetTest(
reader_dataset_ops_test_base.TFRecordDatasetTestBase):
def _interleave(self, iterators, cycle_length):
pending_iterators = iterators
open_iterators = []
num_open = 0
for i in range(cycle_length):
if pending_iterators:
open_iterators.append(pending_iterators.pop(0))
num_open += 1
while num_open:
for i in range(min(cycle_length, len(open_iterators))):
if open_iterators[i] is None:
continue
try:
yield next(open_iterators[i])
except StopIteration:
if pending_iterators:
open_iterators[i] = pending_iterators.pop(0)
else:
open_iterators[i] = None
num_open -= 1
def _next_expected_batch(self,
file_indices,
batch_size,
num_epochs,
cycle_length,
drop_final_batch,
use_parser_fn):
def _next_record(file_indices):
for j in file_indices:
for i in range(self._num_records):
yield j, i
def _next_record_interleaved(file_indices, cycle_length):
return self._interleave([_next_record([i]) for i in file_indices],
cycle_length)
record_batch = []
batch_index = 0
for _ in range(num_epochs):
if cycle_length == 1:
next_records = _next_record(file_indices)
else:
next_records = _next_record_interleaved(file_indices, cycle_length)
for f, r in next_records:
record = self._record(f, r)
if use_parser_fn:
record = record[1:]
record_batch.append(record)
batch_index += 1
if len(record_batch) == batch_size:
yield record_batch
record_batch = []
batch_index = 0
if record_batch and not drop_final_batch:
yield record_batch
def _verify_records(self,
sess,
outputs,
batch_size,
file_index,
num_epochs,
interleave_cycle_length,
drop_final_batch,
use_parser_fn):
if file_index is not None:
file_indices = [file_index]
else:
file_indices = range(self._num_files)
for expected_batch in self._next_expected_batch(
file_indices, batch_size, num_epochs, interleave_cycle_length,
drop_final_batch, use_parser_fn):
actual_batch = sess.run(outputs)
self.assertAllEqual(expected_batch, actual_batch)
def _read_test(self, batch_size, num_epochs, file_index=None,
num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
if file_index is None:
file_pattern = self.test_filenames
else:
file_pattern = self.test_filenames[file_index]
if parser_fn:
fn = lambda x: string_ops.substr(x, 1, 999)
else:
fn = None
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
outputs = readers.make_tf_record_dataset(
file_pattern=file_pattern,
num_epochs=num_epochs,
batch_size=batch_size,
parser_fn=fn,
num_parallel_reads=num_parallel_reads,
drop_final_batch=drop_final_batch,
shuffle=False).make_one_shot_iterator().get_next()
self._verify_records(
sess, outputs, batch_size, file_index, num_epochs=num_epochs,
interleave_cycle_length=num_parallel_reads,
drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
with self.assertRaises(errors.OutOfRangeError):
sess.run(outputs)
def testRead(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
# Basic test: read from file 0.
self._read_test(batch_size, num_epochs, 0)
# Basic test: read from file 1.
self._read_test(batch_size, num_epochs, 1)
# Basic test: read from both files.
self._read_test(batch_size, num_epochs)
# Basic test: read from both files, with parallel reads.
self._read_test(batch_size, num_epochs, num_parallel_reads=8)
def testDropFinalBatch(self):
for batch_size in [1, 2, 10]:
for num_epochs in [1, 3]:
# Read from file 0.
self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
# Read from both files.
self._read_test(batch_size, num_epochs, drop_final_batch=True)
# Read from both files, with parallel reads.
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
drop_final_batch=True)
def testParserFn(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
for drop_final_batch in [False, True]:
self._read_test(batch_size, num_epochs, parser_fn=True,
drop_final_batch=drop_final_batch)
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
parser_fn=True, drop_final_batch=drop_final_batch)
def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
seed=None):
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
dataset = readers.make_tf_record_dataset(
file_pattern=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
num_parallel_reads=num_parallel_reads,
shuffle=True,
shuffle_seed=seed)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess.run(iterator.initializer)
first_batches = []
try:
while True:
first_batches.append(sess.run(next_element))
except errors.OutOfRangeError:
pass
sess.run(iterator.initializer)
second_batches = []
try:
while True:
second_batches.append(sess.run(next_element))
except errors.OutOfRangeError:
pass
self.assertEqual(len(first_batches), len(second_batches))
if seed is not None:
# if you set a seed, should get the same results
for i in range(len(first_batches)):
self.assertAllEqual(first_batches[i], second_batches[i])
expected = []
for f in range(self._num_files):
for r in range(self._num_records):
expected.extend([self._record(f, r)] * num_epochs)
for batches in (first_batches, second_batches):
actual = []
for b in batches:
actual.extend(b)
self.assertAllEqual(sorted(expected), sorted(actual))
def testShuffle(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
for num_parallel_reads in [1, 2]:
# Test that all expected elements are produced
self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
# Test that elements are produced in a consistent order if
# you specify a seed.
self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
seed=21345)
def testIndefiniteRepeatShapeInference(self):
dataset = readers.make_tf_record_dataset(
file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
for shape in nest.flatten(dataset.output_shapes):
self.assertEqual(32, shape[0])
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,243 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.make_tf_record_dataset()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class MakeTFRecordDatasetTest(
reader_dataset_ops_test_base.TFRecordDatasetTestBase):
def _interleave(self, iterators, cycle_length):
pending_iterators = iterators
open_iterators = []
num_open = 0
for i in range(cycle_length):
if pending_iterators:
open_iterators.append(pending_iterators.pop(0))
num_open += 1
while num_open:
for i in range(min(cycle_length, len(open_iterators))):
if open_iterators[i] is None:
continue
try:
yield next(open_iterators[i])
except StopIteration:
if pending_iterators:
open_iterators[i] = pending_iterators.pop(0)
else:
open_iterators[i] = None
num_open -= 1
def _next_expected_batch(self,
file_indices,
batch_size,
num_epochs,
cycle_length,
drop_final_batch,
use_parser_fn):
def _next_record(file_indices):
for j in file_indices:
for i in range(self._num_records):
yield j, i
def _next_record_interleaved(file_indices, cycle_length):
return self._interleave([_next_record([i]) for i in file_indices],
cycle_length)
record_batch = []
batch_index = 0
for _ in range(num_epochs):
if cycle_length == 1:
next_records = _next_record(file_indices)
else:
next_records = _next_record_interleaved(file_indices, cycle_length)
for f, r in next_records:
record = self._record(f, r)
if use_parser_fn:
record = record[1:]
record_batch.append(record)
batch_index += 1
if len(record_batch) == batch_size:
yield record_batch
record_batch = []
batch_index = 0
if record_batch and not drop_final_batch:
yield record_batch
def _verify_records(self,
sess,
outputs,
batch_size,
file_index,
num_epochs,
interleave_cycle_length,
drop_final_batch,
use_parser_fn):
if file_index is not None:
file_indices = [file_index]
else:
file_indices = range(self._num_files)
for expected_batch in self._next_expected_batch(
file_indices, batch_size, num_epochs, interleave_cycle_length,
drop_final_batch, use_parser_fn):
actual_batch = sess.run(outputs)
self.assertAllEqual(expected_batch, actual_batch)
def _read_test(self, batch_size, num_epochs, file_index=None,
num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
if file_index is None:
file_pattern = self.test_filenames
else:
file_pattern = self.test_filenames[file_index]
if parser_fn:
fn = lambda x: string_ops.substr(x, 1, 999)
else:
fn = None
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
outputs = readers.make_tf_record_dataset(
file_pattern=file_pattern,
num_epochs=num_epochs,
batch_size=batch_size,
parser_fn=fn,
num_parallel_reads=num_parallel_reads,
drop_final_batch=drop_final_batch,
shuffle=False).make_one_shot_iterator().get_next()
self._verify_records(
sess, outputs, batch_size, file_index, num_epochs=num_epochs,
interleave_cycle_length=num_parallel_reads,
drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
with self.assertRaises(errors.OutOfRangeError):
sess.run(outputs)
def testRead(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
# Basic test: read from file 0.
self._read_test(batch_size, num_epochs, 0)
# Basic test: read from file 1.
self._read_test(batch_size, num_epochs, 1)
# Basic test: read from both files.
self._read_test(batch_size, num_epochs)
# Basic test: read from both files, with parallel reads.
self._read_test(batch_size, num_epochs, num_parallel_reads=8)
def testDropFinalBatch(self):
for batch_size in [1, 2, 10]:
for num_epochs in [1, 3]:
# Read from file 0.
self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
# Read from both files.
self._read_test(batch_size, num_epochs, drop_final_batch=True)
# Read from both files, with parallel reads.
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
drop_final_batch=True)
def testParserFn(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
for drop_final_batch in [False, True]:
self._read_test(batch_size, num_epochs, parser_fn=True,
drop_final_batch=drop_final_batch)
self._read_test(batch_size, num_epochs, num_parallel_reads=8,
parser_fn=True, drop_final_batch=drop_final_batch)
def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
seed=None):
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
dataset = readers.make_tf_record_dataset(
file_pattern=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
num_parallel_reads=num_parallel_reads,
shuffle=True,
shuffle_seed=seed)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess.run(iterator.initializer)
first_batches = []
try:
while True:
first_batches.append(sess.run(next_element))
except errors.OutOfRangeError:
pass
sess.run(iterator.initializer)
second_batches = []
try:
while True:
second_batches.append(sess.run(next_element))
except errors.OutOfRangeError:
pass
self.assertEqual(len(first_batches), len(second_batches))
if seed is not None:
# if you set a seed, should get the same results
for i in range(len(first_batches)):
self.assertAllEqual(first_batches[i], second_batches[i])
expected = []
for f in range(self._num_files):
for r in range(self._num_records):
expected.extend([self._record(f, r)] * num_epochs)
for batches in (first_batches, second_batches):
actual = []
for b in batches:
actual.extend(b)
self.assertAllEqual(sorted(expected), sorted(actual))
def testShuffle(self):
for batch_size in [1, 2]:
for num_epochs in [1, 3]:
for num_parallel_reads in [1, 2]:
# Test that all expected elements are produced
self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
# Test that elements are produced in a consistent order if
# you specify a seed.
self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
seed=21345)
def testIndefiniteRepeatShapeInference(self):
dataset = readers.make_tf_record_dataset(
file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
for shape in nest.flatten(dataset.output_shapes):
self.assertEqual(32, shape[0])
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,337 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.map_and_batch()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Default", None, None),
("SequentialCalls", 1, None),
("ParallelCalls", 2, None),
("ParallelBatches", None, 10),
)
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset ->
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7))
count = array_ops.placeholder(dtypes.int64, shape=[])
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
batching.map_and_batch(
map_func=_map_fn,
batch_size=batch_size,
num_parallel_calls=num_parallel_calls,
num_parallel_batches=num_parallel_batches))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
with self.cached_session() as sess:
# Batch of a finite input, where the batch_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
num_batches = (28 * 7) // 14
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Batch of a finite input, where the batch_size does not
# divide the total number of elements.
sess.run(init_op, feed_dict={count: 14, batch_size: 8})
# We expect (num_batches - 1) full-sized batches.
num_batches = int(math.ceil((14 * 7) / 8))
for i in range(num_batches - 1):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
self.assertAllEqual(component[(i * 8 + j) % 7]**2,
result_component[j])
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Batch of an empty input should fail straight away.
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Empty batch should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
("Even", False),
("Uneven", True),
)
def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(
lambda x: array_ops.reshape(x * x, [1]),
batch_size=4,
drop_remainder=drop_remainder)).make_one_shot_iterator())
if drop_remainder:
self.assertEqual([4, 1], iterator.output_shapes.as_list())
else:
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
self.assertAllEqual([[64], [81]], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testMapAndBatchYieldsPartialBatch(self):
iterator = (dataset_ops.Dataset.range(10)
.apply(batching.map_and_batch(
lambda x: array_ops.reshape(x * x, [1]), 4))
.make_one_shot_iterator())
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testMapAndBatchParallelGetNext(self):
iterator = (dataset_ops.Dataset.range(50000)
.apply(batching.map_and_batch(lambda x: x, batch_size=100))
.make_one_shot_iterator())
elements = []
for _ in range(100):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(5):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
def testMapAndBatchParallelGetNextDropRemainder(self):
iterator = (
dataset_ops.Dataset.range(49999).apply(
batching.map_and_batch(
lambda x: x, batch_size=100, drop_remainder=True))
.make_one_shot_iterator())
elements = []
for _ in range(100):
elements.append(iterator.get_next())
with self.cached_session() as sess:
for i in range(4):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
for j in range(100):
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(elements)
def testMapAndBatchSparse(self):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])
iterator = dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
dense_shape=[5, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testMapAndBatchFails(self):
"""Test a dataset that maps a TF function across its input elements."""
dataset = dataset_ops.Dataset.from_tensors(
array_ops.check_numerics(
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
def testMapAndBatchShapeMismatch(self):
"""Test a dataset that maps a TF function across its input elements."""
def generator():
yield [1]
yield [2]
yield [3]
yield [[4, 5, 6]]
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int32)
batch_size = 4
iterator = (
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
sess.run(get_next)
def testMapAndBatchImplicitDispose(self):
# Tests whether a map and batch dataset will be cleaned up correctly when
# the pipeline does not run it until exhaustion.
# The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
# MapAndBatchDataset(f=square_3, batch_size=100).
components = (np.arange(1000),
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
np.array(37.0) * np.arange(1000))
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
dataset = dataset.prefetch(5)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
for _ in range(3):
sess.run(get_next)
@parameterized.named_parameters(
("1", 0),
("2", 5),
("3", 10),
("4", 90),
("5", 95),
("6", 99),
)
def testMapAndBatchOutOfRangeError(self, threshold):
def raising_py_fn(i):
if i >= threshold:
raise StopIteration()
else:
return i
iterator = (
dataset_ops.Dataset.range(100).apply(
batching.map_and_batch(
lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
batch_size=10)).make_one_shot_iterator())
get_next = iterator.get_next()
with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
self.assertAllEqual(
[threshold // 10 * 10 + j for j in range(threshold % 10)],
sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@parameterized.named_parameters(
("1", False, dtypes.bool),
("2", -42, dtypes.int8),
("3", -42, dtypes.int16),
("4", -42, dtypes.int32),
("5", -42, dtypes.int64),
("6", 42, dtypes.uint8),
("7", 42, dtypes.uint16),
("8", 42.0, dtypes.float16),
("9", 42.0, dtypes.float32),
("10", 42.0, dtypes.float64),
("11", b"hello", dtypes.string),
)
def testMapAndBatchTypes(self, element, dtype):
def gen():
yield element
dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
batching.map_and_batch(lambda x: x, batch_size=10))
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
if __name__ == "__main__":
test.main()

View File

@ -235,6 +235,18 @@ class MapDefunTest(test_base.DatasetTestBase):
sess.close()
thread.join()
def testMapDefunWithCapturedInputs(self):
c = constant_op.constant(2)
@function.Defun(dtypes.int32)
def fn(x):
return x + c
x = constant_op.constant([1, 2, 3, 4])
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
expected = x + c
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
class MapDefunBenchmark(test.Benchmark):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline statistics gathering ops."""
"""Tests for the private `override_threadpool()` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -32,8 +32,8 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
parameterized.TestCase):
class OverrideThreadpoolTest(test_base.DatasetTestBase,
parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
"""Tests for `tf.data.experimental.parallel_interleave()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -37,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
class ParallelInterleaveTest(test_base.DatasetTestBase):
def setUp(self):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.parsing_ops."""
"""Tests for `tf.data.experimental.parse_example_dataset()."""
from __future__ import absolute_import
from __future__ import division
@ -73,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
i += 1
class ParseExampleTest(test_base.DatasetTestBase):
class ParseExampleDatasetTest(test_base.DatasetTestBase):
def _test(self,
input_tensor,

View File

@ -0,0 +1,234 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.prefetch_to_device()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class PrefetchToDeviceTest(test_base.DatasetTestBase):
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device(
"/job:localhost/replica:0/task:0/device:CPU:0"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element["a"].dtype)
self.assertEqual([], next_element["a"].shape)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchSparseTensorsToDevice(self):
def make_tensor(i):
return sparse_tensor.SparseTensorValue(
indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = sess.run(next_element)
self.assertAllEqual([i], actual.values)
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_initializable_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
if __name__ == "__main__":
test.main()

View File

@ -63,11 +63,11 @@ class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
return filenames
class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
class MakeBatchedFeaturesDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing `make_batched_feature_dataset`."""
def setUp(self):
super(ReadBatchFeaturesTestBase, self).setUp()
super(MakeBatchedFeaturesDatasetTestBase, self).setUp()
self._num_files = 2
self._num_records = 7
self.test_filenames = self._createFiles()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
"""Tests for `tf.data.experimental.rejection_resample()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -58,7 +58,7 @@ def _time_resampling(
return end_time - start_time
class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
"""Tests for the private `_RestructuredDataset` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -26,7 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class DatasetConstructorTest(test_base.DatasetTestBase):
class RestructuredDatasetTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
"""Tests for `tf.data.experimental.scan()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -34,7 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ScanDatasetTest(test_base.DatasetTestBase):
class ScanTest(test_base.DatasetTestBase):
def _counting_dataset(self, start, scan_fn):
return dataset_ops.Dataset.from_tensors(0).repeat().apply(

View File

@ -69,6 +69,26 @@ py_test(
],
)
py_test(
name = "checkpoint_input_pipeline_hook_test",
size = "small",
srcs = ["checkpoint_input_pipeline_hook_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/data/experimental/ops:iterator_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/estimator:estimator_py",
],
)
py_test(
name = "concatenate_dataset_serialization_test",
size = "small",
@ -580,7 +600,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python/data/experimental/kernel_tests:sql_dataset_op_test_base",
"//tensorflow/python/data/experimental/kernel_tests:sql_dataset_test_base",
"//tensorflow/python/data/experimental/ops:readers",
],
)

View File

@ -23,7 +23,7 @@ from tensorflow.python.platform import test
class ParseExampleDatasetSerializationTest(
reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
dataset_serialization_test_base.DatasetSerializationTestBase):
def ParseExampleDataset(self, num_repeat, batch_size):

View File

@ -19,7 +19,7 @@ from __future__ import print_function
import os
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.framework import dtypes
@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class SqlDatasetSerializationTest(
sql_dataset_op_test_base.SqlDatasetTestBase,
sql_dataset_test_base.SqlDatasetTestBase,
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_dataset(self, num_repeats):

View File

@ -1,85 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration test for dataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
class SerializationIntegrationTest(test.TestCase):
def _build_input_pipeline(self, name, num_outputs):
with ops.name_scope(name):
ds = dataset_ops.Dataset.range(num_outputs).shuffle(
10, reshuffle_each_iteration=False).prefetch(10)
iterator = ds.make_initializable_iterator()
saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
return iterator.initializer, iterator.get_next()
def _build_graph(self, num_pipelines, num_outputs):
init_ops = []
get_next_ops = []
for i in range(num_pipelines):
name = "input_pipeline_%d" % i
init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
init_ops.append(init_op)
get_next_ops.append(get_next_op)
saver = saver_lib.Saver()
return init_ops, get_next_ops, saver
def _ckpt_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
def testConcurrentSaves(self):
num_pipelines = 100
num_outputs = 100
break_point = 10
all_outputs = [[] for _ in range(num_pipelines)]
with ops.Graph().as_default() as g:
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
num_outputs)
with self.session(graph=g) as sess:
sess.run(init_ops)
for _ in range(break_point):
output = sess.run(get_next_ops)
for i in range(num_pipelines):
all_outputs[i].append(output[i])
saver.save(sess, self._ckpt_path())
with ops.Graph().as_default() as g:
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
num_outputs)
with self.session(graph=g) as sess:
saver.restore(sess, self._ckpt_path())
for _ in range(num_outputs - break_point):
output = sess.run(get_next_ops)
for i in range(num_pipelines):
all_outputs[i].append(output[i])
for output in all_outputs:
self.assertSequenceEqual(sorted(output), range(num_outputs))
if __name__ == "__main__":
test.main()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
"""Tests for `tf.data.experimental.shuffle_and_repeat()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for experimental sql input op."""
"""Tests for `tf.data.experimental.SqlDataset`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that SqlDataset can read from a database table.
def testReadResultSet(self):

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for testing SqlDataset."""
"""Base class for testing `tf.data.experimental.SqlDataset`."""
from __future__ import absolute_import
from __future__ import division

View File

@ -280,7 +280,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
class FeatureStatsDatasetTest(
stats_dataset_test_base.StatsDatasetTestBase,
reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
def testFeaturesStats(self):
num_epochs = 5

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
"""Tests for `tf.data.experimental.TFRecordWriter`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -0,0 +1,300 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.unbatch()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl.testing import parameterized
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
def testUnbatchWithUnknownRankInput(self):
placeholder = array_ops.placeholder(dtypes.int32)
dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
batching.unbatch())
iterator = dataset.make_initializable_iterator()
next_elem = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
for i in range(4):
self.assertEqual(i, sess.run(next_elem))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_elem)
def testUnbatchScalarDataset(self):
data = tuple([math_ops.range(10) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
expected_types = (dtypes.int32,) * 3
data = data.batch(2)
self.assertEqual(expected_types, data.output_types)
data = data.apply(batching.unbatch())
self.assertEqual(expected_types, data.output_types)
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testUnbatchDatasetWithStrings(self):
data = tuple([math_ops.range(10) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
data = data.batch(2)
self.assertEqual(expected_types, data.output_types)
data = data.apply(batching.unbatch())
self.assertEqual(expected_types, data.output_types)
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testUnbatchDatasetWithSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
indices=[[i, i] for i in range(10)],
values=list(range(10)),
dense_shape=[10, 10])
data = dataset_ops.Dataset.from_tensors(st)
data = data.apply(batching.unbatch())
data = data.batch(5)
data = data.apply(batching.unbatch())
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testUnbatchDatasetWithDenseAndSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
indices=[[i, i] for i in range(10)],
values=list(range(10)),
dense_shape=[10, 10])
data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
data = data.apply(batching.unbatch())
data = data.batch(5)
data = data.apply(batching.unbatch())
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
self.assertEqual([i], st_row.indices)
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testUnbatchSingleElementTupleDataset(self):
data = tuple([(math_ops.range(10),) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
expected_types = ((dtypes.int32,),) * 3
data = data.batch(2)
self.assertEqual(expected_types, data.output_types)
data = data.apply(batching.unbatch())
self.assertEqual(expected_types, data.output_types)
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testUnbatchMultiElementTupleDataset(self):
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
array_ops.fill([10], "hi")) for i in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
expected_types = ((dtypes.int32, dtypes.string),) * 3
data = data.batch(2)
self.assertAllEqual(expected_types, data.output_types)
data = data.apply(batching.unbatch())
self.assertAllEqual(expected_types, data.output_types)
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
def testUnbatchEmpty(self):
data = dataset_ops.Dataset.from_tensors(
(constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
constant_op.constant([], shape=[0, 4, 0])))
data = data.apply(batching.unbatch())
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testUnbatchStaticShapeMismatch(self):
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
np.arange(9)))
with self.assertRaises(ValueError):
data.apply(batching.unbatch())
def testUnbatchDynamicShapeMismatch(self):
ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
ph2 = array_ops.placeholder(dtypes.int32, shape=None)
data = dataset_ops.Dataset.from_tensors((ph1, ph2))
data = data.apply(batching.unbatch())
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
# Mismatch in the 0th dimension.
sess.run(
iterator.initializer,
feed_dict={
ph1: np.arange(7).astype(np.int32),
ph2: np.arange(8).astype(np.int32)
})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(next_element)
# No 0th dimension (i.e. scalar value) for one component.
sess.run(
iterator.initializer,
feed_dict={
ph1: np.arange(7).astype(np.int32),
ph2: 7
})
with self.assertRaises(errors.InvalidArgumentError):
sess.run(next_element)
class UnbatchBenchmark(test.Benchmark):
def benchmarkNativeUnbatch(self):
batch_sizes = [1, 2, 5, 10, 20, 50]
elems_per_trial = 10000
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
dataset = dataset.batch(batch_size_placeholder)
dataset = dataset.apply(batching.unbatch())
dataset = dataset.skip(elems_per_trial)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with session.Session() as sess:
for batch_size in batch_sizes:
deltas = []
for _ in range(5):
sess.run(
iterator.initializer,
feed_dict={batch_size_placeholder: batch_size})
start = time.time()
sess.run(next_element.op)
end = time.time()
deltas.append((end - start) / elems_per_trial)
median_wall_time = np.median(deltas)
print("Unbatch (native) batch size: %d Median wall time per element:"
" %f microseconds" % (batch_size, median_wall_time * 1e6))
self.report_benchmark(
iters=10000,
wall_time=median_wall_time,
name="benchmark_unbatch_dataset_native_batch_size_%d" %
batch_size)
# Include a benchmark of the previous `unbatch()` implementation that uses
# a composition of more primitive ops. Eventually we'd hope to generate code
# that is as good in both cases.
def benchmarkOldUnbatchImplementation(self):
batch_sizes = [1, 2, 5, 10, 20, 50]
elems_per_trial = 10000
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
dataset = dataset.batch(batch_size_placeholder)
dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
dataset = dataset.skip(elems_per_trial)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with session.Session() as sess:
for batch_size in batch_sizes:
deltas = []
for _ in range(5):
sess.run(
iterator.initializer,
feed_dict={batch_size_placeholder: batch_size})
start = time.time()
sess.run(next_element.op)
end = time.time()
deltas.append((end - start) / elems_per_trial)
median_wall_time = np.median(deltas)
print("Unbatch (unfused) batch size: %d Median wall time per element:"
" %f microseconds" % (batch_size, median_wall_time * 1e6))
self.report_benchmark(
iters=10000,
wall_time=median_wall_time,
name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
batch_size)
if __name__ == "__main__":
test.main()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
"""Tests for `tf.data.experimental.unique()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -26,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
class UniqueDatasetTest(test_base.DatasetTestBase):
class UniqueTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.

View File

@ -47,10 +47,12 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
if not isinstance(elems, list):
raise ValueError("`elems` must be a list of tensors.")
if not isinstance(output_dtypes, list):
raise ValueError("`output_dtypes` must be a list of tensors.")
raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.")
if not isinstance(output_shapes, list):
raise ValueError("`output_shapes` must be a list of tensors.")
raise ValueError("`output_shapes` must be a list of `tf.TensorShape` "
"objects.")
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
return gen_dataset_ops.map_defun(elems, fn.captured_inputs, output_dtypes,
output_shapes, fn)

Some files were not shown because too many files have changed in this diff Show More