Merge pull request #9957 from av8ramit/branch_155393864

Branch 155393864
This commit is contained in:
Amit Patankar 2017-05-17 08:25:46 -07:00 committed by GitHub
commit ffd1ed2df7
163 changed files with 15797 additions and 2107 deletions

View File

@ -226,6 +226,10 @@ filegroup(
"//tensorflow/contrib/copy_graph:all_files",
"//tensorflow/contrib/crf:all_files",
"//tensorflow/contrib/cudnn_rnn:all_files",
"//tensorflow/contrib/data:all_files",
"//tensorflow/contrib/data/python/framework:all_files",
"//tensorflow/contrib/data/python/kernel_tests:all_files",
"//tensorflow/contrib/data/python/ops:all_files",
"//tensorflow/contrib/distributions:all_files",
"//tensorflow/contrib/factorization:all_files",
"//tensorflow/contrib/factorization/kernels:all_files",

View File

@ -738,7 +738,8 @@ tensorflow::string OutputName(const TF_Output& output) {
const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
const char* attr_name,
TF_Status* status) {
const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
const tensorflow::AttrValue* attr =
tensorflow::AttrSlice(oper->node.def()).Find(attr_name);
if (attr == nullptr) {
status->status =
InvalidArgument("Operation has no attr named '", attr_name, "'.");
@ -1134,7 +1135,7 @@ const char* TF_OperationOpType(TF_Operation* oper) {
}
const char* TF_OperationDevice(TF_Operation* oper) {
return oper->node.requested_device().c_str();
return oper->node.def().device().c_str();
}
int TF_OperationNumOutputs(TF_Operation* oper) {
@ -1149,8 +1150,8 @@ TF_DataType TF_OperationOutputType(TF_Output oper_out) {
int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
TF_Status* status) {
NameRangeMap name_ranges;
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
nullptr, &name_ranges);
if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
@ -1171,8 +1172,8 @@ TF_DataType TF_OperationInputType(TF_Input oper_in) {
int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
TF_Status* status) {
NameRangeMap name_ranges;
status->status =
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
&name_ranges, nullptr);
if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name);
if (iter == name_ranges.end()) {
@ -1410,27 +1411,26 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
}
}
#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
void func(TF_Operation* oper, const char* attr_name, c_type* value, \
TF_Status* status) { \
cpp_type v; \
status->status = \
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
*value = static_cast<c_type>(v); \
} \
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
int max_values, TF_Status* status) { \
const auto* attr = GetAttrValue(oper, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
const auto len = std::min(max_values, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
values[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
void func(TF_Operation* oper, const char* attr_name, c_type* value, \
TF_Status* status) { \
cpp_type v; \
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &v); \
*value = static_cast<c_type>(v); \
} \
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
int max_values, TF_Status* status) { \
const auto* attr = GetAttrValue(oper, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
const auto len = std::min(max_values, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
values[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
}
DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
@ -1441,8 +1441,7 @@ DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
int64_t* value, int num_dims, TF_Status* status) {
PartialTensorShape shape;
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shape);
if (!status->status.ok()) return;
auto len = std::min(shape.dims(), num_dims);
for (int i = 0; i < len; ++i) {
@ -1456,7 +1455,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
int storage_size, TF_Status* status) {
std::vector<PartialTensorShape> shapes;
status->status =
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shapes);
if (!status->status.ok()) return;
auto len = std::min(static_cast<int>(shapes.size()), max_values);
int64_t* p = storage;
@ -1523,7 +1522,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
TF_Tensor** value, TF_Status* status) {
*value = nullptr;
Tensor t;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &t);
if (!status->status.ok()) return;
*value = new TF_Tensor{static_cast<TF_DataType>(t.dtype()), t.shape(),
tensorflow::TensorCApi::Buffer(t)};
@ -1534,7 +1533,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
TF_Tensor** values, int max_values,
TF_Status* status) {
std::vector<Tensor> ts;
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &ts);
if (!status->status.ok()) return;
const auto len = std::min(max_values, static_cast<int>(ts.size()));
for (int i = 0; i < len; ++i) {

View File

@ -740,10 +740,11 @@ void OpInfo::GetOutput(string* out) const {
return;
}
strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n");
strings::StrAppend(out,
" ::tensorflow::Status _status_ = "
"::tensorflow::NameRangesForNode(*ret, ret->op_def(), "
"nullptr, &_outputs_range);\n");
strings::StrAppend(
out,
" ::tensorflow::Status _status_ = "
"::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), "
"nullptr, &_outputs_range);\n");
strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str,
".UpdateStatus(_status_);\n", " return;\n");
strings::StrAppend(out, " }\n\n");

View File

@ -35,8 +35,8 @@ Output Linear(const Scope& scope, Input x, Input w, Input b) {
void GetColocationConstraints(const Output& tensor,
std::vector<string>* constraints) {
constraints->clear();
TF_EXPECT_OK(GetNodeAttr(tensor.op().node()->attrs(), kColocationAttrName,
constraints));
TF_EXPECT_OK(
GetNodeAttr(tensor.op().node()->def(), kColocationAttrName, constraints));
}
} // namespace
@ -159,11 +159,11 @@ TEST(CCOpTest, KernelLabel) {
Scope root = Scope::NewRootScope();
auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f);
TF_EXPECT_OK(root.status());
AttrSlice attrs = add.z.op().node()->attrs();
const auto* kernel_attr = attrs.Find("_kernel");
ASSERT_TRUE(kernel_attr);
TF_EXPECT_OK(AttrValueHasType(*kernel_attr, "string"));
EXPECT_EQ(kernel_attr->s(), "AddWithKernelLabel");
const auto& attrs = add.z.op().node()->def().attr();
ASSERT_TRUE(attrs.find("_kernel") != attrs.end());
auto kernel_attr = attrs.find("_kernel")->second;
TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string"));
EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel");
}
TEST(CCOpTest, ColocateWith) {
@ -190,7 +190,8 @@ TEST(CCOpTest, ColocateWith) {
Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4);
auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7);
EXPECT_FALSE(c6.op().node()->attrs().Find("_class"));
const auto& attrs = c6.op().node()->def().attr();
EXPECT_TRUE(attrs.find("_class") == attrs.end());
}
TEST(CCOpTest, TemplatedConst) {

View File

@ -271,9 +271,9 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
std::unordered_set<string> current_constraints(colocation_constraints_);
const AttrSlice attrs = colocate_with_op.node()->attrs();
const NodeDef& node_def = colocate_with_op.node()->def();
std::vector<string> node_constraints;
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
if (GetNodeAttr(node_def, kColocationAttrName, &node_constraints).ok()) {
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (s.Consume(kColocationGroupPrefix)) {

View File

@ -43,9 +43,9 @@ Status PackGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int N;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N));
int axis;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
grad_outputs->reserve(N);
auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
@ -60,7 +60,7 @@ Status UnpackGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int axis;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
return scope.status();
}
@ -162,7 +162,7 @@ Status CheckNumericsGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
string message;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "message", &message));
string err_msg = strings::StrCat(
"Not a number (NaN) or infinity (Inf) values detected in gradient. ",
message);
@ -215,9 +215,9 @@ Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
std::vector<Output>* grad_outputs) {
auto seq_lengths = op.input(1);
int batch_dim;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim));
int seq_dim;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim));
grad_outputs->push_back(
ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
ReverseSequence::BatchDim(batch_dim)));
@ -267,8 +267,7 @@ Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int block_size;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
grad_outputs->push_back(
BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
grad_outputs->push_back(NoGradient());
@ -291,8 +290,7 @@ Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int block_size;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
grad_outputs->push_back(
SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
grad_outputs->push_back(NoGradient());
@ -315,8 +313,7 @@ Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int block_size;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
return scope.status();
}
@ -326,8 +323,7 @@ Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
int block_size;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
return scope.status();
}
@ -337,7 +333,7 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
string mode;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
scope, grad_inputs[0], op.input(1), mode));
grad_outputs->push_back(NoGradient());
@ -350,7 +346,7 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
string mode;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
grad_outputs->push_back(NoGradient());
return scope.status();

View File

@ -350,7 +350,7 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op,
const string& attr_adj_x, const string& attr_adj_y,
std::vector<Output>* grad_outputs) {
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
return errors::Unimplemented(
"MatMul gradient for complex data type is not supported yet.");
@ -358,10 +358,8 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op,
bool ta;
bool tb;
TF_RETURN_IF_ERROR(
GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
TF_RETURN_IF_ERROR(
GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
if (!ta && !tb) {
return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),

View File

@ -28,9 +28,9 @@ void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
TensorShape shape) {
EXPECT_TRUE(n->IsConstant());
Tensor tensor;
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor));
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
DataType dtype;
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
EXPECT_EQ(tensor.dtype(), dtype);
test::ExpectTensorEqual<T>(tensor, test::AsTensor(values, shape));
}
@ -39,9 +39,9 @@ void ExpectTypeAndShape(const Node* n, DataType expected_dtype,
TensorShape expected_shape) {
EXPECT_TRUE(n->IsConstant());
Tensor tensor;
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor));
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
DataType dtype;
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
EXPECT_EQ(dtype, expected_dtype);
EXPECT_EQ(expected_shape, TensorShape(tensor.shape()));
}

View File

@ -203,14 +203,14 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config,
for (const Node* n : graph->nodes()) {
if (n->type_string() == kArgOp) {
string feed_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id));
if (missing_feeds.erase(feed_id) == 0) {
return errors::Aborted(kArgOp,
" node found with unknown feed id: ", feed_id);
}
} else if (n->type_string() == kRetvalOp) {
string fetch_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id));
if (missing_fetches.erase(fetch_id) == 0) {
return errors::Aborted(kRetvalOp,
" node found with unknown fetch id: ", fetch_id);
@ -234,7 +234,7 @@ Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
for (Node* n : graph.nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
auto insert_result = indexed_arg_nodes.insert({index, n});
if (!insert_result.second) {
const Node* dup = insert_result.first->second;
@ -264,9 +264,9 @@ Status CreateXlaArgs(const Graph& graph,
for (const Node* node : arg_nodes) {
XlaCompiler::Argument arg;
arg.kind = XlaCompiler::Argument::kParameter;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name));
xla_args->push_back(arg);
}
return Status::OK();

View File

@ -66,9 +66,9 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
int num_constant_args, num_resource_args;
TF_RETURN_IF_ERROR(
GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args));
TF_RETURN_IF_ERROR(
GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
GetNodeAttr(node->def(), kXlaNumResourceArgsAttr, &num_resource_args));
if (num_constant_args < 0 || num_resource_args < 0 ||
num_constant_args + num_resource_args > node->num_inputs()) {
@ -88,7 +88,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
Node* launch_node;
TF_RETURN_IF_ERROR(BuildLaunchNode(
graph->NewName(node->name()), node->type_string(), node->def().attr(),
node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
node->def().device(), const_dtypes, num_resource_args, arg_dtypes,
node->output_types(), graph, &launch_node));
launch_node->set_assigned_device_name(node->assigned_device_name());
@ -173,8 +173,7 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
FunctionLibraryRuntime::Handle handle;
// If ndef is not instantiable, e.g., the function does not exist,
// simply bail out.
TF_RETURN_IF_ERROR(
flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle));
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CHECK(fbody); // Can't be nullptr since we just instantiated it.
std::vector<bool> const_args(fbody->arg_types.size());

View File

@ -165,7 +165,7 @@ static const char* const kRetValOp = "_Retval";
// none.
string Encapsulator::GetFunctionNameAttr(Node const* node) const {
string attr;
if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) {
if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) {
attr.clear();
}
return attr;
@ -195,7 +195,7 @@ Status Encapsulator::SplitIntoSubgraphs() {
// Check the device matches any existing device.
string device = node->assigned_device_name().empty()
? node->requested_device()
? node->def().device()
: node->assigned_device_name();
if (subgraph.device.empty()) {
@ -593,7 +593,7 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
for (Node* n : graph.nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
if (index < 0 || index >= types->size()) {
return errors::InvalidArgument("Invalid argument number");
}
@ -610,7 +610,7 @@ static Status RenumberArguments(Graph* graph,
for (Node* n : graph->nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
if (index < 0 || index >= permutation.size()) {
return errors::InvalidArgument("Invalid argument number");
}
@ -713,7 +713,7 @@ Status EncapsulateSubgraphsPass::Run(
bool IsXlaCompiledKernel(const Node& node) {
bool is_compiled = false;
bool has_compilation_attr =
GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
is_compiled;
return has_compilation_attr ? is_compiled : false;
}

View File

@ -126,8 +126,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
if (node->type_string() == kArgOp) {
int index;
DataType type;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index));
while (fdef->signature().input_arg_size() <= index) {
fdef->mutable_signature()->add_input_arg();
}
@ -143,8 +143,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
if (node->type_string() == kRetValOp) {
int index;
DataType type;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index));
while (fdef->signature().output_arg_size() <= index) {
fdef->mutable_signature()->add_output_arg();
}
@ -161,7 +161,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
}
NodeDef* node_def = fdef->add_node_def();
*node_def = node->def();
node_def->CopyFrom(node->def());
node_def->set_name(node_names.Uniquify(node->name()));
// Reset input names based on graph rather than the NodeDef.
@ -203,8 +203,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
// Populate tensor_renaming.
NameRangeMap output_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr,
&output_ranges));
for (const auto& output : output_ranges) {
for (int i = output.second.first; i < output.second.second; ++i) {
const string tensor_name = strings::StrCat(

View File

@ -55,7 +55,7 @@ XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx)
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
function_ = *func;
VLOG(1) << "XlaDeviceLaunch created function="
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
<< Canonicalize(function_.name(), function_.attr());
DataTypeVector constant_types;
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
num_constant_args_ = constant_types.size();
@ -81,7 +81,7 @@ std::vector<OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaDeviceLaunch::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
<< Canonicalize(function_.name(), function_.attr());
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();

View File

@ -186,7 +186,7 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaLocalLaunchOp::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
<< Canonicalize(function_.name(), function_.attr());
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
ResourceMgr* rm = ctx->resource_manager();

View File

@ -56,18 +56,18 @@ bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime);
// Tests whether 'while_node' is a completely compilable loop.
// Tests whether 'while_def' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
bool IsCompilableWhile(const Node& while_node,
bool IsCompilableWhile(const NodeDef& while_def,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
VLOG(2) << "Loop marking: " << while_node.type_string();
VLOG(2) << "Loop marking: " << while_def.op();
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
status = GetNodeAttr(while_def, "cond", &name_attr);
if (!status.ok()) {
VLOG(2) << "Missing 'cond' attribute on While node.";
return false;
@ -80,7 +80,7 @@ bool IsCompilableWhile(const Node& while_node,
VLOG(2) << "Can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
status = GetNodeAttr(while_def, "body", &name_attr);
if (!status.ok()) {
VLOG(2) << "Missing 'body' attribute on While node.";
return false;
@ -112,7 +112,7 @@ bool IsCompilableCall(const NodeDef& call_def,
FunctionLibraryRuntime::Handle handle;
Status status =
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle);
if (!status.ok()) {
VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
return false;
@ -134,11 +134,11 @@ bool IsCompilableCall(const NodeDef& call_def,
for (Node* node : fbody->graph->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
continue;
if (node->type_string() == "While") {
if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue;
if (node->def().op() == "While") {
// Handle functional While loop (not in open source build).
return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime);
return IsCompilableWhile(node->def(), jit_device_type, depth + 1,
lib_runtime);
}
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
@ -192,16 +192,17 @@ Status FindCompilationCandidates(
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
<< ": " << node->type_string();
<< ": " << node->def().op();
continue;
}
if (!registration->compile_resource_ops && HasResourceArgument(*node)) {
VLOG(2) << "Compilation rejected node: resource argument " << node->name()
<< ": " << node->type_string();
<< ": " << node->def().op();
continue;
}
if (node->type_string() == "While" &&
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) {
if (node->def().op() == "While" &&
!IsCompilableWhile(node->def(), jit_device_type, 0,
lib_runtime.get())) {
continue;
}
candidates->insert(node);
@ -318,10 +319,10 @@ Status MarkForCompilationPass::Run(
// If there is a _XlaCompile annotation, use its value.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile);
if (status.ok()) return compile;
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile);
if (status.ok()) return compile;
// Otherwise use the value of global_jit_level.
@ -484,8 +485,8 @@ Status MarkForCompilationPass::RunImpl(
// all nodes marked with _XlaCompile=true to also have a
// _XlaScope property set (and raise an error otherwise); but
// for now we don't do this.
if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() &&
if (GetNodeAttr(node_from->def(), kXlaScopeAttr, &from_scope).ok() &&
GetNodeAttr(node_to->def(), kXlaScopeAttr, &to_scope).ok() &&
from_scope != to_scope) {
continue;
}
@ -540,9 +541,10 @@ Status MarkForCompilationPass::RunImpl(
// Compile if the user marked this node _XlaCompile=true
bool compile_attr = false;
bool marked_for_compilation = false;
if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) {
marked_for_compilation = compile_attr;
} else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr)
} else if (options.flib_def
->GetAttr(n->def(), kXlaCompileAttr, &compile_attr)
.ok()) {
marked_for_compilation = compile_attr;
}

View File

@ -57,7 +57,7 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
std::unordered_map<string, string> ids;
for (Node* node : graph.nodes()) {
string cluster;
if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) {
CHECK(!cluster.empty());
ids[node->name()] = cluster;
}

View File

@ -95,7 +95,7 @@ Status XlaCompilationCache::BuildSignature(
const NameAttrList& function, int num_constant_args,
const std::vector<OptionalTensor>& variable_args, OpKernelContext* ctx,
Signature* signature) {
signature->name = Canonicalize(function.name(), AttrSlice(&function.attr()));
signature->name = Canonicalize(function.name(), function.attr());
signature->arg_values.resize(num_constant_args);
signature->arg_types.reserve(ctx->num_inputs() - num_constant_args);

View File

@ -108,7 +108,7 @@ Status BackwardsConstAnalysis(const Graph& g,
if (must_be_const.find(node) != must_be_const.end()) {
if (node->type_string() == "_Arg") {
int index;
status = GetNodeAttr(node->attrs(), "index", &index);
status = GetNodeAttr(node->def(), "index", &index);
if (!status.ok()) return;
compile_time_const_args->at(index) = true;
return;
@ -124,8 +124,8 @@ Status BackwardsConstAnalysis(const Graph& g,
if (range.first == range.second) return;
NameRangeMap input_name_ranges;
status =
NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr);
status = NameRangesForNode(node->def(), node->op_def(), &input_name_ranges,
nullptr);
if (!status.ok()) return;
for (auto it = range.first; it != range.second; ++it) {

View File

@ -68,8 +68,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
done);
OP_REQUIRES_OK_ASYNC(
ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_),
done);
ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done);
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();

View File

@ -106,8 +106,7 @@ Status XlaCompiler::CompileFunction(
const XlaCompiler::CompileOptions& options, const NameAttrList& function,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult* result) {
const string function_id =
Canonicalize(function.name(), AttrSlice(&function.attr()));
const string function_id = Canonicalize(function.name(), function.attr());
VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
auto it = cache_.find({function_id, args});
@ -117,8 +116,8 @@ Status XlaCompiler::CompileFunction(
}
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flib_runtime_->Instantiate(
function.name(), AttrSlice(&function.attr()), &handle));
TF_RETURN_IF_ERROR(
flib_runtime_->Instantiate(function.name(), function.attr(), &handle));
const FunctionBody* fbody = flib_runtime_->GetFunctionBody(handle);
CHECK(fbody);

View File

@ -1,4 +1,3 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@ -168,6 +167,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleReverse(HloInstruction* reverse,
HloInstruction* operand) override;
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
Status HandleDynamicSlice(HloInstruction* slice, HloInstruction* operand,
HloInstruction* start_indices) override;
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
HloInstruction* operand,
HloInstruction* update,
@ -1025,6 +1026,15 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice,
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
HloInstruction* dynamic_slice, HloInstruction* operand,
HloInstruction* start_indices) {
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
return ReplaceInstruction(dynamic_slice, operand);
}
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice, HloInstruction* operand,
HloInstruction* update, HloInstruction* start_indices) {

View File

@ -196,9 +196,9 @@ class DfsHloVisitor {
tensorflow::StringPiece custom_call_target) = 0;
virtual Status HandleSlice(HloInstruction* slice,
HloInstruction* operand) = 0;
virtual Status HandleDynamicSlice(
HloInstruction* slice,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
virtual Status HandleDynamicSlice(HloInstruction* dynamic_slice,
HloInstruction* operand,
HloInstruction* start_indices) = 0;
virtual Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
HloInstruction* operand,
HloInstruction* update,

View File

@ -134,10 +134,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
HloInstruction* /*operand*/) override {
return DefaultAction(slice);
}
Status HandleDynamicSlice(
HloInstruction* slice,
tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/) override {
return DefaultAction(slice);
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
HloInstruction* /*operand*/,
HloInstruction* /*start_indices*/) override {
return DefaultAction(dynamic_slice);
}
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
HloInstruction* /*operand*/,

View File

@ -136,9 +136,9 @@ Status HloCostAnalysis::HandleSlice(HloInstruction* slice,
return Status::OK();
}
Status HloCostAnalysis::HandleDynamicSlice(
HloInstruction* slice,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
Status HloCostAnalysis::HandleDynamicSlice(HloInstruction* dynamic_slice,
HloInstruction* operand,
HloInstruction* start_indices) {
return Status::OK();
}

View File

@ -89,9 +89,9 @@ class HloCostAnalysis : public DfsHloVisitor {
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target) override;
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
Status HandleDynamicSlice(
HloInstruction* slice,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
HloInstruction* operand,
HloInstruction* start_indices) override;
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
HloInstruction* operand,
HloInstruction* update,

View File

@ -1782,7 +1782,7 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
case HloOpcode::kSlice:
return visitor->HandleSlice(this, operands_[0]);
case HloOpcode::kDynamicSlice:
return visitor->HandleDynamicSlice(this, operands_);
return visitor->HandleDynamicSlice(this, operands_[0], operands_[1]);
case HloOpcode::kDynamicUpdateSlice:
return visitor->HandleDynamicUpdateSlice(this, operands_[0], operands_[1],
operands_[2]);

View File

@ -20,6 +20,7 @@ py_library(
"//tensorflow/contrib/copy_graph:copy_graph_py",
"//tensorflow/contrib/crf:crf_py",
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
"//tensorflow/contrib/data",
"//tensorflow/contrib/deprecated:deprecated_py",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/factorization:factorization_py",

View File

@ -25,6 +25,7 @@ from tensorflow.contrib import compiler
from tensorflow.contrib import copy_graph
from tensorflow.contrib import crf
from tensorflow.contrib import cudnn_rnn
from tensorflow.contrib import data
from tensorflow.contrib import deprecated
from tensorflow.contrib import distributions
from tensorflow.contrib import factorization

View File

@ -14,6 +14,23 @@ For prebuilt libraries, see the
[nightly Android build artifacts](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)
page for a recent build.
The TensorFlow Inference Interface is also available as a
[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and
can be included quite simply in your android project with a couple of lines in
the project's `build.gradle` file:
```
allprojects {
repositories {
jcenter()
}
}
dependencies {
compile 'org.tensorflow:tensorflow-android:1.2.0-preview'
}
```
To build the libraries yourself (if, for example, you want to support custom
TensorFlow operators), pick your preferred approach below:

View File

@ -18,6 +18,7 @@ set(tf_op_lib_names
"control_flow_ops"
"ctc_ops"
"data_flow_ops"
"dataset_ops"
"functional_ops"
"image_ops"
"io_ops"

View File

@ -265,6 +265,11 @@ add_python_module("tensorflow/contrib/cudnn_rnn/ops")
add_python_module("tensorflow/contrib/cudnn_rnn/python")
add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests")
add_python_module("tensorflow/contrib/cudnn_rnn/python/ops")
add_python_module("tensorflow/contrib/data")
add_python_module("tensorflow/contrib/data/python")
add_python_module("tensorflow/contrib/data/python/framework")
add_python_module("tensorflow/contrib/data/python/kernel_tests")
add_python_module("tensorflow/contrib/data/python/ops")
add_python_module("tensorflow/contrib/deprecated")
add_python_module("tensorflow/contrib/distributions")
add_python_module("tensorflow/contrib/distributions/python")
@ -592,6 +597,7 @@ GENERATE_PYTHON_OP_LIB("control_flow_ops"
ADDITIONAL_LIBRARIES $<TARGET_OBJECTS:tf_no_op>)
GENERATE_PYTHON_OP_LIB("ctc_ops")
GENERATE_PYTHON_OP_LIB("data_flow_ops")
GENERATE_PYTHON_OP_LIB("dataset_ops")
GENERATE_PYTHON_OP_LIB("image_ops")
GENERATE_PYTHON_OP_LIB("io_ops")
GENERATE_PYTHON_OP_LIB("linalg_ops")

View File

@ -144,6 +144,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/training/*_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/stateless/python/kernel_tests/*_test.py"
@ -204,6 +205,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
# Broken tensorboard test due to cmake issues.
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
# tensor_forest tests (also note that we exclude the hybrid tests for now)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.

View File

@ -0,0 +1,27 @@
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "data",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/python:util",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,630 @@
# Using the `Dataset` API for TensorFlow Input Pipelines
The `Dataset` API is designed to let you build complex input pipelines from
simple, reusable pieces. For example, the pipeline for an image model might
aggregate data from files in a distributed file system, apply random
perturbations to each image, and merge randomly selected images into a batch
for training. The pipeline for a text model might involve extracting symbols
from raw text data, converting them to embedding identifiers with a lookup
table, and batching together sequences of different lengths. The `Dataset` API
makes it easy to deal with large amounts of data, different data formats, and
complicated transformations.
The `Dataset` API introduces two new abstractions to TensorFlow:
* A `tf.contrib.data.Dataset` represents a sequence of elements, in which
each element contains one or more `Tensor` objects. For example, in an image
pipeline, an element might be a single training example, with a pair of
tensors representing the image data and a label. A `Dataset` can either be a
*source* (e.g. `Dataset.from_tensor_slices()` constructs a dataset from one
or more `tf.Tensor` objects), or a *transformation* (e.g. `Dataset.batch()`
constructs a dataset by stacking consecutive elements of another dataset into
a single element).
* A `tf.contrib.data.Iterator` provides the main way to extract elements from a
dataset. The `Iterator.get_next()` operation yields the next element of a
`Dataset`, and typically acts as the interface between input pipeline code and
your model. The simplest iterator is a "one-shot iterator", which is
associated with a particular `Dataset` and iterates through it once. For more
sophisticated uses, the `Iterator.initializer` operation enables you to
reinitialize and parameterize an iterator with different datasets, so that
you can, for example, iterate over training and validation data multiple times
in the same program.
## Tutorial
This programmers' guide includes step-by-step instructions for a variety of
input data use cases. Also see the `Dataset` and `Iterator` class references
for more detailed information about the API.
### Basic mechanics
This section of the guide describes the fundamentals of creating different kinds
of `Dataset` and `Iterator` objects, and how to extract data from them.
#### Defining a source dataset
You can build a `Dataset` using one of the following *source* dataset
constructors:
* From in-memory data:
* `tf.contrib.data.Dataset.from_tensors()`
* `tf.contrib.data.Dataset.from_tensor_slices()`
* From on-disk data:
* `tf.contrib.data.FixedLengthRecordDataset()`
* `tf.contrib.data.TextLineDataset()`
* `tf.contrib.data.TFRecordDataset()`
* From parameters:
* `tf.contrib.data.Dataset.range()`
#### Transforming a dataset
The `tf.contrib.data.Dataset` class has many methods that can be chained
together to *transform* one dataset into another:
* Per-element transformations:
* `Dataset.filter()`
* `Dataset.flat_map()`
* `Dataset.map()`
* `Dataset.zip()`
* Multi-element transformations:
* `Dataset.batch()`
* `Dataset.dense_to_sparse_batch()`
* `Dataset.group_by_window()`
* `Dataset.padded_batch()`
* `Dataset.repeat()`
* `Dataset.shuffle()`
* `Dataset.skip()`
* `Dataset.take()`
The following sections contain examples of how to use these transformations to
solve common problems.
#### Dataset structure
A dataset comprises elements that each have the same structure. An element
contains one or more `tf.Tensor` objects, called *components*. Each component
has a `tf.DType` representing the type of elements in the tensor, and a
`tf.TensorShape` representing the (possibly partially specified) static shape of
each element. The `Dataset.output_types` and `Dataset.output_shapes` properties
allow you to inspect the inferred types and shapes of each component of a
dataset element. The *nested structure* of these properties map to the structure
of an element, which may be a single tensor, a tuple of tensors, or a nested
tuple of tensors. For example:
```python
dataset1 = tf.contrib.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.contrib.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]), tf.random_uniform([4, 100], dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.contrib.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "((), (100,))"
```
The `Dataset` transformations support datasets of any structure. When using the
`Dataset.map()`, `Dataset.flat_map()` and `Dataset.filter()` transformations,
which apply a function to each element, the element structure determines the
arguments of the function:
```python
dataset1 = dataset1.map(lambda x: ...)
dataset2 = dataset2.flat_map(lambda x, y: ...)
# *N.B.* Lambda argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)
```
#### Creating an iterator
One you have built a `Dataset` to represent your input data, the next step is to
create an `Iterator` to access elements from that dataset. The `Dataset` API
currently supports three kinds of iterator, in increasing level of
sophistication:
A *one-shot* iterator is the simplest form of iterator, which only supports
iterating once through a dataset, with no need for explicit initialization.
One-shot iterators handle almost all of the cases that the existing queue-based
input pipelines support, but they do not support parameterization. Using the
example of `Dataset.range()`:
```python
dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
```
An *initializable* iterator requires you to run an explicit
`iterator.initializer` operation before using it. In exchange for this
inconvenience, it enables you to *parameterize* the definition of the dataset,
using one or more `tf.placeholder()` tensors that can be fed when you
initialize the iterator. Continuing the `Dataset.range()` example:
```python
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.contrib.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
```
A *reinitializable* iterator can be initialized from multiple different
`Dataset` objects. For example, you might have a training input pipeline that
uses random perturbations to the input images to improve generalization, and
a validation input pipeline that evaluates predictions on unmodified data. These
pipelines will typically use different `Dataset` objects that have the same
structure (i.e. the same types and compatible shapes for each component).
```python
training_dataset = tf.contrib.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.contrib.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
```
#### Consuming values from an iterator
The `Iterator.get_next()` method returns one or more `tf.Tensor` objects that
correspond to the symbolic next element of an iterator. Each time these tensors
are evaluated, they take the value of the next element in the underlying
dataset. (Note that, like other stateful objects in TensorFlow, calling
`Iterator.get_next()` does not immediately advance the iterator. Instead you
must use the returned `tf.Tensor` objects in a TensorFlow expression, and pass
the result of that expression to `tf.Session.run()` to get the next elements and
advance the iterator.)
If the iterator reaches the end of the dataset, executing
the `Iterator.get_next()` operation will raise a `tf.errors.OutOfRangeError`.
After this point the iterator will be in an unusable state, and you must
initialize it again if you want to use it further.
```python
dataset = tf.contrib.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)
sess.run(iterator.initializer)
print(sess.run(result)) # ==> "0"
print(sess.run(result)) # ==> "2"
print(sess.run(result)) # ==> "4"
print(sess.run(result)) # ==> "6"
print(sess.run(result)) # ==> "8"
try:
sess.run(result)
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
```
A common pattern is to wrap the "training loop" in a `try`-`except` block:
```python
sess.run(iterator.initializer)
while True:
try:
sess.run(result)
except tf.errors.OutOfRangeError:
break
```
If each element of the dataset has a nested structure, the return value of
`Iterator.get_next()` will be one or more `tf.Tensor` objects in the same
nested structure:
```python
dataset1 = tf.contrib.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.contrib.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.contrib.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
```
Note that evaluating *any* of `next1`, `next2`, or `next3` will advance the
iterator for all components. A typical consumer of an iterator will include all
components in a single expression.
### Reading input data
#### Consuming NumPy arrays
If all of your input data fit in memory, the simplest way to create a `Dataset`
from them is to convert them to `tf.Tensor` objects and use
`Dataset.from_tensor_slices()`.
```python
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.contrib.data.Dataset.from_tensor_slices((features, labels))
```
Note that the above code snippet will embed the `features` and `labels` arrays
in your TensorFlow graph as constants. This works well for a small dataset, but
wastes memory, and can run into the 2GB limit for the `tf.GraphDef` protocol
buffer.
As an alternative, you can define the `Dataset` in terms of `tf.placeholder()`
tensors, and *feed* the NumPy arrays when you initialize an `Iterator` over the
dataset.
```python
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.contrib.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
```
#### Consuming TFRecord data
The `Dataset` API supports a variety of file formats so that you can process
large datasets that do not fit in memory. The TFRecord file format is a
simple record-oriented binary format that many TensorFlow applications use for
training data. The `tf.contrib.data.TFRecordDataset` class enables you to
stream over the contents of one or more TFRecord files as part of an input
pipeline.
```python
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
```
The `filenames` argument to the `TFRecordDataset` initializer can be a
`tf.Tensor` of strings. Therefore if you have two sets of files for training
and validation purposes, you can use a `tf.placeholder(tf.string)` to represent
the filenames, and initialize an iterator from the appropriate filenames:
```python
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.
# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
```
#### Consuming text data
Many datasets are distributed as one or more text files. The
`tf.contrib.data.TextLineDataset` provides an easy way to extract lines from
one or more text files. Given one or more filenames, a `TextLineDataset` will
produce one string-valued element per line of those files. Like a
`TFRecordDataset`, `TextLineDataset` accepts `filenames` as a `tf.Tensor`, so
you can parameterize it by passing a `tf.placeholder(tf.string)`.
```python
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.contrib.data.TextLineDataset(filenames)
```
By default, a `TextLineDataset` yields *every* line of each file, which may
not be desirable, for example if the file starts with a header line, or contains
comments. These lines can be removed using the `Dataset.skip()` and
`Dataset.filter()` transformations. To apply these transformations to each
file separately, we use `Dataset.flat_map()` to create a nested `Dataset` for
each file.
```python
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames)
# Use `Dataset.flat_map()` to transform each file separately.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
lambda filename: (
tf.contrib.data.Dataset.TextLineDataset(filename)
.skip(1)
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
```
<!--
TODO(mrry): Add these sections.
#### Consuming from a Python generator
#### Consuming from an index file and images
-->
### Preprocessing data with `Dataset.map()`
The `Dataset.map(f)` transformation produces a new dataset by applying a given
function `f` to each element of the input dataset. It is based on
the
[`map()` function](https://en.wikipedia.org/wiki/Map_(higher-order_function))
that is commonly applied to lists (and other structures) in functional
programming languages. The function `f` takes the `tf.Tensor` objects that
represent a single element in the input, and returns the `tf.Tensor` objects
that will represent a single element in the new dataset. Its implementation uses
standard TensorFlow operations to transform one element into another.
This section covers common examples of how to use `Dataset.map()`.
#### Parsing `tf.Example` protocol buffer messages
Many input pipelines extract `tf.train.Example` protocol buffer messages from a
TFRecord-format file (written, for example, using
`tf.python_io.TFRecordWriter`). Each `tf.train.Example` record contains one or
more "features", and the input pipeline typically converts these features into
tensors.
```python
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
```
#### Decoding image data and resizing it
When training a neural network on real-world image data, it is often necessary
to convert images of different sizes to a common size, so that they may be
batched into a fixed size.
```python
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(filename)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]
dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
```
#### Applying arbitrary Python logic with `tf.py_func()`
For performance reasons, we encourage you to use TensorFlow operations for
preprocessing your data whenever possible. However, it is sometimes useful to
be able to call upon external Python libraries when parsing your input data,
and you can do this by invoking the `tf.py_func()` operation in a
`Dataset.map()` transformation.
```python
import cv2
# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
image_decoded = cv2.imread(image_string, cv2.IMREAD_GRAYSCALE)
return image_decoded, label
# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
image_decoded.set_shape([None, None, None])
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]
dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: tf.py_func(
_read_py_function, [filename, label], [tf.uint8, label.dtype]))
dataset = dataset.map(_resize_function)
```
<!--
TODO(mrry): Add this section.
#### Handling text data with unusual sizes
-->
### Batching dataset elements
#### Simple batching
The simplest form of batching stacks `n` consecutive elements of a dataset into
a single element. The `Dataset.batch()` transformation does exactly this, with
the same constraints as the `tf.stack()` operator, applied to each component
of the elements: i.e. for each component *i*, all elements must have a tensor
of the exact same shape.
```python
inc_dataset = tf.contrib.data.Dataset.range(100)
dec_dataset = tf.contrib.data.Dataset.range(0, -100, -1)
dataset = tf.contrib.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)
iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
```
#### Batching tensors with padding
The above recipe works for tensors that all have the same size. However, many
models (e.g. sequence models) work with input data that can have varying size
(e.g. sequences of different lengths). To handle this case, the
`Dataset.padded_batch()` transformation enables you to batch tensors of
different shape by specifying one or more dimensions in which they may be
padded.
```python
dataset = tf.contrib.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0],
# [5, 5, 5, 5, 5, 0, 0],
# [6, 6, 6, 6, 6, 6, 0],
# [7, 7, 7, 7, 7, 7, 7]]
```
The `Dataset.padded_batch()` transformation allows you to set different padding
for each dimension of each component, and it may be variable-length (signified
by `None` in the example above) or constant-length. It is also possible to
override the padding value, which defaults to 0.
<!--
TODO(mrry): Add this section.
#### Dense ragged -> tf.SparseTensor
-->
### Training workflows
#### Processing multiple epochs
The `Dataset` API offers two main ways to process multiple epochs of the same
data.
The simplest way to iterate over a dataset in multiple epochs is to use the
`Dataset.repeat()` transformation. For example, to create a dataset that repeats
its input for 10 epochs:
```python
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)
```
Applying the `Dataset.repeat()` transformation with no arguments will repeat
the input indefinitely. The `Dataset.repeat()` transformation concatenates its
arguments without signaling the end of one epoch and the beginning of the next
epoch.
If you want to receive a signal at the end of each epoch, you can write a
training loop that catches the `tf.errors.OutOfRangeError` at the end of a
dataset. At that point you might collect some statistics (e.g. the validation
error) for the epoch.
```python
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Compute for 100 epochs.
for _ in range(100):
sess.run(iterator.initializer)
while True:
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
break
# [Perform end-of-epoch calculations here.]
```
#### Randomly shuffling input data
The `Dataset.shuffle()` transformation randomly shuffles the input dataset
using a similar algorithm to `tf.RandomShuffleQueue`: it maintains a fixed-size
buffer and chooses the next element uniformly at random from that buffer.
```python
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
```

View File

@ -0,0 +1,42 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""`tf.contrib.data.Dataset` API for input pipelines.
@@Dataset
@@Iterator
@@TFRecordDataset
@@FixedLengthRecordDataset
@@TextLineDataset
@@read_batch_features
@@rejection_resample
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
from tensorflow.contrib.data.python.ops.dataset_ops import Iterator
from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features
from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample
from tensorflow.contrib.data.python.ops.dataset_ops import TextLineDataset
from tensorflow.contrib.data.python.ops.dataset_ops import TFRecordDataset
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -0,0 +1,41 @@
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "function",
srcs = ["function.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework",
],
)
py_test(
name = "function_test",
size = "medium",
srcs = ["function_test.py"],
srcs_version = "PY2AND3",
deps = [
":function",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,267 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An experimental fork of the Python TensorFlow-function library.
NOTE: functions are currently experimental and subject to change!
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import tf_inspect
# NOTE(mrry): This is an experimental extension of a core class that wasn't
# designed to be extended, so we disable protected access checks for the
# whole file.
# pylint: disable=protected-access
class _ExperimentalFuncGraph(function._FuncGraph):
"""A helper for construction a function (supporting capture-by-value).
_ExperimentalFuncGraph overrides ops.Graph's create_op() so that we can keep
track of every inputs into every op created inside the function. If
any input is from other graphs, we keep track of it in self.capture
and substitue the input with a place holder.
Each captured input's corresponding place holder is converted into a
function argument and the caller passes in the captured tensor.
"""
def __init__(self, capture_by_value, *args, **kwargs):
super(_ExperimentalFuncGraph, self).__init__(*args, **kwargs)
self._capture_by_value = capture_by_value
self._building_function = True
self._outer_graph = ops.get_default_graph()
self._vscope = vs.get_variable_scope()
self._old_custom_getter = self._vscope.custom_getter
self._captured = {}
self.extra_inputs = []
self.extra_args = []
self.extra_vars = []
def create_op(self, op_type, inputs, data_types, **kwargs):
for i, x in enumerate(inputs):
if x.graph is not self:
# Referring to a tensor from other graph.
if x in self._captured:
# Captured already.
inputs[i] = self._captured[x]
elif self._capture_by_value:
inputs[i] = self._add_tensor_and_parents(x)
else:
# Substitute with a placeholder.
self.extra_inputs.append(x)
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
# pylint: disable=protected-access
ph._handle_shape = x._handle_shape
ph._handle_dtype = x._handle_dtype
# pylint: enable=protected-access
inputs[i] = ph
self._captured[x] = ph
self.extra_args.append(ph)
return super(_ExperimentalFuncGraph, self).create_op(op_type, inputs,
data_types, **kwargs)
def _add_tensor_and_parents(self, tensor):
op = self._add_op_and_parents(tensor.op)
return op.outputs[tensor.value_index]
def _add_op_and_parents(self, op):
op_def = function._get_op_def(op)
if op_def.is_stateful:
raise ValueError("Cannot capture a stateful node by value.")
elif op.type in ("Placeholder", "PlaceholderV2"):
raise ValueError("Cannot capture a placeholder by value.")
captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
captured_op = self.create_op(op.type, captured_inputs,
[o.dtype for o in op.outputs],
name=op.name, attrs=op.node_def.attr,
op_def=op_def)
for t, captured_t in zip(op.outputs, captured_op.outputs):
self._captured[t] = captured_t
return captured_op
class _ExperimentalDefinedFunction(function._DefinedFunction):
"""Overrides _DefinedFunction with support for capture-by-value."""
def __init__(self,
func,
argnames,
input_types,
func_name=None,
grad_func=None,
python_grad_func=None,
out_names=None,
shape_func=None,
capture_by_value=False,
**kwargs):
"""Creates an _ExperimentalDefinedFunction.
Args:
func: A python callable which constructs a tf function body.
argnames: A list of strings for function argument names.
input_types: The function's argument types. Can be a tuple, list of
tf data types.
func_name: The function name. Defaults to None, in which derives from
'func'.
grad_func: This function's gradient function, if not None. Defaults
to None.
python_grad_func: A python callable implementing the gradient of
the function python-side.
out_names: An optional list of strings for the function return value
names.
shape_func: An optional function mapping an op to a list of static
output shapes.
capture_by_value: Boolean (defaults to False). If True, captured values
will be copied into the function body.
**kwargs: The keyword arguments. **kwargs is passed to every call
site of this function.
Raises:
ValueError: The function definition is invalid.
"""
super(_ExperimentalDefinedFunction, self).__init__(
func, argnames, input_types, func_name, grad_func, python_grad_func,
out_names, shape_func, **kwargs)
self._capture_by_value = capture_by_value
def _create_definition_if_needed(self):
"""Creates the function definition if it's not created yet."""
if self._definition is not None:
return
# Create the func_def object.
temp_graph = _ExperimentalFuncGraph(capture_by_value=self._capture_by_value)
with temp_graph.as_default():
# List of placeholders for the function_def.
inputs = []
for (argname, argtype) in self._args:
argholder = array_ops.placeholder(argtype, name=argname)
inputs.append(argholder)
# Call func and gather the output tensors.
with vs.variable_scope("", custom_getter=temp_graph.getvar):
outputs = self._func(*inputs)
# If func only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
if any([_ is None for _ in outputs]):
raise ValueError("Function can not return None.")
# Ensures each output is a Tensor.
outputs = [ops.convert_to_tensor(_) for _ in outputs]
self._extra_inputs = temp_graph.extra_inputs
inputs.extend(temp_graph.extra_args)
self._sub_functions = temp_graph._functions
# Build the FunctionDef
self._definition = function._graph_to_function_def(
temp_graph, temp_graph.get_operations(), inputs, outputs,
out_names=self._out_names)
# Extra kwargs are treated as attrs on the function def.
sig_pre_func_name = self._func_name or function._get_func_name(self._func)
kwargs_attr = function._parse_kwargs_as_attrs(
sig_pre_func_name, **self._extra_kwargs)
for k in kwargs_attr:
self._definition.attr[k].CopyFrom(kwargs_attr[k])
# Hash the definition and its dependencies.
self._hash_str = self._create_hash_str(
self._definition.signature.input_arg,
self._definition.signature.output_arg,
self._definition.node_def)
# Finally, we decide the function name to use. If not specified,
# make up something which is almost certainly unique (but deterministic).
if not self._func_name:
self._func_name = "_".join([function._get_func_name(self._func),
self._hash_str])
self._definition.signature.name = self._func_name
if self._func.__doc__:
self._definition.signature.description = self._func.__doc__
class Defun(function.Defun):
"""Experimental version of Defun supporting capture-by-value."""
def __init__(self, *input_types, **kwargs):
"""Create an experimental `Defun` decorator.
Args:
*input_types: A list of `tf.DType`
**kwargs: Optional keyword arguments (see `function.Defun`) plus:
capture_by_value - Boolean (defaults to False). If True, captured values
will be copied into the function body.
"""
super(Defun, self).__init__(*input_types, **kwargs)
def __call__(self, func):
# Various sanity checks on the callable func.
if not callable(func):
raise ValueError("func %s must be callable" % func)
# Func should not use kwargs and defaults.
argspec = tf_inspect.getargspec(func)
if argspec.keywords or argspec.defaults:
raise ValueError("Functions with argument defaults or keyword "
"arguments are not supported.")
# Computes how many arguments 'func' has.
min_args = len(argspec.args)
max_args = min_args
if argspec.varargs:
max_args = 1000000
argnames = argspec.args
if tf_inspect.ismethod(func):
# 1st argument is the "class" type.
min_args -= 1
argnames = argnames[1:]
if self._input_types:
# If Defun is given a list of types for the inputs, the number
# of input types should be compatible with 'func'.
num = len(self._input_types)
if num < min_args or num > max_args:
raise ValueError(
"The function has fewer arguments than the number of specified "
"input types.")
return _ExperimentalDefinedFunction(
func, argnames, self._input_types, self._func_name, self._grad_func,
self._python_grad_func, out_names=self._out_names,
**self._extra_kwargs)
# 'func' expects no arguments and input types is an empty list.
if min_args == 0 and max_args == 0:
return _ExperimentalDefinedFunction(
func, [], [], self._func_name, self._grad_func,
self._python_grad_func, out_names=self._out_names,
**self._extra_kwargs)
# Input types are unknown. It's an overloaded function and hence
# its definition needs to be deferred until it's called.
return function._OverloadedFunction(
func, argnames, self._func_name, self._grad_func,
self._python_grad_func, out_names=self._out_names, **self._extra_kwargs)

View File

@ -0,0 +1,59 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for experimental capture-by-value feature in TF functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.framework import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class FunctionTest(test.TestCase):
def testCaptureByValue(self):
g = ops.Graph()
with g.as_default():
w = constant_op.constant([[1.0]])
b = constant_op.constant([2.0])
# Foo() captures w and b.
@function.Defun(dtypes.float32, capture_by_value=True)
def Foo(x):
# Plus() captures b.
@function.Defun(dtypes.float32, capture_by_value=True)
def Plus(y):
return y + b
self.assertEqual(0, len(Plus.captured_inputs))
return Plus(math_ops.matmul(w, x))
y = Foo(constant_op.constant([[10.]]))
self.assertEqual(0, len(Foo.captured_inputs))
with self.test_session(graph=g):
self.assertAllEqual(y.eval(), [[12.0]])
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,219 @@
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
name = "iterator_ops_test",
size = "small",
srcs = ["iterator_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
],
)
py_test(
name = "batch_dataset_op_test",
size = "small",
srcs = ["batch_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:string_ops",
],
)
py_test(
name = "bucketing_test",
size = "small",
srcs = ["bucketing_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:client_testlib",
"//tensorflow/python:string_ops",
],
)
py_test(
name = "dataset_constructor_op_test",
size = "small",
srcs = ["dataset_constructor_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "filter_dataset_op_test",
size = "small",
srcs = ["filter_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "flat_map_dataset_op_test",
size = "small",
srcs = ["flat_map_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
)
py_test(
name = "map_dataset_op_test",
size = "small",
srcs = ["map_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:string_ops",
],
)
py_test(
name = "range_dataset_op_test",
size = "small",
srcs = ["range_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "reader_dataset_ops_test",
size = "small",
srcs = ["reader_dataset_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "resample_test",
size = "medium",
srcs = ["resample_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:string_ops",
"//tensorflow/python:variables",
],
)
py_test(
name = "sequence_dataset_op_test",
size = "small",
srcs = ["sequence_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "shuffle_dataset_op_test",
size = "small",
srcs = ["shuffle_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "zip_dataset_op_test",
size = "small",
srcs = ["zip_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,276 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class BatchDatasetTest(test.TestCase):
def testBatchDataset(self):
"""Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(count) -> BatchDataset(batch_size).
components = [np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
count = array_ops.placeholder(dtypes.int64, shape=[])
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
.repeat(count).batch(batch_size).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
with self.test_session() as sess:
# Batch of a finite input, where the batch_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
num_batches = (28 * 7) // 14
for i in range(num_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(14):
self.assertAllEqual(component[(i*14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Batch of a finite input, where the batch_size does not
# divide the total number of elements.
sess.run(init_op, feed_dict={count: 14, batch_size: 8})
# We expect (num_batches - 1) full-sized batches.
num_batches = int(math.ceil((14 * 7) / 8))
for i in range(num_batches - 1):
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range(8):
self.assertAllEqual(component[(i*8 + j) % 7]**2,
result_component[j])
result = sess.run(get_next)
for component, result_component in zip(components, result):
for j in range((14 * 7) % 8):
self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Batch of an empty input should fail straight away.
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Empty batch should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
def testPaddedBatchDataset(self):
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens)
.map(lambda x: array_ops.fill([x], x)).padded_batch(
4,
padded_shapes=padded_shape).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
sess.run(init_op, feed_dict={padded_shape: [-1],
seq_lens: random_seq_lens})
for i in range(8):
result = sess.run(get_next)
padded_len = np.max(result)
self.assertEqual((4, padded_len), result.shape)
for j in range(4):
seq_len = random_seq_lens[(i*4)+j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test with random sequence lengths, and constant padding.
sess.run(init_op, feed_dict={padded_shape: [25],
seq_lens: random_seq_lens})
for i in range(8):
result = sess.run(get_next)
self.assertEqual((4, 25), result.shape)
for j in range(4):
seq_len = random_seq_lens[(i*4)+j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test correct handling of empty tensors.
sess.run(init_op, feed_dict={padded_shape: [-1],
seq_lens: [0, 0, 0, 0]})
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test error handling with constant sequence lengths, and
# too-short padding.
sess.run(init_op, feed_dict={padded_shape: [5],
seq_lens: [6, 5, 5, 5]})
with self.assertRaises(errors.DataLossError):
result = sess.run(get_next)
def testPaddedBatchDatasetNonDefaultPadding(self):
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
def fill_tuple(x):
filled = array_ops.fill([x], x)
return (filled, string_ops.as_string(filled))
iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
.padded_batch(
4,
padded_shapes=(padded_shape, padded_shape),
padding_values=(-1, "<end>")).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
# Test with random sequence lengths, and max padding.
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
sess.run(init_op, feed_dict={padded_shape: [-1],
seq_lens: random_seq_lens})
for i in range(8):
result = sess.run(get_next)
padded_len = np.max(result[0])
self.assertEqual((4, padded_len), result[0].shape)
self.assertEqual((4, padded_len), result[1].shape)
for j in range(4):
seq_len = random_seq_lens[(i*4)+j]
self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
self.assertAllEqual(result[0][j, seq_len:],
[-1] * (padded_len - seq_len))
self.assertAllEqual(result[1][j, :seq_len],
[compat.as_bytes(str(seq_len))] * seq_len)
self.assertAllEqual(result[1][j, seq_len:],
[b"<end>"] * (padded_len - seq_len))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testPaddedBatchDatasetShapeSpecifications(self):
int_placeholder = array_ops.placeholder(dtypes.int32)
float_placeholder = array_ops.placeholder(dtypes.float32)
string_placeholder = array_ops.placeholder(dtypes.string)
input_dataset = dataset_ops.Dataset.from_tensors(
(int_placeholder, float_placeholder, string_placeholder))
# Test different ways of specifying the `padded_shapes` argument.
dynamic_padding_from_tensor_shapes = input_dataset.padded_batch(
32,
padded_shapes=(tensor_shape.TensorShape([None]),
tensor_shape.TensorShape([None, None]),
tensor_shape.TensorShape([37])))
dynamic_padding_from_lists = input_dataset.padded_batch(
32, padded_shapes=([None], [None, None], [37]))
dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch(
32, padded_shapes=([-1], [-1, -1], [37]))
dynamic_padding_from_tensors = input_dataset.padded_batch(
32,
padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64),
constant_op.constant([-1, -1], dtype=dtypes.int64),
constant_op.constant([37], dtype=dtypes.int64)))
for dataset in [dynamic_padding_from_tensor_shapes,
dynamic_padding_from_lists,
dynamic_padding_from_lists_with_minus_one,
dynamic_padding_from_tensors]:
self.assertEqual([None, None], dataset.output_shapes[0].as_list())
self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([x], x)).dense_to_sparse_batch(
4, [12]).make_initializable_iterator())
init_op = iterator.initializer
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
with self.test_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
results = sess.run(get_next)
self.assertAllEqual(
[[i, j] for i, c in enumerate(components[start:start+4])
for j in range(c)], results.indices)
self.assertAllEqual(
[c for c in components[start:start+4] for _ in range(c)],
results.values)
self.assertAllEqual(
[min(4, len(components) - start), 12], results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testDenseToSparseBatchDatasetShapeErrors(self):
input_tensor = array_ops.placeholder(dtypes.int32)
iterator = (dataset_ops.Dataset.from_tensors(input_tensor)
.dense_to_sparse_batch(4, [12]).make_initializable_iterator())
init_op = iterator.initializer
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
with self.test_session() as sess:
# Initialize with an input tensor of incompatible rank.
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"incompatible with the row shape"):
sess.run(get_next)
# Initialize with an input tensor that is larger than `row_shape`.
sess.run(init_op, feed_dict={input_tensor: range(13)})
with self.assertRaisesRegexp(errors.DataLossError,
"larger than the row shape"):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,292 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class BucketingTest(test.TestCase):
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
result = sess.run(get_next)
self.assertTrue(
all(x % 2 == 0 for x in result) or all(x % 2 == 1)
for x in result)
counts.append(result.shape[0])
self.assertEqual(len(components), sum(counts))
num_full_batches = len([c for c in counts if c == 4])
self.assertGreaterEqual(num_full_batches, 23)
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
def testImmediateOutput(self):
components = np.array(
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1)
.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
# 2. Different buckets can produce output at different rates, and
# 3. For deterministic input, the output is deterministic.
for _ in range(3):
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
def testSmallGroups(self):
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
# The small outputs at the end are deterministically produced in key
# order.
self.assertAllEqual([0, 0, 0], sess.run(get_next))
self.assertAllEqual([1], sess.run(get_next))
def testReduceFuncError(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
def reduce_func(_, xs):
# Introduce an incorrect padded shape that cannot (currently) be
# detected at graph construction time.
return xs.padded_batch(
4,
padded_shapes=(tensor_shape.TensorShape([]),
constant_op.constant([5], dtype=dtypes.int64) * -1))
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: (x, ops.convert_to_tensor([x * x])))
.group_by_window(lambda x, _: x % 2, reduce_func, 32))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
def testConsumeWindowDatasetMoreThanOnce(self):
components = np.random.randint(50, size=(200,)).astype(np.int64)
def reduce_func(key, window):
# Apply two different kinds of padding to the input: tight
# padding, and quantized (to a multiple of 10) padding.
return dataset_ops.Dataset.zip((window.padded_batch(
4,
padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch(
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),))
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
.group_by_window(
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
reduce_func, 4))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
while True:
tight_result, multiple_of_10_result = sess.run(get_next)
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
self.assertAllEqual(tight_result,
multiple_of_10_result[:, :tight_result.shape[1]])
counts.append(tight_result.shape[0])
self.assertEqual(len(components), sum(counts))
# NOTE(mrry): These tests are based on the tests in
# bucket_ops_test.py. Currently, different batch sizes for each key
# are not supported, although this would be possible to add to
# `Dataset.group_by_window()`.
class BucketTest(test.TestCase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
# generic form of padded_batch that pads every component
# dynamically and does not rely on static shape information about
# the arguments.
return dataset_ops.Dataset.zip(
(dataset_ops.Dataset.from_tensors(bucket), window.padded_batch(
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape([None]),
tensor_shape.TensorShape([3])))))
def testSingleBucket(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
bucketed_dataset = input_dataset.group_by_window(
lambda x, y, z: 0, lambda k, bucket: self._dynamicPad(k, bucket, 32),
32)
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
which_bucket, bucketed_values = sess.run(get_next)
self.assertEqual(0, which_bucket)
expected_scalar_int = np.arange(32, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
for i in range(32):
expected_unk_int64[i, :i] = i
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values[0])
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
def testEvenOddBuckets(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
bucketed_dataset = input_dataset.group_by_window(
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
which_bucket_even, bucketed_values_even = sess.run(get_next)
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
# Count number of bucket_tensors.
self.assertEqual(3, len(bucketed_values_even))
self.assertEqual(3, len(bucketed_values_odd))
# Ensure bucket 0 was used for all minibatch entries.
self.assertAllEqual(0, which_bucket_even)
self.assertAllEqual(1, which_bucket_odd)
# Test the first bucket outputted, the events starting at 0
expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
for i in range(0, 32):
expected_unk_int64[i, :2 * i] = 2 * i
expected_vec3_str = np.vstack(
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
# Test the second bucket outputted, the odds starting at 1
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
for i in range(0, 32):
expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
expected_vec3_str = np.vstack(
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
def testEvenOddBucketsFilterOutAllOdd(self):
def _map_fn(v):
return (v, array_ops.fill([v], v),
array_ops.fill([3], string_ops.as_string(v)))
input_dataset = (
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
.filter(lambda x, y, z: math_ops.equal(x % 2, 0)))
bucketed_dataset = input_dataset.group_by_window(
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
which_bucket0, bucketed_values_even0 = sess.run(get_next)
which_bucket1, bucketed_values_even1 = sess.run(get_next)
# Ensure that bucket 1 was completely filtered out
self.assertAllEqual(0, which_bucket0)
self.assertAllEqual(0, which_bucket1)
self.assertAllEqual(
np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0[0])
self.assertAllEqual(
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1[0])
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,239 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
class DatasetConstructorTest(test.TestCase):
def testTensorDataset(self):
"""Test an dataset that represents a single tuple of tensors."""
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
iterator = (dataset_ops.Dataset.from_tensors(components)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testTensorSliceDataset(self):
"""Test an dataset that represents the slices from a tuple of tensors."""
components = [
np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
np.array([[12], [13], [14], [15]]), 22),
np.array([37.0, 38.0, 39.0, 40.0])
]
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
sess.run(init_op)
for i in range(4):
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component[i], result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSparseTensorSliceDataset(self):
"""Test a dataset based on slices of a `tf.SparseTensor`."""
st = array_ops.sparse_placeholder(dtypes.float64)
iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
with self.test_session() as sess:
slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
# Test with sparse tensor in the appropriate order.
indices = np.array(
[[i, j] for i in range(len(slices)) for j in range(len(slices[i]))])
values = np.array([val for s in slices for val in s])
dense_shape = np.array([len(slices), max(len(s) for s in slices) + 1])
sparse_feed = sparse_tensor.SparseTensorValue(indices, values,
dense_shape)
sess.run(init_op, feed_dict={st: sparse_feed})
for i, s in enumerate(slices):
results = sess.run(get_next)
self.assertAllEqual(s, results.values)
expected_indices = np.array(
[[j] for j in range(len(slices[i]))]).reshape([-1, 1])
self.assertAllEqual(expected_indices, results.indices)
self.assertAllEqual(dense_shape[1:], results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test with sparse tensor in the reverse order, which is not
# currently supported.
reverse_order_indices = indices[::-1, :]
reverse_order_values = values[::-1]
sparse_feed = sparse_tensor.SparseTensorValue(
reverse_order_indices, reverse_order_values, dense_shape)
with self.assertRaises(errors.UnimplementedError):
sess.run(init_op, feed_dict={st: sparse_feed})
# Test with an empty sparse tensor.
empty_indices = np.empty((0, 4), dtype=np.int64)
empty_values = np.empty((0,), dtype=np.float64)
empty_dense_shape = [0, 4, 37, 9]
sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
empty_dense_shape)
sess.run(init_op, feed_dict={st: sparse_feed})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# pylint: disable=g-long-lambda,unnecessary-lambda
def testNestedStructure(self):
components = (np.array([1, 2, 3]), (np.array([4., 5.]), np.array([6., 7.])),
np.array([8, 9, 10]))
dataset = dataset_ops.Dataset.from_tensors(components)
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
dtypes.int64), dataset.output_types)
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
dataset = dataset.shuffle(10, 10)
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
dtypes.int64), dataset.output_types)
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
dataset = dataset.repeat(-1)
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
dtypes.int64), dataset.output_types)
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
dataset = dataset.filter(lambda x, y, z: True)
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
dtypes.int64), dataset.output_types)
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
dataset = dataset.take(5)
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
dtypes.int64), dataset.output_types)
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
self.assertEquals(((dtypes.int64, dtypes.int64),
(dtypes.float64, dtypes.float64)), dataset.output_types)
self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
dataset = dataset.flat_map(
lambda x, y: dataset_ops.Dataset.from_tensors(((x[0], x[1]),
(y[0], y[1])))
)
self.assertEquals(((dtypes.int64, dtypes.int64),
(dtypes.float64, dtypes.float64)), dataset.output_types)
self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
dataset = dataset.batch(32)
self.assertEquals(((dtypes.int64, dtypes.int64),
(dtypes.float64, dtypes.float64)), dataset.output_types)
self.assertEquals((([None, 3], [None, 3]), ([None, 2], [None, 2])),
nest.pack_sequence_as(dataset.output_shapes, [
s.as_list()
for s in nest.flatten(dataset.output_shapes)
]))
iterator = dataset.make_one_shot_iterator()
(w, x), (y, z) = iterator.get_next()
self.assertEquals(dtypes.int64, w.dtype)
self.assertEquals(dtypes.int64, x.dtype)
self.assertEquals(dtypes.float64, y.dtype)
self.assertEquals(dtypes.float64, z.dtype)
self.assertEquals([None, 3], w.shape.as_list())
self.assertEquals([None, 3], x.shape.as_list())
self.assertEquals([None, 2], y.shape.as_list())
self.assertEquals([None, 2], z.shape.as_list())
iterator = dataset.make_initializable_iterator()
(w, x), (y, z) = iterator.get_next()
self.assertEquals(dtypes.int64, w.dtype)
self.assertEquals(dtypes.int64, x.dtype)
self.assertEquals(dtypes.float64, y.dtype)
self.assertEquals(dtypes.float64, z.dtype)
self.assertEquals([None, 3], w.shape.as_list())
self.assertEquals([None, 3], x.shape.as_list())
self.assertEquals([None, 2], y.shape.as_list())
self.assertEquals([None, 2], z.shape.as_list())
# Define a separate set of components with matching leading
# dimension for the from-slices constructor.
components_for_slices = (np.array([1, 2, 3]), (np.array(
[4., 5., 6.]), np.array([7., 8., 9.])), np.array([10, 11, 12]))
dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
dtypes.int64), dataset.output_types)
self.assertEquals(([], ([], []), []), dataset.output_shapes)
def testNonSequenceNestedStructure(self):
components = np.array([1, 2, 3])
dataset = dataset_ops.Dataset.from_tensors(components)
self.assertEquals(dtypes.int64, dataset.output_types)
self.assertEquals([3], dataset.output_shapes)
dataset = dataset.filter(
lambda x: math_ops.reduce_all(math_ops.equal(x, components)))
self.assertEquals(dtypes.int64, dataset.output_types)
self.assertEquals([3], dataset.output_shapes)
dataset = dataset.map(lambda x: array_ops.stack([x, x]))
self.assertEquals(dtypes.int64, dataset.output_types)
self.assertEquals([2, 3], dataset.output_shapes)
dataset = dataset.flat_map(
lambda x: dataset_ops.Dataset.from_tensor_slices(x))
self.assertEquals(dtypes.int64, dataset.output_types)
self.assertEquals([3], dataset.output_shapes)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
self.assertEquals(dtypes.int64, get_next.dtype)
self.assertEquals([3], get_next.shape)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,77 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class FilterDatasetTest(test.TestCase):
def testFilterDataset(self):
components = [
np.arange(7, dtype=np.int64),
np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
7, dtype=np.int64)[:, np.newaxis],
np.array(37.0, dtype=np.float64) * np.arange(7)
]
count = array_ops.placeholder(dtypes.int64, shape=[])
modulus = array_ops.placeholder(dtypes.int64)
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
.repeat(count)
.filter(lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
# Test that we can dynamically feed a different modulus value for each
# iterator.
def do_test(count_val, modulus_val):
sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
for _ in range(count_val):
for i in [x for x in range(7) if x**2 % modulus_val == 0]:
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
do_test(14, 2)
do_test(4, 18)
# Test an empty dataset.
do_test(0, 1)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,108 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
class FlatMapDatasetTest(test.TestCase):
# pylint: disable=g-long-lambda
def testFlatMapDataset(self):
repeats = [1, 2, 3, 4, 5, 0, 1]
components = [np.array(repeats, dtype=np.int64)]
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x))
.make_initializable_iterator())
init_op = iterator.initializer
get_next, = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for i in repeats:
for _ in range(i):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testNestedFlatMapDataset(self):
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
components = [np.array(repeats, dtype=np.int64)]
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])
.flat_map(lambda y: dataset_ops.Dataset.from_tensors([y])
.repeat(y))).make_initializable_iterator())
init_op = iterator.initializer
get_next, = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for row in repeats:
for i in row:
for _ in range(i):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSharedResourceNestedFlatMapDataset(self):
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
components = [np.array(repeats, dtype=np.int64)]
iterator = (
dataset_ops.Dataset.from_tensor_slices(components)
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])
.flat_map(lambda y: dataset_ops.Dataset.from_tensors([y])
.repeat(y))).make_initializable_iterator(
shared_name="shared_flat_map_iterator"))
init_op = iterator.initializer
get_next, = iterator.get_next()
# Create two concurrent sessions that share the same iterator
# resource on the same server, and verify that a random
# interleaving of `Session.run(get_next)` calls on the two
# sessions yields the expected result.
server = server_lib.Server.create_local_server()
with session.Session(server.target) as sess1:
with session.Session(server.target) as sess2:
for _ in range(3):
sess = random.choice([sess1, sess2])
sess.run(init_op)
for row in repeats:
for i in row:
for _ in range(i):
sess = random.choice([sess1, sess2])
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess = random.choice([sess1, sess2])
sess.run(get_next)
# pylint: enable=g-long-lambda
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,252 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
class IteratorTest(test.TestCase):
def testOneShotIterator(self):
components = [np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
.repeat(14).make_one_shot_iterator())
get_next = iterator.get_next()
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testOneShotIteratorCaptureByValue(self):
components = [np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
tensor_components = [ops.convert_to_tensor(c) for c in components]
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (dataset_ops.Dataset.from_tensor_slices(tensor_components)
.map(_map_fn).repeat(14).make_one_shot_iterator())
get_next = iterator.get_next()
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testOneShotIteratorInsideContainer(self):
components = [np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
def within_container():
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.map(_map_fn).repeat(14).make_one_shot_iterator())
return iterator.get_next()
server = server_lib.Server.create_local_server()
# Create two iterators within unique containers, and run them to
# make sure that the resources aren't shared.
#
# The test below would fail if cname were the same across both
# sessions.
for i in range(2):
with session.Session(server.target) as sess:
cname = "iteration%d" % i
with ops.container(cname):
get_next = within_container()
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSimpleSharedResource(self):
components = [
np.array(1, dtype=np.int64),
np.array([1, 2, 3], dtype=np.int64),
np.array(37.0, dtype=np.float64)
]
server = server_lib.Server.create_local_server()
# Create two non-overlapping sessions that share the same iterator
# resource on the same server, and verify that an action of the
# first session (initializing the iterator) is visible in the
# second session.
with ops.Graph().as_default():
iterator = (dataset_ops.Dataset.from_tensors(components)
.map(lambda x, y, z: (x, y, z)).make_initializable_iterator(
shared_name="shared_iterator"))
init_op = iterator.initializer
get_next = iterator.get_next()
with session.Session(server.target) as sess:
sess.run(init_op)
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Re-initialize the iterator in the first session.
sess.run(init_op)
with ops.Graph().as_default():
# Re-define the iterator manually, without defining any of the
# functions in this graph, to ensure that we are not
# accidentally redefining functions with the same names in the
# new graph.
iterator = dataset_ops.Iterator.from_structure(
shared_name="shared_iterator",
output_types=[dtypes.int64, dtypes.int64, dtypes.float64],
output_shapes=[[], [3], []])
get_next = iterator.get_next()
with session.Session(server.target) as sess:
# Use the iterator without re-initializing in the second session.
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testNotInitializedError(self):
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
iterator = (dataset_ops.Dataset.from_tensors(components)
.make_initializable_iterator())
get_next = iterator.get_next()
with self.test_session() as sess:
with self.assertRaisesRegexp(errors.FailedPreconditionError,
"iterator has not been initialized"):
sess.run(get_next)
def testReinitializableIterator(self):
dataset_3 = dataset_ops.Dataset.from_tensors(
constant_op.constant([1, 2, 3]))
dataset_4 = dataset_ops.Dataset.from_tensors(
constant_op.constant([4, 5, 6, 7]))
iterator = dataset_ops.Iterator.from_structure(dataset_3.output_types,
[None])
dataset_3_init_op = iterator.make_initializer(dataset_3)
dataset_4_init_op = iterator.make_initializer(dataset_4)
get_next = iterator.get_next()
self.assertEqual(dataset_3.output_types, iterator.output_types)
self.assertEqual(dataset_4.output_types, iterator.output_types)
self.assertEqual([None], iterator.output_shapes.as_list())
with self.test_session() as sess:
# The iterator is initially uninitialized.
with self.assertRaises(errors.FailedPreconditionError):
sess.run(get_next)
# Initialize with one dataset.
sess.run(dataset_3_init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Initialize with a different dataset.
sess.run(dataset_4_init_op)
self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Reinitialize with the first dataset.
sess.run(dataset_3_init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testReinitializableIteratorStaticErrors(self):
# Non-matching structure for types and shapes.
with self.assertRaises(TypeError):
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64), [None])
# Test validation of dataset argument.
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
dtypes.float64))
# Incompatible structure.
with self.assertRaises(ValueError):
iterator.make_initializer(
dataset_ops.Dataset.from_tensors(((constant_op.constant(
[1, 2, 3], dtype=dtypes.int64),), (constant_op.constant(
[4., 5., 6., 7.], dtype=dtypes.float64),))))
# Incompatible types.
with self.assertRaises(TypeError):
iterator.make_initializer(
dataset_ops.Dataset.from_tensors((constant_op.constant(
[1, 2, 3], dtype=dtypes.int32), constant_op.constant(
[4., 5., 6., 7.], dtype=dtypes.float32))))
# Incompatible shapes.
iterator = dataset_ops.Iterator.from_structure(
(dtypes.int64, dtypes.float64), ([None], []))
with self.assertRaises(TypeError):
iterator.make_initializer(
dataset_ops.Dataset.from_tensors((constant_op.constant(
[1, 2, 3], dtype=dtypes.int64), constant_op.constant(
[4., 5., 6., 7.], dtype=dtypes.float64))))
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,330 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
class MapDatasetTest(test.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
.repeat(count))
def testMapDataset(self):
"""Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(count).
components = [np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
count = array_ops.placeholder(dtypes.int64, shape=[])
dataset = self._buildMapDataset(components, count)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
# Test single-threaded access to the iterator.
sess.run(init_op, feed_dict={count: 14})
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test multi-threaded access to the same iterator.
sess.run(init_op, feed_dict={count: 18})
results = []
def iterator_thread():
while True:
try:
results.append(sess.run(get_next))
except errors.OutOfRangeError:
return
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()
# `results` will contain the same elements components**2
# repeated 18 times, but in a non-deterministic order. Sort the
# results, and assert that each element of components**2 is
# produced 18 times.
results.sort(key=lambda x: x[0])
for i in range(7):
for j in range(18):
for component, result_component in zip(components,
results[i * 18 + j]):
self.assertAllEqual(component[i]**2, result_component)
def _buildParallelMapDataset(self, components, count, num_threads,
output_buffer_size):
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
return (dataset_ops.Dataset.from_tensor_slices(components).map(
_map_fn, num_threads=num_threads, output_buffer_size=output_buffer_size)
.repeat(count))
def testParallelMapDataset(self):
"""Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
# RepeatDataset(count).
components = [np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7)]
count = array_ops.placeholder(dtypes.int64, shape=[])
num_threads = array_ops.placeholder(dtypes.int32, shape=[])
output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
dataset = self._buildParallelMapDataset(components, count, num_threads,
output_buffer_size)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
def do_test(num_threads_val, output_buffer_size_val):
# Test single-threaded access to the iterator.
sess.run(init_op, feed_dict={
count: 14,
num_threads: num_threads_val,
output_buffer_size: output_buffer_size_val})
for _ in range(14):
for i in range(7):
result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test multi-threaded access to the same iterator.
sess.run(init_op, feed_dict={
count: 18,
num_threads: num_threads_val,
output_buffer_size: output_buffer_size_val})
results = []
def iterator_thread():
while True:
try:
results.append(sess.run(get_next))
except errors.OutOfRangeError:
return
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()
# `results` will contain the same elements components**2
# repeated 18 times, but in a non-deterministic order. Sort the
# results, and assert that each element of components**2 is
# produced 18 times.
results.sort(key=lambda x: x[0])
for i in range(7):
for j in range(18):
for component, result_component in zip(components,
results[i * 18 + j]):
self.assertAllEqual(component[i]**2, result_component)
for num_threads_val, output_buffer_size_val in [
(1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]:
do_test(num_threads_val, output_buffer_size_val)
def _testDisposeParallelMapDataset(self, explicit_dispose):
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(1000).
components = [np.arange(1000),
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
np.array(37.0) * np.arange(1000)]
dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
if explicit_dispose:
dispose_op = iterator.dispose_op()
with self.test_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
if explicit_dispose:
sess.run(dispose_op)
def testExplicitDisposeParallelMapDataset(self):
self._testDisposeParallelMapDataset(True)
def testImplicitDisposeParallelMapDataset(self):
self._testDisposeParallelMapDataset(False)
def testParallelMapError(self):
components = [np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)]
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
.map(lambda x: array_ops.check_numerics(x, "message")))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for _ in range(3):
sess.run(get_next)
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
sess.run(get_next)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testCaptureHashTable(self):
# NOTE(mrry): We must use the V2 variants of `HashTable`
# etc. because these produce a `tf.resource`-typed output that is
# compatible with the in-graph function implementation.
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
input_sentences = dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([
"brain brain tank salad surgery",
"surgery brain",
]))
iterator = (input_sentences
.map(lambda x: string_ops.string_split([x]).values)
.map(table.lookup)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(table.init)
sess.run(init_op)
print(sess.run(get_next))
print(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testCaptureQueue(self):
elements = np.random.randint(100, size=[200])
queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
enqueue_op = queue.enqueue_many(elements)
close_op = queue.close()
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1)
.map(lambda _: queue.dequeue()).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(enqueue_op)
sess.run(close_op)
sess.run(init_op)
for element in elements:
self.assertEqual(element, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testCaptureVariable(self):
counter_var = variable_scope.get_variable(
"counter", (), dtypes.int32, use_resource=True)
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
.map(lambda _: counter_var.assign_add(1))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(counter_var.initializer)
sess.run(init_op)
for i in range(10):
self.assertEqual(i, sess.run(counter_var))
self.assertEqual(i + 1, sess.run(get_next))
self.assertEqual(10, sess.run(counter_var))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertEqual(10, sess.run(counter_var))
def testCaptureUninitializedVariableError(self):
counter_var = variable_scope.get_variable(
"counter", (), dtypes.int32, use_resource=True)
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
.map(lambda _: counter_var.assign_add(1))
.make_initializable_iterator())
init_op = iterator.initializer
with self.test_session() as sess:
with self.assertRaisesRegexp(errors.FailedPreconditionError,
"Failed to capture resource"):
sess.run(init_op)
def testSeededStatefulOperatorIsProperlyStateful(self):
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
.map(lambda _: random_ops.random_uniform((), seed=11)).batch(2)
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
random_values = []
with self.assertRaises(errors.OutOfRangeError):
while True:
random_values.extend(sess.run(get_next))
self.assertEqual(10, len(random_values))
self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
sess.run(init_op)
random_values_2 = []
with self.assertRaises(errors.OutOfRangeError):
while True:
random_values_2.extend(sess.run(get_next))
# Randomness is repeatable given same seed
self.assertAllClose(random_values, random_values_2)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,182 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test RangeDataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class RangeDatasetTest(test.TestCase):
def testStop(self):
stop = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op, feed_dict={stop: 5})
for i in range(5):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testStartStop(self):
start = array_ops.placeholder(dtypes.int64, shape=[])
stop = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(start,
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 5})
for i in range(2, 5):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testStartStopStep(self):
start = array_ops.placeholder(dtypes.int64, shape=[])
stop = array_ops.placeholder(dtypes.int64, shape=[])
step = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(start, stop,
step).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2})
for i in range(2, 10, 2):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testZeroStep(self):
start = array_ops.placeholder(dtypes.int64, shape=[])
stop = array_ops.placeholder(dtypes.int64, shape=[])
step = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(start, stop,
step).make_initializable_iterator()
init_op = iterator.initializer
with self.test_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0})
def testNegativeStep(self):
start = array_ops.placeholder(dtypes.int64, shape=[])
stop = array_ops.placeholder(dtypes.int64, shape=[])
step = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(start, stop,
step).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
for i in range(2, 10, -1):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testStopLessThanStart(self):
start = array_ops.placeholder(dtypes.int64, shape=[])
stop = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(start,
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
for i in range(10, 2):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testStopLessThanStartWithPositiveStep(self):
start = array_ops.placeholder(dtypes.int64, shape=[])
stop = array_ops.placeholder(dtypes.int64, shape=[])
step = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(start, stop,
step).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2})
# This for loop is a no-op but will ensure that the implementation is
# consistent with range if it ever changes.
for i in range(10, 2, 2):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testStopLessThanStartWithNegativeStep(self):
start = array_ops.placeholder(dtypes.int64, shape=[])
stop = array_ops.placeholder(dtypes.int64, shape=[])
step = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(start, stop,
step).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1})
for i in range(10, 2, -1):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testEnumerateDataset(self):
components = [np.array(["a", "b"]), np.array([1, 2]), np.array([37.0, 38])]
start = constant_op.constant(20, dtype=dtypes.int64)
iterator = (dataset_ops.Dataset.from_tensor_slices(components).enumerate(
start=start).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual(dtypes.int64, get_next[0].dtype)
self.assertEqual((), get_next[0].shape)
self.assertEqual([tensor_shape.TensorShape([])] * 3,
[t.shape for t in get_next[1]])
with self.test_session() as sess:
sess.run(init_op)
self.assertEqual((20, [b"a", 1, 37.0]), sess.run(get_next))
self.assertEqual((21, [b"b", 2, 38.0]), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,500 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import zlib
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class TextLineDatasetTest(test.TestCase):
def _lineText(self, f, l):
return compat.as_bytes("%d: %d" % (f, l))
def _createFiles(self, num_files, num_lines, crlf=False):
filenames = []
for i in range(num_files):
fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
filenames.append(fn)
with open(fn, "wb") as f:
for j in range(num_lines):
f.write(self._lineText(i, j))
# Always include a newline after the record unless it is
# at the end of the file, in which case we include it sometimes.
if j + 1 != num_lines or i == 0:
f.write(b"\r\n" if crlf else b"\n")
return filenames
def testTextLineDataset(self):
test_filenames = self._createFiles(2, 5, crlf=True)
filenames = array_ops.placeholder(dtypes.string, shape=[None])
num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
repeat_dataset = dataset_ops.TextLineDataset(filenames).repeat(num_epochs)
batch_dataset = repeat_dataset.batch(batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
init_op = iterator.make_initializer(repeat_dataset)
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
with self.test_session() as sess:
# Basic test: read from file 0.
sess.run(init_op, feed_dict={filenames: [test_filenames[0]],
num_epochs: 1})
for i in range(5):
self.assertEqual(self._lineText(0, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Basic test: read from file 1.
sess.run(init_op, feed_dict={filenames: [test_filenames[1]],
num_epochs: 1})
for i in range(5):
self.assertEqual(self._lineText(1, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Basic test: read from both files.
sess.run(init_op, feed_dict={filenames: test_filenames,
num_epochs: 1})
for j in range(2):
for i in range(5):
self.assertEqual(self._lineText(j, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test repeated iteration through both files.
sess.run(init_op, feed_dict={filenames: test_filenames,
num_epochs: 10})
for _ in range(10):
for j in range(2):
for i in range(5):
self.assertEqual(self._lineText(j, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test batched and repeated iteration through both files.
sess.run(init_batch_op, feed_dict={filenames: test_filenames,
num_epochs: 10,
batch_size: 5})
for _ in range(10):
self.assertAllEqual([self._lineText(0, i) for i in range(5)],
sess.run(get_next))
self.assertAllEqual([self._lineText(1, i) for i in range(5)],
sess.run(get_next))
class FixedLengthRecordReaderTest(test.TestCase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
self._num_files = 2
self._num_records = 7
self._header_bytes = 5
self._record_bytes = 3
self._footer_bytes = 2
def _record(self, f, r):
return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
def _createFiles(self):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
filenames.append(fn)
with open(fn, "wb") as f:
f.write(b"H" * self._header_bytes)
for j in range(self._num_records):
f.write(self._record(i, j))
f.write(b"F" * self._footer_bytes)
return filenames
def testFixedLengthRecordDataset(self):
test_filenames = self._createFiles()
filenames = array_ops.placeholder(dtypes.string, shape=[None])
num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
repeat_dataset = (dataset_ops.FixedLengthRecordDataset(
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
.repeat(num_epochs))
batch_dataset = repeat_dataset.batch(batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
init_op = iterator.make_initializer(repeat_dataset)
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
with self.test_session() as sess:
# Basic test: read from file 0.
sess.run(init_op, feed_dict={filenames: [test_filenames[0]],
num_epochs: 1})
for i in range(self._num_records):
self.assertEqual(self._record(0, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Basic test: read from file 1.
sess.run(init_op, feed_dict={filenames: [test_filenames[1]],
num_epochs: 1})
for i in range(self._num_records):
self.assertEqual(self._record(1, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Basic test: read from both files.
sess.run(init_op, feed_dict={filenames: test_filenames,
num_epochs: 1})
for j in range(self._num_files):
for i in range(self._num_records):
self.assertEqual(self._record(j, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test repeated iteration through both files.
sess.run(init_op, feed_dict={filenames: test_filenames,
num_epochs: 10})
for _ in range(10):
for j in range(self._num_files):
for i in range(self._num_records):
self.assertEqual(self._record(j, i), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test batched and repeated iteration through both files.
sess.run(init_batch_op, feed_dict={filenames: test_filenames,
num_epochs: 10,
batch_size: self._num_records})
for _ in range(10):
for j in range(self._num_files):
self.assertAllEqual([self._record(j, i)
for i in range(self._num_records)],
sess.run(get_next))
class TFRecordDatasetTest(test.TestCase):
def setUp(self):
super(TFRecordDatasetTest, self).setUp()
self._num_files = 2
self._num_records = 7
self.test_filenames = self._createFiles()
self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
self.num_epochs = array_ops.placeholder_with_default(
constant_op.constant(1, dtypes.int64), shape=[])
self.compression_type = array_ops.placeholder_with_default("", shape=[])
self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
repeat_dataset = dataset_ops.TFRecordDataset(
self.filenames, self.compression_type).repeat(self.num_epochs)
batch_dataset = repeat_dataset.batch(self.batch_size)
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
self.init_op = iterator.make_initializer(repeat_dataset)
self.init_batch_op = iterator.make_initializer(batch_dataset)
self.get_next = iterator.get_next()
def _record(self, f, r):
return compat.as_bytes("Record %d of file %d" % (r, f))
def _createFiles(self):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
filenames.append(fn)
writer = python_io.TFRecordWriter(fn)
for j in range(self._num_records):
writer.write(self._record(i, j))
writer.close()
return filenames
def testReadOneEpoch(self):
with self.test_session() as sess:
# Basic test: read from file 0.
sess.run(self.init_op,
feed_dict={self.filenames: [self.test_filenames[0]],
self.num_epochs: 1})
for i in range(self._num_records):
self.assertAllEqual(self._record(0, i), sess.run(self.get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.get_next)
# Basic test: read from file 1.
sess.run(self.init_op,
feed_dict={self.filenames: [self.test_filenames[1]],
self.num_epochs: 1})
for i in range(self._num_records):
self.assertAllEqual(self._record(1, i), sess.run(self.get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.get_next)
# Basic test: read from both files.
sess.run(self.init_op,
feed_dict={self.filenames: self.test_filenames,
self.num_epochs: 1})
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.get_next)
def testReadTenEpochs(self):
with self.test_session() as sess:
sess.run(self.init_op, feed_dict={self.filenames: self.test_filenames,
self.num_epochs: 10})
for _ in range(10):
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.get_next)
def testReadTenEpochsOfBatches(self):
with self.test_session() as sess:
sess.run(self.init_batch_op,
feed_dict={self.filenames: self.test_filenames,
self.num_epochs: 10,
self.batch_size: self._num_records})
for _ in range(10):
for j in range(self._num_files):
values = sess.run(self.get_next)
self.assertAllEqual([self._record(j, i)
for i in range(self._num_records)], values)
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.get_next)
def testReadZlibFiles(self):
zlib_files = []
for i, fn in enumerate(self.test_filenames):
with open(fn, "rb") as f:
cdata = zlib.compress(f.read())
zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
with open(zfn, "wb") as f:
f.write(cdata)
zlib_files.append(zfn)
with self.test_session() as sess:
sess.run(self.init_op,
feed_dict={self.filenames: zlib_files,
self.compression_type: "ZLIB"})
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.get_next)
def testReadGzipFiles(self):
gzip_files = []
for i, fn in enumerate(self.test_filenames):
with open(fn, "rb") as f:
gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
with gzip.GzipFile(gzfn, "wb") as gzf:
gzf.write(f.read())
gzip_files.append(gzfn)
with self.test_session() as sess:
sess.run(self.init_op,
feed_dict={self.filenames: gzip_files,
self.compression_type: "GZIP"})
for j in range(self._num_files):
for i in range(self._num_records):
self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(self.get_next)
class ReadBatchFeaturesTest(test.TestCase):
def setUp(self):
super(ReadBatchFeaturesTest, self).setUp()
self._num_files = 2
self._num_records = 7
self.test_filenames = self._createFiles()
def _read_batch_features(self, filenames, num_epochs, batch_size):
self.filenames = filenames
self.num_epochs = num_epochs
self.batch_size = batch_size
return dataset_ops.read_batch_features(
file_pattern=self.filenames,
batch_size=self.batch_size,
features={
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
"keywords": parsing_ops.VarLenFeature(dtypes.string)
},
reader=dataset_ops.TFRecordDataset,
randomize_input=False,
num_epochs=self.num_epochs)
def _record(self, f, r):
example = example_pb2.Example(features=feature_pb2.Features(
feature={
"file":
feature_pb2.Feature(int64_list=feature_pb2.Int64List(
value=[f])),
"record":
feature_pb2.Feature(int64_list=feature_pb2.Int64List(
value=[r])),
"keywords":
feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
value=self._get_keywords(f, r)))
}))
return example.SerializeToString()
def _get_keywords(self, f, r):
num_keywords = 1 + (f + r) % 2
keywords = []
for index in range(num_keywords):
keywords.append(compat.as_bytes("keyword%d" % index))
return keywords
def _createFiles(self):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
filenames.append(fn)
writer = python_io.TFRecordWriter(fn)
for j in range(self._num_records):
writer.write(self._record(i, j))
writer.close()
return filenames
def _next_actual_batch(self, sess):
file_op = self.outputs["file"]
keywords_indices_op = self.outputs["keywords"].indices
keywords_values_op = self.outputs["keywords"].values
keywords_dense_shape_op = self.outputs["keywords"].dense_shape
record_op = self.outputs["record"]
return sess.run([
file_op, keywords_indices_op, keywords_values_op,
keywords_dense_shape_op, record_op
])
def _next_expected_batch(self, file_indices, batch_size, num_epochs):
def _next_record(file_indices):
for j in file_indices:
for i in range(self._num_records):
yield j, i
file_batch = []
keywords_batch_indices = []
keywords_batch_values = []
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
for _ in range(num_epochs):
for record in _next_record(file_indices):
f = record[0]
r = record[1]
file_batch.append(f)
record_batch.append(r)
keywords = self._get_keywords(f, r)
keywords_batch_values.extend(keywords)
keywords_batch_indices.extend([[batch_index, i]
for i in range(len(keywords))])
batch_index += 1
keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
if len(file_batch) == batch_size:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
[batch_size, keywords_batch_max_len], record_batch
]
file_batch = []
keywords_batch_indices = []
keywords_batch_values = []
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
if file_batch:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
[len(file_batch), keywords_batch_max_len], record_batch
]
def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1):
if file_index is not None:
file_indices = [file_index]
else:
file_indices = range(self._num_files)
for expected_batch in self._next_expected_batch(file_indices, batch_size,
num_epochs):
actual_batch = self._next_actual_batch(sess)
for i in range(len(expected_batch)):
self.assertAllEqual(expected_batch[i], actual_batch[i])
def testRead(self):
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
with ops.Graph().as_default():
with self.test_session(graph=ops.get_default_graph()) as sess:
# Basic test: read from file 0.
self.outputs = self._read_batch_features(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size)
self._verify_records(sess, batch_size, 0, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
with ops.Graph().as_default():
with self.test_session(graph=ops.get_default_graph()) as sess:
# Basic test: read from file 1.
self.outputs = self._read_batch_features(
filenames=self.test_filenames[1],
num_epochs=num_epochs,
batch_size=batch_size)
self._verify_records(sess, batch_size, 1, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
with ops.Graph().as_default():
with self.test_session(graph=ops.get_default_graph()) as sess:
# Basic test: read from both files.
self.outputs = self._read_batch_features(
filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size)
self._verify_records(sess, batch_size, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,79 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class ResampleTest(test.TestCase):
def testInitialKnownDistribution(self):
self._testDistribution(initial_known=True)
def testInitialNotKnownDistribution(self):
self._testDistribution(initial_known=False)
def _testDistribution(self, initial_known):
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
initial_dist = [0.2] * 5 if initial_known else None
iterator = dataset_ops.Iterator.from_dataset(
dataset_ops.rejection_resample(
(dataset_ops.Dataset.from_tensor_slices(classes)
.shuffle(200, seed=21)
.map(lambda c: (c, string_ops.as_string(c)))),
target_dist=target_dist,
initial_dist=initial_dist,
class_func=lambda c, _: c,
seed=27))
init_op = iterator.initializer
get_next = iterator.get_next()
variable_init_op = variables.global_variables_initializer()
with self.test_session() as sess:
sess.run(variable_init_op)
sess.run(init_op)
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
returned.append(sess.run(get_next))
returned_classes, returned_classes_and_data = zip(*returned)
_, returned_data = zip(*returned_classes_and_data)
self.assertAllEqual([compat.as_bytes(str(c))
for c in returned_classes], returned_data)
total_returned = len(returned_classes)
# Subsampling rejects a large precentage of the initial data in
# this case.
self.assertGreater(total_returned, 20000 * 0.2)
class_counts = np.array([
len([True for v in returned_classes if v == c])
for c in range(5)])
returned_dist = class_counts / total_returned
self.assertAllClose(target_dist, returned_dist, atol=1e-2)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,211 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class SequenceDatasetTest(test.TestCase):
def testRepeatTensorDataset(self):
"""Test a dataset that repeats its input multiple times."""
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
# This placeholder can be fed when dataset-definition subgraph
# runs (i.e. `init_op` below) to configure the number of
# repetitions used in a particular iterator.
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (dataset_ops.Dataset.from_tensors(components)
.repeat(count_placeholder).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
# Test a finite repetition.
sess.run(init_op, feed_dict={count_placeholder: 3})
for _ in range(3):
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test a different finite repetition.
sess.run(init_op, feed_dict={count_placeholder: 7})
for _ in range(7):
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test an empty repetition.
sess.run(init_op, feed_dict={count_placeholder: 0})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test an infinite repetition.
# NOTE(mrry): There's not a good way to test that the sequence
# actually is infinite.
sess.run(init_op, feed_dict={count_placeholder: -1})
for _ in range(17):
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
def testTakeTensorDataset(self):
components = [np.arange(10)]
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.take(count_placeholder).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
# Take fewer than input size
sess.run(init_op, feed_dict={count_placeholder: 4})
for i in range(4):
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Take more than input size
sess.run(init_op, feed_dict={count_placeholder: 25})
for i in range(10):
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Take all of input
sess.run(init_op, feed_dict={count_placeholder: -1})
for i in range(10):
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Take nothing
sess.run(init_op, feed_dict={count_placeholder: 0})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSkipTensorDataset(self):
components = [np.arange(10)]
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
.skip(count_placeholder).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape[1:] for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
# Skip fewer than input size, we should skip
# the first 4 elements and then read the rest.
sess.run(init_op, feed_dict={count_placeholder: 4})
for i in range(4, 10):
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Skip more than input size: get nothing.
sess.run(init_op, feed_dict={count_placeholder: 25})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Skip exactly input size.
sess.run(init_op, feed_dict={count_placeholder: 10})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Set -1 for 'count': skip the entire dataset.
sess.run(init_op, feed_dict={count_placeholder: -1})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Skip nothing
sess.run(init_op, feed_dict={count_placeholder: 0})
for i in range(0, 10):
results = sess.run(get_next)
self.assertAllEqual(results, components[0][i:i+1])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testRepeatRepeatTensorDataset(self):
"""Test the composition of repeat datasets."""
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
inner_count = array_ops.placeholder(dtypes.int64, shape=[])
outer_count = array_ops.placeholder(dtypes.int64, shape=[])
iterator = (dataset_ops.Dataset.from_tensors(components).repeat(inner_count)
.repeat(outer_count).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([c.shape for c in components],
[t.shape for t in get_next])
with self.test_session() as sess:
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
for _ in range(7 * 14):
results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testRepeatEmptyDataset(self):
"""Test that repeating an empty dataset does not hang."""
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10)
.repeat(-1).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(
errors.OutOfRangeError,
"Attempted to repeat an empty dataset infinitely."):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,152 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ShuffleDatasetTest(test.TestCase):
def testShuffleDataset(self):
components = [
np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
np.array([9.0, 10.0, 11.0, 12.0])
]
count_placeholder = array_ops.placeholder_with_default(
constant_op.constant(5, dtypes.int64), shape=[])
buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
.repeat(count_placeholder))
shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
seed_placeholder)
self.assertEqual([c.shape[1:] for c in components],
shuffle_dataset.output_shapes)
# Create initialization ops for iterators without and with
# shuffling, respectively.
iterator = dataset_ops.Iterator.from_structure(
shuffle_dataset.output_types, shuffle_dataset.output_shapes)
init_fifo_op = iterator.make_initializer(repeat_dataset)
init_shuffle_op = iterator.make_initializer(shuffle_dataset)
get_next = iterator.get_next()
with self.test_session() as sess:
# First run without shuffling to collect the "ground truth".
sess.run(init_fifo_op)
unshuffled_elements = []
for _ in range(20):
unshuffled_elements.append(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Assert that the shuffled dataset has the same elements as the
# "ground truth".
sess.run(
init_shuffle_op,
feed_dict={buffer_size_placeholder: 100,
seed_placeholder: 37})
shuffled_elements = []
for _ in range(20):
shuffled_elements.append(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertAllEqual(
sorted(unshuffled_elements), sorted(shuffled_elements))
# Assert that shuffling twice with the same seeds gives the same sequence.
sess.run(
init_shuffle_op,
feed_dict={buffer_size_placeholder: 100,
seed_placeholder: 37})
reshuffled_elements_same_seed = []
for _ in range(20):
reshuffled_elements_same_seed.append(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)
# Assert that shuffling twice with a different seed gives a different
# permutation of the same elements.
sess.run(
init_shuffle_op,
feed_dict={buffer_size_placeholder: 100,
seed_placeholder: 1037})
reshuffled_elements_different_seed = []
for _ in range(20):
reshuffled_elements_different_seed.append(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
self.assertAllEqual(
sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))
# Assert that the shuffled dataset has the same elements as the
# "ground truth" when the buffer size is smaller than the input
# dataset.
sess.run(
init_shuffle_op,
feed_dict={buffer_size_placeholder: 2,
seed_placeholder: 37})
reshuffled_elements_small_buffer = []
for _ in range(20):
reshuffled_elements_small_buffer.append(sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertAllEqual(
sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))
# Test the case of shuffling an empty dataset.
sess.run(init_shuffle_op, feed_dict={buffer_size_placeholder: 2,
seed_placeholder: 37,
count_placeholder: 0})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testDefaultArguments(self):
components = np.array([0, 1, 2, 3, 4])
iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
.repeat().make_one_shot_iterator())
get_next = iterator.get_next()
with self.test_session() as sess:
counts = collections.defaultdict(lambda: 0)
for _ in range(10):
for _ in range(5):
counts[sess.run(get_next)] += 1
for i in range(5):
self.assertEqual(10, counts[i])
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,114 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ZipDatasetTest(test.TestCase):
def testZipDataset(self):
component_placeholders = [
array_ops.placeholder(dtypes.int64),
array_ops.placeholder(dtypes.int64),
array_ops.placeholder(dtypes.float64)
]
datasets = [
dataset_ops.Dataset.from_tensor_slices(component_placeholder)
for component_placeholder in component_placeholders
]
zipped = dataset_ops.Dataset.zip(datasets)
iterator = zipped.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
equal_length_components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
np.array([37.0, 38.0, 39.0, 40.0])
]
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
component_placeholders, equal_length_components)})
for i in range(4):
results = sess.run(get_next)
for component, result_component in zip(
equal_length_components, results):
self.assertAllEqual(component[i], result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
component_placeholders, variable_length_components)})
for i in range(2):
results = sess.run(get_next)
for component, result_component in zip(
variable_length_components, results):
self.assertAllEqual(component[i], result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testNestedZipDataset(self):
component_placeholders = [
array_ops.placeholder(dtypes.int64, shape=[4, 20]),
array_ops.placeholder(dtypes.int64, shape=[4, 22]),
array_ops.placeholder(dtypes.float64, shape=[4])
]
datasets = [
dataset_ops.Dataset.from_tensor_slices(component_placeholder)
for component_placeholder in component_placeholders
]
zipped = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
iterator = zipped.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
self.assertEqual([20], get_next[0].shape)
self.assertEqual([22], get_next[1][0].shape)
self.assertEqual([], get_next[1][1].shape)
with self.test_session() as sess:
equal_length_components = [
np.tile(np.array([[1], [2], [3], [4]]), 20),
np.tile(np.array([[12], [13], [14], [15]]), 22),
np.array([37.0, 38.0, 39.0, 40.0])
]
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
component_placeholders, equal_length_components)})
for i in range(4):
result1, (result2, result3) = sess.run(get_next)
self.assertAllEqual(equal_length_components[0][i], result1)
self.assertAllEqual(equal_length_components[1][i], result2)
self.assertAllEqual(equal_length_components[2][i], result3)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,31 @@
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "dataset_ops",
srcs = ["dataset_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/data/python/framework:function",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:framework",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:util",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

File diff suppressed because it is too large Load Diff

View File

@ -273,7 +273,6 @@ class Layer(tf_base_layers.Layer):
# Internal methods:
build(input_shape)
_add_inbound_node(layer, index=0)
assert_input_compatibility()
"""
def __init__(self, **kwargs):
@ -381,97 +380,6 @@ class Layer(tf_base_layers.Layer):
self.constraints[weight] = constraint
return weight
def assert_input_compatibility(self, inputs):
"""Checks compatibility between the layer and provided inputs.
This checks that the tensor(s) `input`
verify the input assumptions of the layer
(if any). If not, exceptions are raised.
Arguments:
inputs: input tensor or list of input tensors.
Raises:
ValueError: in case of mismatch between
the provided inputs and the expectations of the layer.
"""
if not self.input_spec:
return
if not isinstance(self.input_spec, (list, tuple)):
input_spec = _to_list(self.input_spec)
else:
input_spec = self.input_spec
inputs = _to_list(inputs)
if len(inputs) != len(input_spec):
raise ValueError('Layer ' + self.name + ' expects ' +
str(len(input_spec)) + ' inputs, '
'but it received ' + str(len(inputs)) +
' input tensors. Input received: ' + str(inputs))
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
if spec is None:
continue
# Check ndim.
if spec.ndim is not None:
if K.ndim(x) != spec.ndim:
raise ValueError('Input ' + str(input_index) +
' is incompatible with layer ' + self.name +
': expected ndim=' + str(spec.ndim) + ', found ndim='
+ str(K.ndim(x)))
if spec.max_ndim is not None:
ndim = K.ndim(x)
if ndim is not None and ndim > spec.max_ndim:
raise ValueError('Input ' + str(input_index) +
' is incompatible with layer ' + self.name +
': expected max_ndim=' + str(spec.max_ndim) +
', found ndim=' + str(K.ndim(x)))
if spec.min_ndim is not None:
ndim = K.ndim(x)
if ndim is not None and ndim < spec.min_ndim:
raise ValueError('Input ' + str(input_index) +
' is incompatible with layer ' + self.name +
': expected min_ndim=' + str(spec.min_ndim) +
', found ndim=' + str(K.ndim(x)))
# Check dtype.
if spec.dtype is not None:
if K.dtype(x) != spec.dtype:
raise ValueError('Input ' + str(input_index) +
' is incompatible with layer ' + self.name +
': expected dtype=' + str(spec.dtype) +
', found dtype=' + str(K.dtype(x)))
# Check specific shape axes.
if spec.axes:
try:
x_shape = K.int_shape(x)
except TypeError:
x_shape = None
if x_shape is not None:
for axis, value in spec.axes.items():
if hasattr(value, 'value'):
value = value.value
if value is not None and x_shape[int(axis)] not in {value, None}:
raise ValueError(
'Input ' + str(input_index) + ' is incompatible with layer ' +
self.name + ': expected axis ' + str(axis) +
' of input shape to have '
'value ' + str(value) + ' but got shape ' + str(x_shape))
# Check shape.
if spec.shape is not None:
try:
x_shape = K.int_shape(x)
except TypeError:
x_shape = None
if x_shape is not None:
for spec_dim, dim in zip(spec.shape, x_shape):
if hasattr(spec_dim, 'value'):
spec_dim = spec_dim.value
if spec_dim is not None and dim is not None:
if spec_dim != dim:
raise ValueError('Input ' + str(input_index) +
' is incompatible with layer ' + self.name +
': expected shape=' + str(spec.shape) +
', found shape=' + str(x_shape))
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""This is where the layer's logic lives.
@ -509,11 +417,6 @@ class Layer(tf_base_layers.Layer):
if isinstance(inputs, list):
inputs = inputs[:]
# Raise exceptions in case the input is not compatible
# with the input_spec set at build time.
# TODO(fchollet): call after the layer is built, too.
self.assert_input_compatibility(inputs)
# Handle mask propagation.
previous_mask = _collect_previous_mask(inputs)
user_kwargs = copy.copy(kwargs)

View File

@ -37,233 +37,11 @@ from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling3D
# pylint: enable=unused-import
from tensorflow.contrib.keras.python.keras.utils import conv_utils
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import convolutional as tf_convolutional_layers
class _Conv(Layer):
"""Abstract nD convolution layer (private, used as implementation base).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
If `use_bias` is True, a bias vector is created and added to the outputs.
Finally, if `activation` is not `None`,
it is applied to the outputs as well.
Arguments:
rank: An integer, the rank of the convolution,
e.g. "2" for 2D convolution.
filters: Integer, the dimensionality of the output space
(i.e. the number output of filters in the convolution).
kernel_size: An integer or tuple/list of n integers, specifying the
dimensions of the convolution window.
strides: An integer or tuple/list of n integers,
specifying the strides of the convolution.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: One of `"valid"` or `"same"` (case-insensitive).
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, ..., channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, ...)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
dilation_rate: An integer or tuple/list of n integers, specifying
the dilation rate to use for dilated convolution.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any `strides` value != 1.
activation: Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to the kernel matrix.
bias_constraint: Constraint function applied to the bias vector.
"""
def __init__(self,
rank,
filters,
kernel_size,
strides=1,
padding='valid',
data_format=None,
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(_Conv, self).__init__(**kwargs)
self.rank = rank
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank,
'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank,
'dilation_rate')
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=self.rank + 2)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.kernel = self.add_weight(
shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
# Set input spec.
self.input_spec = InputSpec(
ndim=self.rank + 2, axes={channel_axis: input_dim})
self.built = True
def call(self, inputs):
if self.rank == 1:
outputs = K.conv1d(
inputs,
self.kernel,
strides=self.strides[0],
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate[0])
if self.rank == 2:
outputs = K.conv2d(
inputs,
self.kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self.rank == 3:
outputs = K.conv3d(
inputs,
self.kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self.use_bias:
outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_last':
space = input_shape[1:-1]
new_space = []
for i in range(len(space)):
new_dim = conv_utils.conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
return tensor_shape.TensorShape([input_shape[0]] + new_space +
[self.filters])
else:
space = input_shape[2:]
new_space = []
for i in range(len(space)):
new_dim = conv_utils.conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
return tensor_shape.TensorShape([input_shape[0], self.filters] +
new_space)
def get_config(self):
config = {
'rank':
self.rank,
'filters':
self.filters,
'kernel_size':
self.kernel_size,
'strides':
self.strides,
'padding':
self.padding,
'data_format':
self.data_format,
'dilation_rate':
self.dilation_rate,
'activation':
activations.serialize(self.activation),
'use_bias':
self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint)
}
base_config = super(_Conv, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Conv1D(_Conv):
"""1D convolution layer (e.g.
temporal convolution).
class Conv1D(tf_convolutional_layers.Conv1D, Layer):
"""1D convolution layer (e.g. temporal convolution).
This layer creates a convolution kernel that is convolved
with the layer input over a single spatial (or temporal) dimension
@ -336,33 +114,55 @@ class Conv1D(_Conv):
bias_constraint=None,
**kwargs):
super(Conv1D, self).__init__(
rank=1,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format='channels_last',
dilation_rate=dilation_rate,
activation=activation,
activation=activations.get(activation),
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
kernel_initializer=initializers.get(kernel_initializer),
bias_initializer=initializers.get(bias_initializer),
kernel_regularizer=regularizers.get(kernel_regularizer),
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.input_spec = InputSpec(ndim=3)
# TODO(fchollet): move weight constraint support to core layers.
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
def build(self, input_shape):
super(Conv1D, self).build(input_shape)
# TODO(fchollet): move weight constraint support to core layers.
if self.kernel_constraint:
self.constraints[self.kernel] = self.kernel_constraint
if self.use_bias and self.bias_constraint:
self.constraints[self.bias] = self.bias_constraint
def get_config(self):
config = super(Conv1D, self).get_config()
config.pop('rank')
config.pop('data_format')
return config
config = {
'filters': self.filters,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding,
'dilation_rate': self.dilation_rate,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(Conv1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Conv2D(_Conv):
class Conv2D(tf_convolutional_layers.Conv2D, Layer):
"""2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
@ -452,36 +252,60 @@ class Conv2D(_Conv):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
if data_format is None:
data_format = K.image_data_format()
super(Conv2D, self).__init__(
rank=2,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
activation=activations.get(activation),
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
kernel_initializer=initializers.get(kernel_initializer),
bias_initializer=initializers.get(bias_initializer),
kernel_regularizer=regularizers.get(kernel_regularizer),
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.input_spec = InputSpec(ndim=4)
# TODO(fchollet): move weight constraint support to core layers.
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
def build(self, input_shape):
super(Conv2D, self).build(input_shape)
# TODO(fchollet): move weight constraint support to core layers.
if self.kernel_constraint:
self.constraints[self.kernel] = self.kernel_constraint
if self.use_bias and self.bias_constraint:
self.constraints[self.bias] = self.bias_constraint
def get_config(self):
config = super(Conv2D, self).get_config()
config.pop('rank')
return config
config = {
'filters': self.filters,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding,
'data_format': self.data_format,
'dilation_rate': self.dilation_rate,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(Conv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Conv3D(_Conv):
"""3D convolution layer (e.g.
spatial convolution over volumes).
class Conv3D(tf_convolutional_layers.Conv3D, Layer):
"""3D convolution layer (e.g. spatial convolution over volumes).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of
@ -577,33 +401,59 @@ class Conv3D(_Conv):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
if data_format is None:
data_format = K.image_data_format()
super(Conv3D, self).__init__(
rank=3,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
activation=activations.get(activation),
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
kernel_initializer=initializers.get(kernel_initializer),
bias_initializer=initializers.get(bias_initializer),
kernel_regularizer=regularizers.get(kernel_regularizer),
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.input_spec = InputSpec(ndim=5)
# TODO(fchollet): move weight constraint support to core layers.
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
def build(self, input_shape):
super(Conv3D, self).build(input_shape)
# TODO(fchollet): move weight constraint support to core layers.
if self.kernel_constraint:
self.constraints[self.kernel] = self.kernel_constraint
if self.use_bias and self.bias_constraint:
self.constraints[self.bias] = self.bias_constraint
def get_config(self):
config = super(Conv3D, self).get_config()
config.pop('rank')
return config
config = {
'filters': self.filters,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding,
'data_format': self.data_format,
'dilation_rate': self.dilation_rate,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(Conv3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Conv2DTranspose(Conv2D):
class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer):
"""Transposed convolution layer (sometimes called Deconvolution).
The need for transposed convolutions generally arises
@ -699,121 +549,57 @@ class Conv2DTranspose(Conv2D):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
if data_format is None:
data_format = K.image_data_format()
super(Conv2DTranspose, self).__init__(
filters,
kernel_size,
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
activation=activations.get(activation),
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
kernel_initializer=initializers.get(kernel_initializer),
bias_initializer=initializers.get(bias_initializer),
kernel_regularizer=regularizers.get(kernel_regularizer),
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.input_spec = InputSpec(ndim=4)
# TODO(fchollet): move weight constraint support to core layers.
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if len(input_shape) != 4:
raise ValueError(
'Inputs should have rank ' + str(4) + '; Received input shape:',
str(input_shape))
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (self.filters, input_dim)
self.kernel = self.add_weight(
shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
# Set input spec.
self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
self.built = True
def call(self, inputs):
input_shape = K.shape(inputs)
batch_size = input_shape[0]
if self.data_format == 'channels_first':
h_axis, w_axis = 2, 3
else:
h_axis, w_axis = 1, 2
height, width = input_shape[h_axis], input_shape[w_axis]
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
# Infer the dynamic output shape:
out_height = conv_utils.deconv_length(height, stride_h, kernel_h,
self.padding)
out_width = conv_utils.deconv_length(width, stride_w, kernel_w,
self.padding)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
else:
output_shape = (batch_size, out_height, out_width, self.filters)
outputs = K.conv2d_transpose(
inputs,
self.kernel,
output_shape,
self.strides,
padding=self.padding,
data_format=self.data_format)
if self.bias:
outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
output_shape = list(input_shape)
if self.data_format == 'channels_first':
c_axis, h_axis, w_axis = 1, 2, 3
else:
c_axis, h_axis, w_axis = 3, 1, 2
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
output_shape[c_axis] = self.filters
output_shape[h_axis] = conv_utils.deconv_length(
output_shape[h_axis], stride_h, kernel_h, self.padding)
output_shape[w_axis] = conv_utils.deconv_length(
output_shape[w_axis], stride_w, kernel_w, self.padding)
return tensor_shape.TensorShape(output_shape)
super(Conv2DTranspose, self).build(input_shape)
# TODO(fchollet): move weight constraint support to core layers.
if self.kernel_constraint:
self.constraints[self.kernel] = self.kernel_constraint
if self.use_bias and self.bias_constraint:
self.constraints[self.bias] = self.bias_constraint
def get_config(self):
config = super(Conv2DTranspose, self).get_config()
config.pop('dilation_rate')
return config
config = {
'filters': self.filters,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding,
'data_format': self.data_format,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(Conv2DTranspose, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class SeparableConv2D(Conv2D):
class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer):
"""Depthwise separable 2D convolution.
Separable convolutions consist in first performing
@ -909,126 +695,68 @@ class SeparableConv2D(Conv2D):
pointwise_constraint=None,
bias_constraint=None,
**kwargs):
if data_format is None:
data_format = K.image_data_format()
super(SeparableConv2D, self).__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
activation=activations.get(activation),
use_bias=use_bias,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
bias_constraint=bias_constraint,
depthwise_initializer=initializers.get(depthwise_initializer),
pointwise_initializer=initializers.get(pointwise_initializer),
bias_initializer=initializers.get(bias_initializer),
depthwise_regularizer=regularizers.get(depthwise_regularizer),
pointwise_regularizer=regularizers.get(pointwise_regularizer),
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.depth_multiplier = depth_multiplier
self.depthwise_initializer = initializers.get(depthwise_initializer)
self.pointwise_initializer = initializers.get(pointwise_initializer)
self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
self.pointwise_regularizer = regularizers.get(pointwise_regularizer)
# TODO(fchollet): move weight constraint support to core layers.
self.depthwise_constraint = constraints.get(depthwise_constraint)
self.pointwise_constraint = constraints.get(pointwise_constraint)
self.bias_constraint = constraints.get(bias_constraint)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if len(input_shape) < 4:
raise ValueError('Inputs to `SeparableConv2D` should have rank 4. '
'Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = 3
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs to '
'`SeparableConv2D` '
'should be defined. Found `None`.')
input_dim = int(input_shape[channel_axis])
depthwise_kernel_shape = (self.kernel_size[0], self.kernel_size[1],
input_dim, self.depth_multiplier)
pointwise_kernel_shape = (1, 1, self.depth_multiplier * input_dim,
self.filters)
self.depthwise_kernel = self.add_weight(
shape=depthwise_kernel_shape,
initializer=self.depthwise_initializer,
name='depthwise_kernel',
regularizer=self.depthwise_regularizer,
constraint=self.depthwise_constraint)
self.pointwise_kernel = self.add_weight(
shape=pointwise_kernel_shape,
initializer=self.pointwise_initializer,
name='pointwise_kernel',
regularizer=self.pointwise_regularizer,
constraint=self.pointwise_constraint)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
# Set input spec.
self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
self.built = True
def call(self, inputs):
outputs = K.separable_conv2d(
inputs,
self.depthwise_kernel,
self.pointwise_kernel,
data_format=self.data_format,
strides=self.strides,
padding=self.padding)
if self.bias:
outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
rows = input_shape[2]
cols = input_shape[3]
else:
rows = input_shape[1]
cols = input_shape[2]
rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
self.padding, self.strides[0])
cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
self.padding, self.strides[1])
if self.data_format == 'channels_first':
return tensor_shape.TensorShape(
[input_shape[0], self.filters, rows, cols])
else:
return tensor_shape.TensorShape(
[input_shape[0], rows, cols, self.filters])
super(SeparableConv2D, self).build(input_shape)
# TODO(fchollet): move weight constraint support to core layers.
if self.depthwise_constraint:
self.constraints[self.depthwise_kernel] = self.depthwise_constraint
if self.pointwise_constraint:
self.constraints[self.pointwise_kernel] = self.pointwise_constraint
if self.use_bias and self.bias_constraint:
self.constraints[self.bias] = self.bias_constraint
def get_config(self):
config = super(SeparableConv2D, self).get_config()
config.pop('kernel_initializer')
config.pop('kernel_regularizer')
config.pop('kernel_constraint')
config['depth_multiplier'] = self.depth_multiplier
config['depthwise_initializer'] = initializers.serialize(
self.depthwise_initializer)
config['pointwise_initializer'] = initializers.serialize(
self.pointwise_initializer)
config['depthwise_regularizer'] = regularizers.serialize(
self.depthwise_regularizer)
config['pointwise_regularizer'] = regularizers.serialize(
self.pointwise_regularizer)
config['depthwise_constraint'] = constraints.serialize(
self.depthwise_constraint)
config['pointwise_constraint'] = constraints.serialize(
self.pointwise_constraint)
return config
config = {
'filters': self.filters,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding,
'data_format': self.data_format,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'depthwise_initializer': initializers.serialize(
self.depthwise_initializer),
'pointwise_initializer': initializers.serialize(
self.pointwise_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'depthwise_regularizer': regularizers.serialize(
self.depthwise_regularizer),
'pointwise_regularizer': regularizers.serialize(
self.pointwise_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'depthwise_constraint': constraints.serialize(
self.depthwise_constraint),
'pointwise_constraint': constraints.serialize(
self.pointwise_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(SeparableConv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class UpSampling1D(Layer):

View File

@ -27,24 +27,7 @@ from tensorflow.python.platform import test
class Convolution1DTest(test.TestCase):
def test_causal_dilated_conv1d(self):
# Causal:
with self.test_session():
testing_utils.layer_test(
keras.layers.Conv1D,
input_data=np.reshape(np.arange(4, dtype='float32'), (1, 4, 1)),
kwargs={
'filters': 1,
'kernel_size': 2,
'dilation_rate': 1,
'padding': 'causal',
'kernel_initializer': 'ones',
'use_bias': False,
},
expected_output=[[[0], [1], [3], [5]]])
def test_dilated_conv1d(self):
# Non-causal:
with self.test_session():
testing_utils.layer_test(
keras.layers.Conv1D,

View File

@ -85,7 +85,7 @@ class Masking(Layer):
return dict(list(base_config.items()) + list(config.items()))
class Dropout(Layer):
class Dropout(tf_core_layers.Dropout, Layer):
"""Applies Dropout to the input.
Dropout consists in randomly setting
@ -104,24 +104,18 @@ class Dropout(Layer):
"""
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
super(Dropout, self).__init__(**kwargs)
self.rate = min(1., max(0., rate))
self.noise_shape = noise_shape
self.seed = seed
self.supports_masking = True
def _get_noise_shape(self, _):
return self.noise_shape
# Inheritance call order:
# 1) tf.layers.Dropout, 2) keras.layers.Layer, 3) tf.layers.Layer
super(Dropout, self).__init__(**kwargs)
def call(self, inputs, training=None):
if 0. < self.rate < 1.:
noise_shape = self._get_noise_shape(inputs)
def dropped_inputs():
return K.dropout(inputs, self.rate, noise_shape, seed=self.seed)
return K.in_train_phase(dropped_inputs, inputs, training=training)
return inputs
if training is None:
training = K.learning_phase()
output = super(Dropout, self).call(inputs, training=training)
if training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access
return output
def get_config(self):
config = {'rate': self.rate}
@ -726,45 +720,32 @@ class Dense(tf_core_layers.Dense, Layer):
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
# TODO(fchollet): move weight constraint support to core layers.
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=2)
self.supports_masking = True
def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]
super(Dense, self).build(input_shape)
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
# TODO(fchollet): move weight constraint support to core layers.
if self.kernel_constraint:
self.constraints[self.kernel] = self.kernel_constraint
if self.use_bias and self.bias_constraint:
self.constraints[self.bias] = self.bias_constraint
self.built = True
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'use_bias':
self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'units': self.units,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint)
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(Dense, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -22,12 +22,11 @@ from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras import constraints
from tensorflow.contrib.keras.python.keras import initializers
from tensorflow.contrib.keras.python.keras import regularizers
from tensorflow.contrib.keras.python.keras.engine import InputSpec
from tensorflow.contrib.keras.python.keras.engine import Layer
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import normalization as tf_normalization_layers
class BatchNormalization(Layer):
class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer):
"""Batch normalization layer (Ioffe and Szegedy, 2014).
Normalize the activations of the previous layer at each batch,
@ -86,148 +85,59 @@ class BatchNormalization(Layer):
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(BatchNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.momentum = momentum
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
self.moving_mean_initializer = initializers.get(moving_mean_initializer)
self.moving_variance_initializer = initializers.get(
moving_variance_initializer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
super(BatchNormalization, self).__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=initializers.get(beta_initializer),
gamma_initializer=initializers.get(gamma_initializer),
moving_mean_initializer=initializers.get(moving_mean_initializer),
moving_variance_initializer=initializers.get(
moving_variance_initializer),
beta_regularizer=regularizers.get(beta_regularizer),
gamma_regularizer=regularizers.get(gamma_regularizer),
**kwargs
)
# TODO(fchollet): move weight constraint support to core layers.
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
dim = input_shape[self.axis]
if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: dim})
shape = (dim,)
if self.scale:
self.gamma = self.add_weight(
shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(
shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.moving_mean = self.add_weight(
shape=shape,
name='moving_mean',
initializer=self.moving_mean_initializer,
trainable=False)
self.moving_variance = self.add_weight(
shape=shape,
name='moving_variance',
initializer=self.moving_variance_initializer,
trainable=False)
self.built = True
super(BatchNormalization, self).build(input_shape)
# TODO(fchollet): move weight constraint support to core layers.
if self.center and self.beta_constraint:
self.constraints[self.beta] = self.beta_constraint
if self.scale and self.gamma_constraint:
self.constraints[self.gamma] = self.gamma_constraint
def call(self, inputs, training=None):
input_shape = inputs.get_shape().as_list()
# Prepare broadcasting shape.
ndim = len(input_shape)
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
# Determines whether broadcasting is needed.
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
normed, mean, variance = K.normalize_batch_in_training(
inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)
if training in {0, False}:
return normed
else:
self.add_update([
K.moving_average_update(self.moving_mean, mean, self.momentum),
K.moving_average_update(self.moving_variance, variance, self.momentum)
], inputs)
def normalize_inference():
if needs_broadcasting:
# In this case we must explicitly broadcast all parameters.
broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape)
broadcast_moving_variance = K.reshape(self.moving_variance,
broadcast_shape)
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
else:
broadcast_beta = None
if self.scale:
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
else:
broadcast_gamma = None
return K.batch_normalization(
inputs,
broadcast_moving_mean,
broadcast_moving_variance,
broadcast_beta,
broadcast_gamma,
epsilon=self.epsilon)
else:
return K.batch_normalization(
inputs,
self.moving_mean,
self.moving_variance,
self.beta,
self.gamma,
epsilon=self.epsilon)
# Pick the normalized form corresponding to the training phase.
return K.in_train_phase(normed, normalize_inference, training=training)
if training is None:
training = K.learning_phase()
output = super(BatchNormalization, self).call(inputs, training=training)
if training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access
return output
def get_config(self):
config = {
'axis':
self.axis,
'momentum':
self.momentum,
'epsilon':
self.epsilon,
'center':
self.center,
'scale':
self.scale,
'beta_initializer':
initializers.serialize(self.beta_initializer),
'gamma_initializer':
initializers.serialize(self.gamma_initializer),
'axis': self.axis,
'momentum': self.momentum,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
'moving_mean_initializer':
initializers.serialize(self.moving_mean_initializer),
'moving_variance_initializer':
initializers.serialize(self.moving_variance_initializer),
'beta_regularizer':
regularizers.serialize(self.beta_regularizer),
'gamma_regularizer':
regularizers.serialize(self.gamma_regularizer),
'beta_constraint':
constraints.serialize(self.beta_constraint),
'gamma_constraint':
constraints.serialize(self.gamma_constraint)
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(BatchNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

View File

@ -116,19 +116,21 @@ class NoiseLayersTest(test.TestCase):
"""
with self.test_session():
# Test single layer reuse
bn = keras.layers.BatchNormalization(input_shape=(10,))
bn = keras.layers.BatchNormalization()
x1 = keras.layers.Input(shape=(10,))
bn(x1)
_ = bn(x1)
x2 = keras.layers.Input(shape=(10,))
y2 = bn(x2)
x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
model = keras.models.Model(x2, y2)
assert len(model.updates) == 2
model.compile('sgd', 'mse')
model.train_on_batch(x, x)
assert len(model.updates) == 2
# Test model-level reuse
x3 = keras.layers.Input(shape=(10,))
y3 = model(x3)

View File

@ -23,51 +23,10 @@ from tensorflow.contrib.keras.python.keras.engine import InputSpec
from tensorflow.contrib.keras.python.keras.engine import Layer
from tensorflow.contrib.keras.python.keras.utils import conv_utils
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import pooling as tf_pooling_layers
class _Pooling1D(Layer):
"""Abstract class for different pooling 1D layers.
"""
def __init__(self, pool_size=2, strides=None, padding='valid', **kwargs):
super(_Pooling1D, self).__init__(**kwargs)
if strides is None:
strides = pool_size
self.pool_size = conv_utils.normalize_tuple(pool_size, 1, 'pool_size')
self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.input_spec = InputSpec(ndim=3)
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0],
self.padding, self.strides[0])
return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
raise NotImplementedError
def call(self, inputs):
inputs = K.expand_dims(inputs, 2) # add dummy last dimension
output = self._pooling_function(
inputs=inputs,
pool_size=self.pool_size + (1,),
strides=self.strides + (1,),
padding=self.padding,
data_format='channels_last')
return K.squeeze(output, 2) # remove dummy last dimension
def get_config(self):
config = {
'strides': self.strides,
'pool_size': self.pool_size,
'padding': self.padding
}
base_config = super(_Pooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class MaxPooling1D(_Pooling1D):
class MaxPooling1D(tf_pooling_layers.MaxPooling1D, Layer):
"""Max pooling operation for temporal data.
Arguments:
@ -85,15 +44,21 @@ class MaxPooling1D(_Pooling1D):
"""
def __init__(self, pool_size=2, strides=None, padding='valid', **kwargs):
if strides is None:
strides = pool_size
super(MaxPooling1D, self).__init__(pool_size, strides, padding, **kwargs)
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
output = K.pool2d(
inputs, pool_size, strides, padding, data_format, pool_mode='max')
return output
def get_config(self):
config = {
'strides': self.strides,
'pool_size': self.pool_size,
'padding': self.padding
}
base_config = super(MaxPooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class AveragePooling1D(_Pooling1D):
class AveragePooling1D(tf_pooling_layers.AveragePooling1D, Layer):
"""Average pooling for temporal data.
Arguments:
@ -111,78 +76,22 @@ class AveragePooling1D(_Pooling1D):
"""
def __init__(self, pool_size=2, strides=None, padding='valid', **kwargs):
if strides is None:
strides = pool_size
super(AveragePooling1D, self).__init__(pool_size, strides, padding,
**kwargs)
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
output = K.pool2d(
inputs, pool_size, strides, padding, data_format, pool_mode='avg')
return output
class _Pooling2D(Layer):
"""Abstract class for different pooling 2D layers.
"""
def __init__(self,
pool_size=(2, 2),
strides=None,
padding='valid',
data_format=None,
**kwargs):
super(_Pooling2D, self).__init__(**kwargs)
data_format = conv_utils.normalize_data_format(data_format)
if strides is None:
strides = pool_size
self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size')
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
rows = input_shape[2]
cols = input_shape[3]
else:
rows = input_shape[1]
cols = input_shape[2]
rows = conv_utils.conv_output_length(rows, self.pool_size[0], self.padding,
self.strides[0])
cols = conv_utils.conv_output_length(cols, self.pool_size[1], self.padding,
self.strides[1])
if self.data_format == 'channels_first':
return tensor_shape.TensorShape(
[input_shape[0], input_shape[1], rows, cols])
else:
return tensor_shape.TensorShape(
[input_shape[0], rows, cols, input_shape[3]])
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
raise NotImplementedError
def call(self, inputs):
output = self._pooling_function(
inputs=inputs,
pool_size=self.pool_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format)
return output
def get_config(self):
config = {
'pool_size': self.pool_size,
'padding': self.padding,
'strides': self.strides,
'data_format': self.data_format
'pool_size': self.pool_size,
'padding': self.padding
}
base_config = super(_Pooling2D, self).get_config()
base_config = super(AveragePooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class MaxPooling2D(_Pooling2D):
class MaxPooling2D(tf_pooling_layers.MaxPooling2D, Layer):
"""Max pooling operation for spatial data.
Arguments:
@ -229,16 +138,25 @@ class MaxPooling2D(_Pooling2D):
padding='valid',
data_format=None,
**kwargs):
if data_format is None:
data_format = K.image_data_format()
if strides is None:
strides = pool_size
super(MaxPooling2D, self).__init__(pool_size, strides, padding, data_format,
**kwargs)
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
output = K.pool2d(
inputs, pool_size, strides, padding, data_format, pool_mode='max')
return output
def get_config(self):
config = {
'pool_size': self.pool_size,
'padding': self.padding,
'strides': self.strides,
'data_format': self.data_format
}
base_config = super(MaxPooling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class AveragePooling2D(_Pooling2D):
class AveragePooling2D(tf_pooling_layers.AveragePooling2D, Layer):
"""Average pooling operation for spatial data.
Arguments:
@ -285,68 +203,12 @@ class AveragePooling2D(_Pooling2D):
padding='valid',
data_format=None,
**kwargs):
super(AveragePooling2D, self).__init__(pool_size, strides, padding,
data_format, **kwargs)
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
output = K.pool2d(
inputs, pool_size, strides, padding, data_format, pool_mode='avg')
return output
class _Pooling3D(Layer):
"""Abstract class for different pooling 3D layers.
"""
def __init__(self,
pool_size=(2, 2, 2),
strides=None,
padding='valid',
data_format=None,
**kwargs):
super(_Pooling3D, self).__init__(**kwargs)
if data_format is None:
data_format = K.image_data_format()
if strides is None:
strides = pool_size
self.pool_size = conv_utils.normalize_tuple(pool_size, 3, 'pool_size')
self.strides = conv_utils.normalize_tuple(strides, 3, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=5)
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
len_dim1 = input_shape[2]
len_dim2 = input_shape[3]
len_dim3 = input_shape[4]
else:
len_dim1 = input_shape[1]
len_dim2 = input_shape[2]
len_dim3 = input_shape[3]
len_dim1 = conv_utils.conv_output_length(len_dim1, self.pool_size[0],
self.padding, self.strides[0])
len_dim2 = conv_utils.conv_output_length(len_dim2, self.pool_size[1],
self.padding, self.strides[1])
len_dim3 = conv_utils.conv_output_length(len_dim3, self.pool_size[2],
self.padding, self.strides[2])
if self.data_format == 'channels_first':
return tensor_shape.TensorShape(
[input_shape[0], input_shape[1], len_dim1, len_dim2, len_dim3])
else:
return tensor_shape.TensorShape(
[input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]])
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
raise NotImplementedError
def call(self, inputs):
output = self._pooling_function(
inputs=inputs,
pool_size=self.pool_size,
strides=self.strides,
padding=self.padding,
data_format=self.data_format)
return output
super(AveragePooling2D, self).__init__(pool_size, strides, padding,
data_format, **kwargs)
def get_config(self):
config = {
@ -355,11 +217,11 @@ class _Pooling3D(Layer):
'strides': self.strides,
'data_format': self.data_format
}
base_config = super(_Pooling3D, self).get_config()
base_config = super(AveragePooling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class MaxPooling3D(_Pooling3D):
class MaxPooling3D(tf_pooling_layers.MaxPooling3D, Layer):
"""Max pooling operation for 3D data (spatial or spatio-temporal).
Arguments:
@ -402,16 +264,25 @@ class MaxPooling3D(_Pooling3D):
padding='valid',
data_format=None,
**kwargs):
if data_format is None:
data_format = K.image_data_format()
if strides is None:
strides = pool_size
super(MaxPooling3D, self).__init__(pool_size, strides, padding, data_format,
**kwargs)
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
output = K.pool3d(
inputs, pool_size, strides, padding, data_format, pool_mode='max')
return output
def get_config(self):
config = {
'pool_size': self.pool_size,
'padding': self.padding,
'strides': self.strides,
'data_format': self.data_format
}
base_config = super(MaxPooling3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class AveragePooling3D(_Pooling3D):
class AveragePooling3D(tf_pooling_layers.AveragePooling3D, Layer):
"""Average pooling operation for 3D data (spatial or spatio-temporal).
Arguments:
@ -454,13 +325,22 @@ class AveragePooling3D(_Pooling3D):
padding='valid',
data_format=None,
**kwargs):
if data_format is None:
data_format = K.image_data_format()
if strides is None:
strides = pool_size
super(AveragePooling3D, self).__init__(pool_size, strides, padding,
data_format, **kwargs)
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
output = K.pool3d(
inputs, pool_size, strides, padding, data_format, pool_mode='avg')
return output
def get_config(self):
config = {
'pool_size': self.pool_size,
'padding': self.padding,
'strides': self.strides,
'data_format': self.data_format
}
base_config = super(AveragePooling3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class _GlobalPooling1D(Layer):

View File

@ -71,7 +71,6 @@ class AvgPool2DTest(test.TestCase):
height, width = 3, 6
images = np.random.uniform(size=(5, 2, height, width))
output = _layers.avg_pool2d(images, [3, 3], data_format='NCHW')
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
self.assertListEqual(output.get_shape().as_list(), [5, 2, 1, 2])
def testCollectOutputs(self):
@ -2692,7 +2691,6 @@ class MaxPool2DTest(test.TestCase):
height, width = 3, 6
images = np.random.uniform(size=(5, 3, height, width)).astype(np.float32)
output = _layers.max_pool2d(images, [3, 3], data_format='NCHW')
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
self.assertListEqual(output.get_shape().as_list(), [5, 3, 1, 2])
def testCollectOutputs(self):

View File

@ -503,6 +503,7 @@ tf_gen_op_libs(
"control_flow_ops",
"ctc_ops",
"data_flow_ops",
"dataset_ops",
"function_ops",
"functional_ops",
"image_ops",
@ -580,6 +581,7 @@ cc_library(
":control_flow_ops_op_lib",
":ctc_ops_op_lib",
":data_flow_ops_op_lib",
":dataset_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@ -707,6 +709,7 @@ cc_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:ctc_ops",
"//tensorflow/core/kernels:data_flow",
"//tensorflow/core/kernels:dataset_ops",
"//tensorflow/core/kernels:fake_quant_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:image",
@ -1408,11 +1411,6 @@ tf_cuda_library(
"framework/**/*.cc",
"util/**/*.h",
"util/**/*.cc",
] + [
"graph/edgeset.h",
"graph/edgeset.cc",
"graph/graph.h",
"graph/graph.cc",
],
exclude = [
"**/*test*",
@ -1553,6 +1551,8 @@ tf_cuda_library(
"graph/colors.cc",
"graph/control_flow.cc",
"graph/costmodel.cc",
"graph/edgeset.cc",
"graph/graph.cc",
"graph/graph_constructor.cc",
"graph/graph_def_builder.cc",
"graph/graph_partition.cc",

View File

@ -48,9 +48,9 @@ class ConstantFoldingTest : public ::testing::Test {
TensorShape shape) {
EXPECT_TRUE(n->IsConstant());
const TensorProto* tensor_proto;
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto));
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor_proto));
DataType dtype;
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
Tensor t(dtype);
EXPECT_TRUE(t.FromProto(*tensor_proto));
test::ExpectClose(t, test::AsTensor(values, shape));
@ -61,9 +61,9 @@ class ConstantFoldingTest : public ::testing::Test {
TensorShape shape) {
EXPECT_TRUE(n->IsConstant());
const TensorProto* tensor_proto;
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto));
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor_proto));
DataType dtype;
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
Tensor t(dtype);
EXPECT_TRUE(t.FromProto(*tensor_proto));
test::ExpectTensorEqual<T>(t, test::AsTensor(values, shape));

View File

@ -92,28 +92,31 @@ bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) {
}
}
}
const AttrSlice attrs = node->attrs();
string text;
const NodeDef& def = node->def();
string text = "";
if (IsSend(node)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
string recv_device;
TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
"(", tensor_name, " @", recv_device);
TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device));
text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
tensor_name, " @", recv_device);
is_transfer_node = true;
} else if (IsRecv(node)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
string send_device;
TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
"(", tensor_name, " @", send_device);
TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device));
text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
tensor_name, " @", send_device);
is_transfer_node = true;
} else {
text =
strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
str_util::Join(node->requested_inputs(), ", "), ")");
text = strings::StrCat(
memory, def.name(), " = ", def.op(), "(",
str_util::Join(
std::vector<StringPiece>(def.input().begin(), def.input().end()),
", "),
")");
}
node_stats->set_timeline_label(text);
return is_transfer_node;
@ -519,7 +522,7 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
EdgeInfo* dst_edge = item->output_edge_base();
for (auto e : n->out_edges()) {
dst_edge->dst_id = e->dst()->id();
CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits
CHECK_LE(e->src_output(), ((int32)0x3FFFFFFF)); // Must fit in 31 bits
dst_edge->output_slot = e->src_output();
dst_edge->is_last = false;
const int output_slot = dst_edge->output_slot;
@ -637,7 +640,7 @@ Status ExecutorImpl::Initialize() {
Status s = params_.create_kernel(n->def(), &item->kernel);
if (!s.ok()) {
item->kernel = nullptr;
s = AttachDef(s, *n);
s = AttachDef(s, n->def());
LOG(ERROR) << "Executor failed to create kernel. " << s;
return s;
}
@ -665,7 +668,7 @@ Status ExecutorImpl::Initialize() {
frame_info->nodes->push_back(n);
if (IsEnter(n)) {
string enter_name;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "frame_name", &enter_name));
EnsureFrameInfo(enter_name)->input_count++;
}
}
@ -720,7 +723,7 @@ Status InferAllocAttr(const Node* n, const Node* dst,
// so these two cases are not mutually exclusive.
if (IsRecv(n)) {
string src_name;
s = GetNodeAttr(n->attrs(), "send_device", &src_name);
s = GetNodeAttr(n->def(), "send_device", &src_name);
if (!s.ok()) return s;
DeviceNameUtils::ParsedName parsed_src_name;
if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
@ -745,7 +748,7 @@ Status InferAllocAttr(const Node* n, const Node* dst,
}
if (IsSend(dst)) {
string dst_name;
s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
if (!s.ok()) return s;
DeviceNameUtils::ParsedName parsed_dst_name;
if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
@ -1358,7 +1361,7 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
if (IsEnter(curr_node)) {
// Enter a child frame.
TF_RETURN_IF_ERROR(
GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
GetNodeAttr(curr_node->def(), "frame_name", &frame_name));
parent = curr_node;
} else if (IsExit(curr_node)) {
// Exit to the parent frame.
@ -1552,7 +1555,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
if (vlog_) {
VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
<< SummarizeNode(*node) << " is dead: " << tagged_node.is_dead;
<< SummarizeNodeDef(node->def())
<< " is dead: " << tagged_node.is_dead;
}
Entry* input_tensors = GetInputTensors(input_frame, input_iter);
@ -1606,7 +1610,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
if (vlog_) {
VLOG(2) << this << " Async kernel done: "
<< SummarizeNode(*state->item->node);
<< SummarizeNodeDef(state->item->node->def());
}
if (stats) nodestats::SetOpEnd(stats);
EntryVector outputs;
@ -1807,7 +1811,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
// tensor value at i-th output.
if (!IsSwitch(node) && !IsRecv(node)) {
s.Update(errors::Internal("Missing ", i, "-th output from ",
SummarizeNode(*node)));
SummarizeNodeDef(node->def())));
}
} else {
Entry* out = &((*outputs)[i]);
@ -1874,7 +1878,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
DataTypeString(dtype),
" does not match declared output type ",
DataTypeString(item.output_type(i)),
" for node ", SummarizeNode(*node)));
" for node ", SummarizeNodeDef(node->def())));
}
}
if (!val.is_ref()) {
@ -1911,7 +1915,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
&impl_->gview_, input_iter, ready);
} else if (item->is_enter) {
bool is_constant;
Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
Status s = GetNodeAttr(node->def(), "is_constant", &is_constant);
DCHECK(s.ok()) << s;
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
output_iter = 0;
@ -2237,7 +2241,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
FrameState** child) {
// Get the child frame name.
string enter_name;
Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name);
Status s = GetNodeAttr(node->def(), "frame_name", &enter_name);
DCHECK(s.ok()) << s;
const string child_name = MakeFrameName(frame, iter, enter_name);
@ -2255,7 +2259,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
if (vlog_) VLOG(2) << "Create frame: " << child_name;
int parallel_iters;
s = GetNodeAttr(node->attrs(), "parallel_iterations", &parallel_iters);
s = GetNodeAttr(node->def(), "parallel_iterations", &parallel_iters);
DCHECK(s.ok()) << s;
FrameState* temp = new FrameState(impl_, parallel_iters);
temp->frame_name = child_name;

View File

@ -150,7 +150,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
~FunctionLibraryRuntimeImpl() override;
Status Instantiate(const string& function_name, AttrSlice attrs,
Status Instantiate(const string& function_name,
const InstantiateAttrValueMap& attrs,
Handle* handle) override;
const FunctionBody* GetFunctionBody(Handle handle) override;
@ -207,7 +208,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
};
std::vector<Item*> items_;
Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
Status FunctionDefToBody(const FunctionDef& fdef,
const InstantiateAttrValueMap& attrs,
FunctionBody** fbody);
Status CreateItem(Handle handle, Item** item);
Status GetOrCreateItem(Handle handle, Item** item);
@ -322,7 +324,7 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
// Try to instantiate this function for the func/attr. Maybe its
// cached already.
Handle handle;
TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
TF_RETURN_IF_ERROR(Instantiate(ndef.op(), ndef.attr(), &handle));
const FunctionBody* fbody = GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
@ -353,9 +355,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
return s;
}
Status FunctionLibraryRuntimeImpl::FunctionDefToBody(const FunctionDef& fdef,
AttrSlice attrs,
FunctionBody** fbody) {
Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
const FunctionDef& fdef, const InstantiateAttrValueMap& attrs,
FunctionBody** fbody) {
// Instantiates the function template into a graph def.
InstantiationResult result;
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig_, &result));
@ -388,13 +390,11 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
// TODO(josh11b): Should filter out the attrs from func that aren't used
// by the gradient function.
TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
TF_RETURN_IF_ERROR(
FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), g_body));
TF_RETURN_IF_ERROR(FunctionDefToBody(grad_fdef, func.attr(), g_body));
} else {
// f is a user-defined function.
Handle f_handle;
TF_RETURN_IF_ERROR(
Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle));
TF_RETURN_IF_ERROR(Instantiate(func.name(), func.attr(), &f_handle));
const FunctionBody* f_body = GetFunctionBody(f_handle);
CHECK_NOTNULL(f_body);
*g_body = SymbolicGradient(*f_body);
@ -402,9 +402,9 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
return Status::OK();
}
Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
AttrSlice attrs,
Handle* handle) {
Status FunctionLibraryRuntimeImpl::Instantiate(
const string& function_name, const InstantiateAttrValueMap& attrs,
Handle* handle) {
const string key = Canonicalize(function_name, attrs);
{
mutex_lock l(mu_);
@ -417,7 +417,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
Status s;
FunctionBody* fbody = nullptr;
if (function_name == kGradientOp) {
const AttrValue* f = attrs.Find(kFuncAttr);
const AttrValue* f = gtl::FindOrNull(attrs, kFuncAttr);
if (f == nullptr) {
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
}
@ -427,7 +427,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
}
const string grad = lib_def_->FindGradient(func.name());
if (!grad.empty()) {
return Instantiate(grad, AttrSlice(&func.attr()), handle);
return Instantiate(grad, func.attr(), handle);
}
TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody));
} else {
@ -989,12 +989,13 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
for (Node* node : graph->nodes()) {
VLOG(3) << "Expanding " << node->DebugString();
bool noinline;
if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
if (fld->GetAttr(node->def(), kNoInlineAttr, &noinline).ok() && noinline) {
VLOG(3) << "noinline: " << node->DebugString();
continue;
}
FunctionLibraryRuntime::Handle handle;
Status s = lib->Instantiate(node->type_string(), node->attrs(), &handle);
Status s =
lib->Instantiate(node->type_string(), node->def().attr(), &handle);
if (!s.ok()) {
// Either "node" is a primitive op, or the instantiation failed.
if (errors::IsNotFound(s)) {
@ -1102,7 +1103,7 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
continue;
}
int index;
TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
TF_CHECK_OK(GetNodeAttr(n->def(), "index", &index));
CHECK_LE(0, index);
CHECK_LT(index, node_vec->size());
(*node_vec)[index] = n;

View File

@ -40,7 +40,6 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
namespace {
typedef FunctionDefHelper FDH;
@ -59,29 +58,13 @@ void HasError(const Status& s, const string& substr) {
<< s << ", expected substring " << substr;
}
// A helper class to make AttrSlice from initializer lists
class Attrs {
public:
Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) {
for (const auto& aval : attrs) {
map_.insert({aval.first, aval.second.proto});
}
}
operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
private:
AttrValueMap map_;
};
class FunctionTest : public ::testing::Test {
protected:
FunctionTest()
: device_(DeviceFactory::NewDevice("CPU", {},
"/job:localhost/replica:0/task:0")) {}
void Create(const FunctionDef& fdef, Attrs attrs) {
void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) {
exec_ = nullptr;
InstantiationResult result;
TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result));
@ -168,8 +151,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
fdef_lib_ = lib_def_->ToProto();
}
Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args,
std::vector<Tensor*> rets) {
Status Run(const string& name, InstantiateAttrValueSlice attrs,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@ -205,7 +188,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return Status::OK();
}
std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) {
std::unique_ptr<Graph> GetFuncBody(const string& name,
InstantiateAttrValueSlice attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@ -219,7 +203,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return ret;
}
std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) {
std::unique_ptr<Graph> GetGradBody(const string& func,
InstantiateAttrValueSlice attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(func, attrs, &handle);
if (!status.ok()) {
@ -630,14 +615,13 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
// Instantiating "XTimesTwo" should fail.
FunctionLibraryRuntime::Handle handle;
HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle),
HasError(lib_->Instantiate("XTimesTwo", {{"T", DT_FLOAT}}, &handle),
"Not found: type attr not found");
// But XTimesFour and XTimes16 instantiation should succeed. Only
// when they run, they fail because XTimesTwo is bad.
TF_CHECK_OK(
lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle));
TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle));
TF_CHECK_OK(lib_->Instantiate("XTimesFour", {{"T", DT_FLOAT}}, &handle));
TF_CHECK_OK(lib_->Instantiate("XTimes16", {{"T", DT_FLOAT}}, &handle));
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
@ -944,7 +928,8 @@ bool DoNothing(Graph* g) { return false; }
GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
const FunctionDef& fdef) {
InstantiationResult result;
TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
InstantiateAttrValueMap empty;
TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result));
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
@ -1263,5 +1248,4 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
}
} // end namespace
} // end namespace tensorflow

View File

@ -103,13 +103,13 @@ void Benchmark::Run(int iters) { RunWithArgs({}, {}, iters); }
string GetRendezvousKey(const Node* node) {
string send_device;
TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device));
TF_CHECK_OK(GetNodeAttr(node->def(), "send_device", &send_device));
string recv_device;
TF_CHECK_OK(GetNodeAttr(node->attrs(), "recv_device", &recv_device));
TF_CHECK_OK(GetNodeAttr(node->def(), "recv_device", &recv_device));
string tensor_name;
TF_CHECK_OK(GetNodeAttr(node->attrs(), "tensor_name", &tensor_name));
TF_CHECK_OK(GetNodeAttr(node->def(), "tensor_name", &tensor_name));
uint64 send_device_incarnation;
TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device_incarnation",
TF_CHECK_OK(GetNodeAttr(node->def(), "send_device_incarnation",
reinterpret_cast<int64*>(&send_device_incarnation)));
return Rendezvous::CreateKey(send_device, send_device_incarnation,
recv_device, tensor_name, FrameAndIter(0, 0));

View File

@ -49,11 +49,11 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
}
}
for (Node* n : matches) {
AttrSlice n_attrs = n->attrs();
AttrSlice n_attrs(n->def());
auto base_make_node = [n, g, &n_attrs](const string& op,
const string& name) {
NodeBuilder node_builder(name, op);
node_builder.Device(n->requested_device());
node_builder.Device(n->def().device());
string colo;
if (GetNodeAttr(n_attrs, "_class", &colo).ok()) {
node_builder.Attr("_class", colo);

View File

@ -55,7 +55,7 @@ class ResourceVariableReadPass : public GraphOptimizationPass {
}
for (Node* read : matches) {
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(read->attrs(), "dtype", &dtype));
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(read->def()), "dtype", &dtype));
std::vector<Node*> in_control_edges;
std::vector<std::pair<Node*, int>> in_edges;
for (const Edge* edge : read->in_edges()) {

View File

@ -76,7 +76,7 @@ void ColocationGroups(const Node& node,
std::vector<string> class_specs;
// TODO(vrv): We should consider adding a GetNodeAttr that returns a
// StringPiece, to avoid a copy.
if (!GetNodeAttrSimple(node.attrs(), kColocationAttrNameStringPiece,
if (!GetNodeAttrSimple(node.def(), kColocationAttrNameStringPiece,
&class_specs)) {
// No attribute value is equivalent to the empty colocation_group.
*colocation_groups = {
@ -329,7 +329,7 @@ class ColocationGraph {
AddDebugInfo(node_root, &debug_info);
DeviceNameUtils::ParsedName specified_device_name;
if (DeviceNameUtils::ParseFullName(node->requested_device(),
if (DeviceNameUtils::ParseFullName(node->def().device(),
&specified_device_name) &&
specified_device_name == members_[node_root].device_name) {
// The specified device and merged set device match, and
@ -348,27 +348,27 @@ class ColocationGraph {
std::sort(device_names.begin(), device_names.end());
return errors::InvalidArgument(
"Operation was explicitly assigned to ",
node->requested_device(), " but available devices are [ ",
"Operation was explicitly assigned to ", node->def().device(),
" but available devices are [ ",
str_util::Join(device_names, ", "), " ]. Make sure ",
"the device specification refers to a valid device.");
} else if (specified_device_name.has_type) {
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
node->requested_device(), "' because no supported kernel for ",
node->def().device(), "' because no supported kernel for ",
specified_device_name.type, " devices is available.",
debug_info);
} else {
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
node->requested_device(), debug_info);
node->def().device(), debug_info);
}
} else {
// The specified device may be a valid device but the
// merged set device is different, so print both.
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
node->requested_device(),
node->def().device(),
"' because the node was colocated with a group of nodes that "
"required incompatible device '",
DeviceNameUtils::ParsedNameToString(
@ -513,7 +513,7 @@ class ColocationGraph {
return errors::Internal("Assigned device '", node.assigned_device_name(),
"' does not have registered OpKernel support "
"for ",
node.type_string());
node.def().op());
} else {
// This node has not yet been assigned to a device, so we
// calculate any constraints due to the set of registered
@ -527,25 +527,25 @@ class ColocationGraph {
registered_device_types.insert(d->device_type());
}
return errors::InvalidArgument(
"No OpKernel was registered to support Op '", node.type_string(),
"No OpKernel was registered to support Op '", node.def().op(),
"' with these attrs. Registered devices: [",
str_util::Join(registered_device_types, ","),
"], Registered kernels:\n",
KernelsRegisteredForOp(node.type_string()));
KernelsRegisteredForOp(node.def().op()));
}
// If the NodeDef contains a device, then we interpret it as a
// (partial) device specification.
if (!node.requested_device().empty()) {
if (!node.def().device().empty()) {
// The user has specified a device in the NodeDef, try to find a
// valid device matching their specification in the set of
// devices.
// NOTE: The full name may specify a device that is not in
// n.supported_device_types(), but we check that in AssignDevice().
if (!DeviceNameUtils::ParseFullName(node.requested_device(),
if (!DeviceNameUtils::ParseFullName(node.def().device(),
&member->device_name)) {
return errors::InvalidArgument("Malformed device specification '",
node.requested_device(), "'");
node.def().device(), "'");
}
}
}
@ -644,7 +644,7 @@ Status SimplePlacer::Run() {
continue;
}
status = colocation_graph.AddNode(*node);
if (!status.ok()) return AttachDef(status, *node);
if (!status.ok()) return AttachDef(status, node->def());
}
// 2. Enumerate the constraint edges, and use them to update the disjoint
@ -707,7 +707,7 @@ Status SimplePlacer::Run() {
"be on the same device), but the two nodes "
"were assigned two different devices: ",
status.error_message()),
*node);
node->def());
}
}
}
@ -749,7 +749,7 @@ Status SimplePlacer::Run() {
return AttachDef(
errors::InvalidArgument("Cannot assign a device for operation '",
node->name(), "': ", status.error_message()),
*node);
node->def());
}
// Returns the first device in sorted devices list so we will always
@ -791,7 +791,7 @@ Status SimplePlacer::Run() {
return AttachDef(
errors::InvalidArgument("Cannot assign a device for operation '",
node->name(), "': ", status.error_message()),
*node);
node->def());
}
string assigned_device = devices[0]->name();

View File

@ -223,16 +223,19 @@ Status DebugNodeInserter::InsertNodes(
void DebugNodeInserter::DeparallelizeWhileLoops(Graph* graph, Device* device) {
for (Node* node : graph->nodes()) {
if (node->IsEnter()) {
const AttrValue* parallel_iterations =
node->attrs().Find("parallel_iterations");
if (parallel_iterations && parallel_iterations->i() > 1) {
LOG(INFO) << "For debugging, tfdbg is changing the "
<< "parallel_iterations attribute of the Enter/RefEnter "
<< "node \"" << node->name() << "\" on device \""
<< device->name() << "\" from " << parallel_iterations->i()
<< " to 1. (This does not affect subsequent non-debug "
<< "runs.)";
node->AddAttr<int64>("parallel_iterations", 1);
for (const auto& attr : node->def().attr()) {
if (attr.first == "parallel_iterations") {
if (attr.second.i() > 1) {
LOG(INFO) << "For debugging, tfdbg is changing the "
<< "parallel_iterations attribute of the Enter/RefEnter "
<< "node \"" << node->name() << "\" on device \""
<< device->name() << "\" from " << attr.second.i()
<< " to 1. (This does not affect subsequent non-debug "
<< "runs.)";
node->AddAttr<int64>("parallel_iterations", 1);
}
break;
}
}
}
}

View File

@ -188,7 +188,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const RunState* run_state,
SimpleGraphExecutionState* execution_state);
string DetailText(const Node& node, const NodeExecStats& ns) {
string DetailText(const NodeDef& def, const NodeExecStats& ns) {
int64 tot = 0;
for (auto& no : ns.output()) {
tot += no.tensor_description().allocation_description().requested_bytes();
@ -197,8 +197,12 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
if (tot >= 0.1 * 1048576.0) {
bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
}
return strings::StrCat(bytes, node.name(), " = ", node.type_string(), "(",
str_util::Join(node.requested_inputs(), ", "), ")");
return strings::StrCat(
bytes, def.name(), " = ", def.op(), "(",
str_util::Join(
std::vector<StringPiece>(def.input().begin(), def.input().end()),
", "),
")");
}
private:
@ -786,7 +790,7 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
if (!ns.timeline_label().empty()) {
details = ns.timeline_label();
} else if (found_node_in_graph) {
details = DetailText(*node, ns);
details = DetailText(node->def(), ns);
} else {
// Leave details string empty
}

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb_text.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@ -45,11 +44,12 @@ namespace {
// Otherwise (arg_def is a simple type T), *is_type_list is set to
// false, and *dtypes is set to a single element vector, whose only
// element is T.
Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
bool* is_type_list, DataTypeVector* dtypes) {
Status ArgNumType(const InstantiateAttrValueMap& attrs,
const OpDef::ArgDef& arg_def, bool* is_type_list,
DataTypeVector* dtypes) {
dtypes->clear();
if (!arg_def.type_list_attr().empty()) {
const AttrValue* v = attrs.Find(arg_def.type_list_attr());
const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_list_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ",
arg_def.type_list_attr());
@ -64,7 +64,7 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
*is_type_list = false;
int num = 1;
if (!arg_def.number_attr().empty()) {
const AttrValue* v = attrs.Find(arg_def.number_attr());
const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ", arg_def.type_attr());
}
@ -77,7 +77,7 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
} else if (arg_def.type_attr().empty()) {
dtype = DT_INVALID;
} else {
const AttrValue* v = attrs.Find(arg_def.type_attr());
const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ", arg_def.type_attr());
}
@ -92,17 +92,18 @@ void AddAttr(const string& name, const T& val, NodeDef* ndef) {
SetAttrValue(val, &((*ndef->mutable_attr())[name]));
}
Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
Status ValidateSignatureWithAttrs(const OpDef& sig,
const InstantiateAttrValueMap& attr_values) {
// attr_values should specify all attrs defined in fdef.
for (const auto& a : sig.attr()) {
const AttrValue* v = attr_values.Find(a.name());
if (!v) {
auto const iter = attr_values.find(a.name());
if (iter == attr_values.end()) {
return errors::NotFound("Attr ", a.name(), " is not found from ",
SummarizeOpDef(sig));
}
Status status = AttrValueHasType(*v, a.type());
Status status = AttrValueHasType(iter->second, a.type());
if (!status.ok()) {
errors::AppendToMessage(&status, "for attr '", a.name(), "'");
errors::AppendToMessage(&status, "for attr '", iter->first, "'");
return status;
}
}
@ -145,7 +146,7 @@ class FunctionInstantiationHelper {
// Builds index for nodes that can be used as node's input arguments.
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
AttrSlice attr_values) {
const InstantiateAttrValueMap& attr_values) {
bool is_type_list;
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(
@ -174,7 +175,8 @@ class FunctionInstantiationHelper {
return Status::OK();
}
Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
Status BuildNodeOutputIndex(const NodeDef& node,
const InstantiateAttrValueMap& attrs,
const int arg_index) {
const OpDef* node_sig = nullptr;
TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
@ -204,7 +206,8 @@ class FunctionInstantiationHelper {
return Status::OK();
}
Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
Status InstantiateNode(const NodeDef& fnode,
const InstantiateAttrValueMap& attrs) {
const OpDef* fnode_sig = nullptr;
TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
NodeDef* gnode = AddNode(fnode.name());
@ -292,7 +295,7 @@ class FunctionInstantiationHelper {
}
Status AddReturnNode(
const OpDef::ArgDef& ret_def, AttrSlice attrs,
const OpDef::ArgDef& ret_def, const InstantiateAttrValueMap& attrs,
const ::tensorflow::protobuf::Map<string, string>& ret_map,
int* ret_index) {
auto ret_iter = ret_map.find(ret_def.name());
@ -601,7 +604,7 @@ string Print(const GraphDef& gdef) {
Status AddDefaultAttrs(const string& op,
const GetFunctionSignature& get_function,
AttrValueMap* attrs) {
InstantiateAttrValueMap* attrs) {
const OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(get_function(op, &op_def));
AttrSlice attr_slice(attrs);
@ -617,7 +620,8 @@ Status AddDefaultAttrs(const string& op,
} // end namespace
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
Status InstantiateFunction(const FunctionDef& fdef,
const InstantiateAttrValueMap& attr_values,
GetFunctionSignature get_function,
InstantiationResult* result) {
VLOG(3) << "Instantiation Function: " << Print(fdef);
@ -635,17 +639,19 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
}
}
auto substitute = [attr_values](StringPiece name, AttrValue* val) {
if (const AttrValue* v = attr_values.Find(name)) {
*val = *v;
auto substitute = [&attr_values](const string& name, AttrValue* val) {
auto iter = attr_values.find(name);
if (iter == attr_values.end()) {
return false;
} else {
*val = iter->second;
return true;
}
return false;
};
// Makes a copy of all attrs in fdef and substitutes placeholders.
// After this step, every attr is bound to a concrete value.
std::vector<AttrValueMap> node_attrs;
std::vector<InstantiateAttrValueMap> node_attrs;
node_attrs.resize(fdef.node_def_size());
for (int i = 0; i < fdef.node_def_size(); ++i) {
for (auto attr : fdef.node_def(i).attr()) {
@ -662,7 +668,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
}
for (int i = 0; i < fdef.node_def_size(); ++i) {
s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
s = helper.BuildNodeOutputIndex(fdef.node_def(i), node_attrs[i],
result->gdef.node_size() + i);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
@ -671,7 +677,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
}
// Emits one gdef.node for each fdef.node_def.
for (int i = 0; i < fdef.node_def_size(); ++i) {
s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
s = helper.InstantiateNode(fdef.node_def(i), node_attrs[i]);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
return s;
@ -742,7 +748,8 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
return true;
}
string Canonicalize(const string& funcname, AttrSlice attrs) {
string Canonicalize(const string& funcname,
const InstantiateAttrValueMap& attrs) {
std::vector<string> entries;
entries.reserve(attrs.size());
for (auto p : attrs) {
@ -946,7 +953,8 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
// If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
// Foo's attributes.
const NameAttrList* forward_func_attrs;
if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
if (!GetNodeAttr(AttrSlice(&ndef.attr()), kFuncAttr, &forward_func_attrs)
.ok()) {
return nullptr;
}
const string& func_name = forward_func_attrs->name();
@ -973,30 +981,34 @@ FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
return lib;
}
template <typename T>
Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
const string& attr, T* value) const {
const FunctionDef* fdef = GetAttrImpl(ndef);
if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
return Status::OK();
Status InstantiateFunction(const FunctionDef& fdef,
InstantiateAttrValueSlice attr_values,
GetFunctionSignature get_function,
InstantiationResult* result) {
InstantiateAttrValueMap m;
for (const auto& aval : attr_values) {
m.insert({aval.first, aval.second.proto});
}
return errors::InvalidArgument("Attr ", attr, " is not defined.");
return InstantiateFunction(fdef, m, std::move(get_function), result);
}
template <typename T>
Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
T* value) const {
return GetAttr(node.def(), attr, value);
string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) {
InstantiateAttrValueMap m;
for (const auto& aval : attrs) {
m.insert({aval.first, aval.second.proto});
}
return Canonicalize(funcname, m);
}
#define GET_ATTR(T) \
template Status FunctionLibraryDefinition::GetAttr(const Node&, \
const string&, T*) const; \
template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
const string&, T*) const;
GET_ATTR(string)
GET_ATTR(bool)
#undef GET_ATTR
Status FunctionLibraryRuntime::Instantiate(const string& function_name,
InstantiateAttrValueSlice attrs,
Handle* handle) {
InstantiateAttrValueMap m;
for (const auto& aval : attrs) {
m.insert({aval.first, aval.second.proto});
}
return Instantiate(function_name, m, handle);
}
void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
if (val.size() >= 2 && val[0] == '$') {

View File

@ -36,7 +36,6 @@ class CancellationManager;
class OpKernel;
class ResourceMgr;
class ScopedStepContainer;
class Node;
// FunctionDefHelper::Create is a convenient helper to construct a
// FunctionDef proto.
@ -191,6 +190,11 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
// InstantiateFunction calls "get_function" to find signatures of other
// functions and primitive ops.
// Placeholders in "fdef" is substituted based on "attr_values" here.
typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap;
typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
InstantiateAttrValueSlice;
// GetFunctionSignature(func name, opdef) returns OK if the func name is found
// and opdef is filled with a pointer to the corresponding signature
// (a OpDef proto). Otherwise, returns an error.
@ -202,7 +206,12 @@ struct InstantiationResult {
DataTypeVector ret_types;
GraphDef gdef;
};
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
Status InstantiateFunction(const FunctionDef& fdef,
const InstantiateAttrValueMap& attr_values,
GetFunctionSignature get_function,
InstantiationResult* result);
Status InstantiateFunction(const FunctionDef& fdef,
InstantiateAttrValueSlice attr_values,
GetFunctionSignature get_function,
InstantiationResult* result);
@ -232,7 +241,9 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
// space. But it may be change as the implementation
// evolves. Therefore, it should not be persisted or compared across
// address spaces.
string Canonicalize(const string& funcname, AttrSlice attrs);
string Canonicalize(const string& funcname,
const InstantiateAttrValueMap& attrs);
string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs);
// Represents a function call frame. I.e., the data structure used to
// pass arguments to a function and retrieve its results.
@ -319,16 +330,9 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// Given a node def 'ndef', inspects attributes of the callee
// function to derive the attribute 'value' for 'attr'. Returns OK
// iff the attribute is given by the function's definition.
// TODO(irving): Remove; keep only the const Node& version.
template <typename T>
Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
// Given a node, inspects attributes of the callee function to derive the
// attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
// function's definition.
template <typename T>
Status GetAttr(const Node& node, const string& attr, T* value) const;
// Returns a proto representation of the state of this function library.
FunctionDefLibrary ToProto() const;
@ -371,8 +375,11 @@ class FunctionLibraryRuntime {
// Returns OK and fills in "handle" if the instantiation succeeds.
// Otherwise returns an error and "handle" is undefined.
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
virtual Status Instantiate(const string& function_name,
const InstantiateAttrValueMap& attrs,
Handle* handle) = 0;
Status Instantiate(const string& function_name,
InstantiateAttrValueSlice attrs, Handle* handle);
// Returns the function body for the instantiated function given its
// handle 'h'. Returns nullptr if "h" is not found.
@ -499,15 +506,17 @@ bool RegisterOp(const string& op, Creator func);
Status GetOpGradientCreator(const string& op, Creator* creator);
};
// Declare explicit instantiations of GetAttr
#define GET_ATTR(T) \
extern template Status FunctionLibraryDefinition::GetAttr( \
const Node&, const string&, T*) const; \
extern template Status FunctionLibraryDefinition::GetAttr( \
const NodeDef&, const string&, T*) const;
GET_ATTR(string)
GET_ATTR(bool)
#undef GET_ATTR
// Implementation details.
template <typename T>
Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
const string& attr, T* value) const {
const FunctionDef* fdef = GetAttrImpl(ndef);
if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
return Status::OK();
}
return errors::InvalidArgument("Attr ", attr, " is not defined.");
}
} // end namespace tensorflow

View File

@ -29,24 +29,6 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// A helper class to make AttrSlice from initializer lists
class Attrs {
public:
Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
std::pair<string, FunctionDefHelper::AttrValueWrapper>>
attrs) {
for (const auto& aval : attrs) {
map_.insert({aval.first, aval.second.proto});
}
}
operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
private:
AttrValueMap map_;
};
typedef FunctionDefHelper FDH;
@ -64,6 +46,8 @@ y: A scalar in type T.
)doc");
static InstantiateAttrValueMap kNoAttrs;
TEST(TFunc, SquarePlusOne) {
auto fdef = FDH::Create(
// Name
@ -97,8 +81,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
// Instantiate one with T=float
InstantiationResult result;
TF_ASSERT_OK(
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
const char* e2 = R"P(
(x:float) -> (y:float) {
a = Square[T=float](x)
@ -143,8 +126,7 @@ ControlDep(x:int32) -> (y:int32) {
// Instantiate one with T=float
InstantiationResult result;
TF_ASSERT_OK(
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
const char* e2 = R"P(
(x:int32) -> (y:int32) {
a = Identity[T=int32](x)
@ -189,7 +171,8 @@ BackCompat() -> (y:float) {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
TF_ASSERT_OK(
InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result));
// Should get T=float from Op's default.
const char* e2 = R"P(
() -> (a:float) {
@ -226,7 +209,7 @@ NTimesT(x:float, y:float) -> (z:float) {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
const char* e2 = R"P(
(x:float, y:float) -> (a:float) {
a = AddN[N=2, T=float](x, y)
@ -289,8 +272,8 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
// Instantiate one with T=float
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}),
GetOpSig, &result));
TF_ASSERT_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig,
&result));
const char* e2 = R"P(
(x_0:float, x_1:float, x_2:float) -> (y:float) {
a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2)
@ -332,7 +315,7 @@ ControlDeps(x:float) -> () {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
const char* e2 = R"P(
(x:float) -> () {
a = One[T=float]() @ x
@ -412,7 +395,7 @@ Test(i:float) -> (o:float) {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
const char* e2 = R"P(
(i:float) -> (o:float) {
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
@ -484,7 +467,7 @@ MySelect(x:float) -> (z:float) {
EXPECT_EQ(DebugString(fdef), e);
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
const char* e2 = R"P(
(x:float) -> (z:float) {
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
@ -505,9 +488,8 @@ TEST(InstantiateErrors, Not_Sufficient_Attrs) {
auto fdef =
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
InstantiationResult result;
HasError(
InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result),
"Attr T is not found from ");
HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result),
"Attr T is not found from ");
}
#if 0 // TODO(josh11b): Enable this test once having an extra attr is an error.
@ -515,7 +497,7 @@ TEST(InstantiateErrors, Too_Many_Attrs) {
auto fdef =
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
InstantiationResult result;
HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}),
HasError(InstantiateFunction(fdef, {{"T", DT_INT32}, {"U", DT_FLOAT}},
GetOpSig, &result),
"Attr U is not found in ");
}
@ -526,7 +508,7 @@ TEST(InstantiateErrors, AttrValue_Value_Placeholder) {
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
InstantiationResult result;
HasError(
InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result),
InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result),
"AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'");
}
@ -536,15 +518,14 @@ TEST(InstantiateErrors, Unbounded_Attr) {
{{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}},
});
InstantiationResult result;
HasError(
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result),
"Failed to bind all placeholders");
HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result),
"Failed to bind all placeholders");
}
TEST(InstantiateErrors, DupArgs) {
auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Duplicated arg name");
}
@ -555,7 +536,7 @@ TEST(InstantiateErrors, Dup_Node_Names) {
{{"y"}, "One", {}, {{"T", DT_FLOAT}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Duplicated ret name");
}
@ -566,7 +547,7 @@ TEST(InstantiateErrors, Node_Arg_Notfound) {
},
{});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"input z is not found");
}
@ -576,7 +557,7 @@ TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"input x[0] expected type int32 != float, the type of x[0]");
}
@ -587,7 +568,7 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) {
{{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"input[2] == '^z', is not found.");
}
@ -598,7 +579,7 @@ TEST(InstantiateErrors, FuncRet_Missing) {
},
{});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Return y missing");
}
@ -609,7 +590,7 @@ TEST(InstantiateErrors, FuncRet_NotFound) {
},
{{"y", "z"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Return y -> z is not found");
}
@ -620,7 +601,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) {
},
{{"z", "x:y:0"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Return y missing");
}
@ -632,7 +613,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) {
// },
// {{"y", "x:y:0"}, {"z", "x:y:0"}});
// InstantiationResult result;
// HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
// "ret is not found");
// }
@ -642,7 +623,7 @@ TEST(InstantiateErrors, FuncRet_TypeMismatch) {
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Invalid ret types y : float vs. double\n\tIn function output y");
}
@ -668,7 +649,7 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
},
{{"y", "y:output"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"type attr not found: out_types");
}
@ -695,7 +676,7 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
},
{{"y", "y:output"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Invalid ret types");
}
@ -722,7 +703,7 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
},
{{"y", "y:output"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"input unknown is not found");
}
@ -743,7 +724,7 @@ TEST(InstantiateErrors, TooManyInputs) {
{{"z", "a:sum:0"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Expected input[2] == 'x' to be a control input.");
}
@ -764,7 +745,7 @@ TEST(InstantiateErrors, TooFewInputs) {
{{"z", "a:sum:0"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Attempt to access beyond input size: 2 >= 2");
}
@ -792,7 +773,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray1) {
{{"z", "a:sum:0"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Expected input[1] == 'y' to be a control input.");
}
@ -820,7 +801,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray2) {
{{"z", "a:sum:0"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Input a:output too long for inputs");
}
@ -841,7 +822,7 @@ TEST(InstantiateErrors, TypeMismatch) {
{{"z", "a:sum:0"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"input inputs[1] expected type float != int32, the type of y[0]");
}
@ -893,17 +874,17 @@ TEST(FunctionCallFrame, Float_Float_Float) {
}
TEST(Canonicalize, Basic) {
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
{"transpose_a", false},
{"transpose_b", false}})),
EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
{"transpose_a", false},
{"transpose_b", false}}),
"MatMul[T=float,transpose_a=false,transpose_b=false]");
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
{"transpose_b", false},
{"transpose_a", false}})),
EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
{"transpose_b", false},
{"transpose_a", false}}),
"MatMul[T=float,transpose_a=false,transpose_b=false]");
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE},
{"transpose_b", true},
{"transpose_a", false}})),
EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE},
{"transpose_b", true},
{"transpose_a", false}}),
"MatMul[T=double,transpose_a=false,transpose_b=true]");
}
@ -1167,5 +1148,4 @@ TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) {
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
}
} // end namespace
} // end namespace tensorflow

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/scanner.h"
@ -37,23 +36,18 @@ namespace tensorflow {
const char* const kColocationAttrName = "_class";
const char* const kColocationGroupPrefix = "loc:@";
AttrSlice::AttrSlice() : ndef_(nullptr) {
static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
attrs_ = kEmptyAttrValueMap;
}
AttrSlice::AttrSlice(const NodeDef& node_def)
: ndef_(&node_def), attrs_(&ndef_->attr()) {}
AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
string ret;
string SummarizeNodeDef(const NodeDef& node_def) {
string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
// We sort the attrs so the output is deterministic.
std::vector<string> attr_names;
attr_names.reserve(attrs.size());
for (const auto& attr : attrs) {
attr_names.reserve(node_def.attr().size());
for (const auto& attr : node_def.attr()) {
attr_names.push_back(attr.first);
}
std::sort(attr_names.begin(), attr_names.end());
@ -61,34 +55,20 @@ static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
for (const string& attr_name : attr_names) {
if (!first) strings::StrAppend(&ret, ", ");
first = false;
strings::StrAppend(&ret, attr_name, "=",
SummarizeAttrValue(*attrs.Find(attr_name)));
auto iter = node_def.attr().find(attr_name);
strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second));
}
// Consider the device to be a final attr with name "_device".
if (!device.empty()) {
if (!node_def.device().empty()) {
if (!first) strings::StrAppend(&ret, ", ");
first = false;
strings::StrAppend(&ret, "_device=\"", device, "\"");
strings::StrAppend(&ret, "_device=\"", node_def.device(), "\"");
}
return ret;
}
string AttrSlice::SummarizeNode() const {
return ndef_ ? SummarizeNodeDef(*ndef_)
: strings::StrCat(
"[", SummarizeAttrsHelper(*this, StringPiece()), "]");
}
string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
string SummarizeNodeDef(const NodeDef& node_def) {
string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
strings::StrAppend(&ret, "](");
// Output inputs, including control inputs, verbatim.
bool first = true;
first = true;
for (const string& input : node_def.input()) {
if (!first) strings::StrAppend(&ret, ", ");
first = false;
@ -129,28 +109,12 @@ Status AttrSlice::Find(StringPiece attr_name,
// Skip AttachDef for internal attrs since it is a little bit
// expensive and it is common for them to correctly not be included
// in a NodeDef.
if (!attr_name.starts_with("_") && ndef_ != nullptr) {
if (!StringPiece(attr_name).starts_with("_") && ndef_) {
s = AttachDef(s, *ndef_);
}
return s;
}
bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
if (size() != other.size()) return false;
for (const auto& attr : *other.attrs_) {
auto iter = attrs_->find(attr.first);
if (iter == attrs_->end()) return false;
// TODO(irving): Comparing AttrValues by proto is slightly buggy, since
// TensorProto is a nonunique representation of Tensor. This bug will go
// away once AttrSlice switches over to NodeInfo.
iter->second.SerializeToString(&scratch->a);
attr.second.SerializeToString(&scratch->b);
if (scratch->a != scratch->b) return false;
}
return true;
}
// The ... is to allow the caller to inject some value validation code. Use
// just ; if no additional validation code is needed.
#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
@ -377,14 +341,14 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
if (StringPiece(input).starts_with("^")) {
seen_control = true;
if (input.find(':') != string::npos) {
return errors::InvalidArgument(
"Control input '", input,
"' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def));
return errors::InvalidArgument("Control input '", input,
"' must not have ':' in NodeDef: ",
SummarizeNodeDef(node_def));
}
} else if (seen_control) {
return errors::InvalidArgument(
"Non-control input '", input,
"' after control input in NodeDef: ", SummarizeNodeDef(node_def));
return errors::InvalidArgument("Non-control input '", input,
"' after control input in NodeDef: ",
SummarizeNodeDef(node_def));
} else {
++num_inputs;
}
@ -394,8 +358,8 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
for (const auto& attr : op_def.attr()) {
if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
return errors::InvalidArgument("OpDef has duplicate attr name '",
attr.name(),
"': ", SummarizeOpDef(op_def));
attr.name(), "': ",
SummarizeOpDef(op_def));
}
}
for (const auto& attr : node_def.attr()) {
@ -419,9 +383,8 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
"with your GraphDef-generating binary.).");
}
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ValidateAttrValue(attr.second, *iter->second),
"; NodeDef: ", SummarizeNodeDef(node_def), "; ",
SummarizeOpDef(op_def));
ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ",
SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def));
// Keep track of which attr names have (not) been found in the NodeDef.
op_attrs.erase(iter);
}
@ -468,9 +431,9 @@ Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
} else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
*num = 1;
} else {
return errors::InvalidArgument(
"Argument '", arg_def.name(),
"' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
return errors::InvalidArgument("Argument '", arg_def.name(),
"' incorrectly specified in op definition: ",
SummarizeOpDef(op_def));
}
return Status::OK();
}
@ -502,11 +465,6 @@ Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
return Status::OK();
}
Status NameRangesForNode(const Node& node, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs) {
return NameRangesForNode(node.def(), op_def, inputs, outputs);
}
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
for (const auto& attr_def : op_def.attr()) {
AttrSlice attrs(*node_def);
@ -607,8 +565,4 @@ Status AttachDef(const Status& status, const NodeDef& node_def) {
return ret;
}
Status AttachDef(const Status& status, const Node& node) {
return AttachDef(status, node.def());
}
} // namespace tensorflow

View File

@ -29,8 +29,6 @@ limitations under the License.
namespace tensorflow {
class Node;
// Name of the attribute used to encode node colocation constraints.
//
// Nodes can be co-located on the same device. Desire for explicit co-location
@ -41,9 +39,8 @@ extern const char* const kColocationAttrName;
// String prefix applied to the operation name for colocation constraints.
extern const char* const kColocationGroupPrefix;
// Produce a human-readable version of a Node or NodeDef that is more concise
// Produce a human-readable version of a NodeDef that is more concise
// than a text-format proto.
string SummarizeNode(const Node& node);
string SummarizeNodeDef(const NodeDef& node_def);
typedef protobuf::Map<string, AttrValue> AttrValueMap;
@ -81,11 +78,8 @@ class AttrSlice {
public:
AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
AttrSlice(); // Empty
explicit AttrSlice(const AttrValueMap* a);
int size() const { return attrs_->size(); }
// Returns the attr with attr_name if found. Otherwise, returns
// nullptr.
const AttrValue* Find(StringPiece attr_name) const;
@ -94,33 +88,6 @@ class AttrSlice {
// NotFound status.
Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
// Helper class to avoid allocations in EqualAttrs.
// TODO(irving): Will go away once NodeInfo is used.
struct Scratch {
string a;
string b;
};
// Check if all attrs and attr values match. Does not take defaults into
// account.
//
// TODO(irving): There is a bug in this routine inherited from its
// OptimizerCSE::EqualAttrs precedecessor. The same tensor attr can be
// represented in more than one way as an AttrValue, since TensorProto is
// not 1-1. This bug will go away once I replace everything with NodeInfo,
// which stores a Tensor object directly. The Scratch object will also go
// away.
bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
// If this AttrSlice has an attached NodeDef, summarize it. This is for
// error messages only: we intentionally do not provide direct access to the
// NodeDef, since it is not always there.
string SummarizeNode() const;
// Iteration over all attrs
AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
AttrValueMap::const_iterator end() const { return attrs_->end(); }
private:
const NodeDef* ndef_;
const AttrValueMap* attrs_;
@ -216,12 +183,9 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
// corresponding input/output index range. For example,
// input "foo" corresponds to input indices
// [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
// TODO(irving): Remove the NodeDef version; keep only the Node version.
typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap;
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs);
Status NameRangesForNode(const Node& node, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs);
// Adds default values to *node_def for unspecified attrs from op_def.
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
@ -242,7 +206,6 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
// Returns "status" with kernel's NodeDef attached as additional text
// in the error message.
Status AttachDef(const Status& status, const NodeDef& node_def);
Status AttachDef(const Status& status, const Node& node);
} // namespace tensorflow

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@ -843,10 +842,13 @@ bool InTypeList(DataType dt, const AttrValue& type_list) {
return false;
}
// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
// an error if attrs in kernel_def are not found, or have a mismatching type.
Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
// Returns whether the attrs in the NodeDef satisfy the constraints in
// the kernel_def. Returns an error if attrs in kernel_def are not
// found, or have a mismatching type.
Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def,
bool* match) {
*match = false;
AttrSlice attrs(node_def);
for (const auto& constraint : kernel_def.constraint()) {
if (constraint.allowed_values().list().type_size() == 0) {
return errors::Unimplemented(
@ -870,7 +872,7 @@ Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
"' that has value '", SummarizeAttrValue(*found),
"' that does not have type 'type' or 'list(type)' in NodeDef "
"'",
attrs.SummarizeNode(), "'");
SummarizeNodeDef(node_def), "'");
}
for (int t : found->list().type()) {
@ -883,7 +885,7 @@ Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
} else {
return errors::InvalidArgument(
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def),
"', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
}
}
@ -893,7 +895,6 @@ Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
static const StringPiece kKernelAttr("_kernel");
// TODO(irving): Replace with const Node& version below.
Status FindKernelRegistration(const DeviceType& device_type,
const NodeDef& node_def,
const KernelRegistration** reg,
@ -926,16 +927,8 @@ Status FindKernelRegistration(const DeviceType& device_type,
return Status::OK();
}
Status FindKernelRegistration(const DeviceType& device_type, const Node& node,
const KernelRegistration** reg,
bool* was_attr_mismatch) {
return FindKernelRegistration(device_type, node.def(), reg,
was_attr_mismatch);
}
} // namespace
// TODO(irving): Change const NodeDef& to const Node&
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
const KernelDef** def, string* kernel_class_name) {
const KernelRegistration* reg = nullptr;

View File

@ -184,8 +184,8 @@ class InferenceContext {
}
#ifndef NDEBUG
for (int i = 0; i < num_outputs(); ++i) {
DCHECK(output(i).IsSet())
<< i << " for " << node_def_.name() << " of type " << node_def_.op();
DCHECK(output(i).IsSet()) << i << " for " << node_def().name()
<< " of type " << node_def().op();
}
#endif // NDEBUG
return s;
@ -394,6 +394,11 @@ class InferenceContext {
// the value.
Status MakeDimForScalarInput(int idx, DimensionHandle* out);
// Returns the NodeDef. The returned reference does not outlive the
// InferenceContext, and it should not be used after InferenceContext is
// destroyed.
const NodeDef& node_def() { return node_def_; }
// Look up the attr for the NodeDef being evaluated with name attr_name and
// set *value to its value. If no attr with attr_name is found in def(), or
// the attr does not have a matching type, a non-ok status will be returned.

View File

@ -88,7 +88,7 @@ Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
out_info->frame = out;
out_info->parent_frame = frame;
TF_RETURN_IF_ERROR(
GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name));
GetNodeAttr(out->def(), "frame_name", &out_info->frame_name));
if (out_info->frame_name.empty()) {
return errors::InvalidArgument("The Enter node ", out->name(),
" must have a frame name.");

View File

@ -78,7 +78,7 @@ string Node::DebugString() const {
} else {
strings::StrAppend(&ret, " op device:");
strings::StrAppend(&ret, "{", assigned_device_name_, "}");
strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}");
}
return ret;
}
@ -474,7 +474,7 @@ void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
for (size_t i = 0; i < inputs.size(); ++i) {
const Edge* edge = inputs[i];
if (edge == nullptr) {
node_def->add_input(node->requested_inputs()[i]);
node_def->add_input(node->def().input(i));
} else {
const Node* src = edge->src();
if (!src->IsOp()) continue;

View File

@ -71,7 +71,6 @@ class Node {
int cost_id() const { return cost_id_; }
const string& name() const { return props_->node_def_.name(); }
const string& type_string() const { return props_->node_def_.op(); }
// def() provides the NodeDef the user supplied, but the specifics
// of this Node may have changed due to placement, optimization, etc.
// In particular:
@ -81,7 +80,6 @@ class Node {
// * def().device() is the "user's requested device" and may not match
// the actual assigned device, see assigned_device_name() below;
// * def().attr() is authoritative.
// TODO(irving): Replace with NodeInfo.
const NodeDef& def() const { return props_->node_def_; }
const OpDef& op_def() const { return *props_->op_def_; }
@ -94,10 +92,6 @@ class Node {
DataType output_type(int32 o) const { return props_->output_types_[o]; }
const DataTypeVector& output_types() const { return props_->output_types_; }
// The device requested by the user. For the actual assigned device,
// use assigned_device_name() below.
const string& requested_device() const { return def().device(); }
// This gives the device the runtime has assigned this node to. If
// you want the device the user requested, use def().device() instead.
// TODO(josh11b): Validate that the assigned_device, if not empty:
@ -109,14 +103,6 @@ class Node {
assigned_device_name_ = device_name;
}
// Read only access to attributes
AttrSlice attrs() const { return AttrSlice(def()); }
// Inputs requested by the NodeDef. For the actual inputs, use in_edges.
const protobuf::RepeatedPtrField<string>& requested_inputs() const {
return def().input();
}
// Get the neighboring nodes via edges either in or out of this node.
gtl::iterator_range<NeighborIter> in_nodes() const;
gtl::iterator_range<NeighborIter> out_nodes() const;

View File

@ -424,7 +424,7 @@ Status GraphConstructor::ValidateShape(Node* node) {
// For nodes with the _output_shapes atttribute, override the shape.
std::vector<TensorShapeProto> shape_attrs;
const char* kAttrName = "_output_shapes";
if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) {
if (!GetNodeAttr(node->def(), kAttrName, &shape_attrs).ok()) {
// No _output_shapes attribute, the AddNode call above was sufficient.
return Status::OK();
}
@ -458,7 +458,7 @@ Status GraphConstructor::ValidateShape(Node* node) {
// functions that are not critical to correct execution but
// would cause graphs to fail if imported after correcting.
//
const string& op = node->type_string();
const string& op = node->def().op();
const std::vector<string> whitelist = {
// To be removed after 2017/03/08.
"RandomShuffleQueue", "PaddingFIFOQueue", "FIFOQueue",

View File

@ -146,7 +146,7 @@ class GraphConstructorTest : public ::testing::Test {
return "";
}
std::vector<string> value;
Status s = GetNodeAttr(n->attrs(), kColocationAttrName, &value);
Status s = GetNodeAttr(n->def(), kColocationAttrName, &value);
if (!s.ok()) {
return "";
}
@ -997,7 +997,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_DefaultAttrs) {
}
ASSERT_TRUE(a != nullptr);
int value = 0;
s = GetNodeAttr(a->attrs(), "default_int", &value);
s = GetNodeAttr(a->def(), "default_int", &value);
ASSERT_EQ(Status::OK(), s) << s << " -- " << a->def().DebugString();
EXPECT_EQ(31415, value);
}
@ -1201,9 +1201,9 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) {
// Check that t1's NodeDef is consistent with graph
Node* t1 = FindNode("t1");
ASSERT_EQ(t1->requested_inputs().size(), 2);
ASSERT_EQ(t1->requested_inputs()[0], "input:1");
ASSERT_EQ(t1->requested_inputs()[1], "input:0");
ASSERT_EQ(t1->def().input_size(), 2);
ASSERT_EQ(t1->def().input(0), "input:1");
ASSERT_EQ(t1->def().input(1), "input:0");
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) {
@ -1254,19 +1254,19 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) {
// Check that NodeDefs are consistent with graph
Node* t1 = FindNode("import/t1");
ASSERT_EQ(t1->requested_inputs().size(), 2);
EXPECT_EQ(t1->requested_inputs()[0], "input:0");
EXPECT_EQ(t1->requested_inputs()[1], "input:0");
ASSERT_EQ(t1->def().input_size(), 2);
EXPECT_EQ(t1->def().input(0), "input:0");
EXPECT_EQ(t1->def().input(1), "input:0");
Node* t2 = FindNode("import/t2");
ASSERT_EQ(t2->requested_inputs().size(), 2);
EXPECT_EQ(t2->requested_inputs()[0], "import/t1:0");
EXPECT_EQ(t2->requested_inputs()[1], "import/t1:0");
ASSERT_EQ(t2->def().input_size(), 2);
EXPECT_EQ(t2->def().input(0), "import/t1:0");
EXPECT_EQ(t2->def().input(1), "import/t1:0");
Node* t3 = FindNode("import/t3");
ASSERT_EQ(t3->requested_inputs().size(), 2);
EXPECT_EQ(t3->requested_inputs()[0], "import/unmapped_input:0");
EXPECT_EQ(t3->requested_inputs()[1], "import/unmapped_input:1");
ASSERT_EQ(t3->def().input_size(), 2);
EXPECT_EQ(t3->def().input(0), "import/unmapped_input:0");
EXPECT_EQ(t3->def().input(1), "import/unmapped_input:1");
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
@ -1795,24 +1795,24 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) {
// Test that node defs are consistent with graph
Node* w1 = FindNode("import/W1");
ASSERT_EQ(w1->requested_inputs().size(), 2);
EXPECT_EQ(w1->requested_inputs()[0], "^W1");
EXPECT_EQ(w1->requested_inputs()[1], "^W2");
ASSERT_EQ(w1->def().input_size(), 2);
EXPECT_EQ(w1->def().input(0), "^W1");
EXPECT_EQ(w1->def().input(1), "^W2");
Node* input = FindNode("import/input");
ASSERT_EQ(input->requested_inputs().size(), 2);
EXPECT_EQ(input->requested_inputs()[0], "^W1");
EXPECT_EQ(input->requested_inputs()[1], "^W2");
ASSERT_EQ(input->def().input_size(), 2);
EXPECT_EQ(input->def().input(0), "^W1");
EXPECT_EQ(input->def().input(1), "^W2");
Node* input2 = FindNode("import/input2");
ASSERT_EQ(input2->requested_inputs().size(), 2);
EXPECT_EQ(input2->requested_inputs()[0], "^W1");
EXPECT_EQ(input2->requested_inputs()[1], "^W2");
ASSERT_EQ(input2->def().input_size(), 2);
EXPECT_EQ(input2->def().input(0), "^W1");
EXPECT_EQ(input2->def().input(1), "^W2");
Node* t1 = FindNode("import/t1");
ASSERT_EQ(t1->requested_inputs().size(), 2);
EXPECT_EQ(t1->requested_inputs()[0], "import/input:0");
EXPECT_EQ(t1->requested_inputs()[1], "import/input:1");
ASSERT_EQ(t1->def().input_size(), 2);
EXPECT_EQ(t1->def().input(0), "import/input:0");
EXPECT_EQ(t1->def().input(1), "import/input:1");
}
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
@ -1856,15 +1856,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
// Test that node defs are consistent with graph
Node* merge = FindNode("merge");
ASSERT_EQ(merge->requested_inputs().size(), 3);
EXPECT_EQ(merge->requested_inputs()[0], "input:0");
EXPECT_EQ(merge->requested_inputs()[1], "t1:0");
EXPECT_EQ(merge->requested_inputs()[2], "^W1");
ASSERT_EQ(merge->def().input_size(), 3);
EXPECT_EQ(merge->def().input(0), "input:0");
EXPECT_EQ(merge->def().input(1), "t1:0");
EXPECT_EQ(merge->def().input(2), "^W1");
Node* t1 = FindNode("t1");
ASSERT_EQ(t1->requested_inputs().size(), 2);
EXPECT_EQ(t1->requested_inputs()[0], "merge:0");
EXPECT_EQ(t1->requested_inputs()[1], "merge:0");
ASSERT_EQ(t1->def().input_size(), 2);
EXPECT_EQ(t1->def().input(0), "merge:0");
EXPECT_EQ(t1->def().input(1), "merge:0");
}
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) {

View File

@ -356,7 +356,7 @@ string ControlLoopName(const string& name) {
}
bool IsControlLoop(const Node* node) {
const string& name = node->name();
const string& name = node->def().name();
return StringPiece(name).starts_with("_cloop");
}
@ -468,7 +468,7 @@ Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
const string& device_name = edge->dst()->assigned_device_name();
const string& frame_name = src_info.frame_name;
int parallel_iterations;
status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
status = GetNodeAttr(src_info.frame->def(), "parallel_iterations",
&parallel_iterations);
if (!status.ok()) return status;
@ -903,11 +903,11 @@ Status Partition(const PartitionOptions& opts, Graph* g,
send_start_time = opts.start_times[src->id()].value();
recv_start_time = opts.start_times[dst->id()].value();
} else {
status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
status = GetNodeAttr(src->def(), "_start_time", &send_start_time);
if (!status.ok()) {
return status;
}
status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time);
if (!status.ok()) {
return status;
}

View File

@ -318,21 +318,21 @@ TEST_F(GraphTest, AddAttr) {
n1->AddAttr("_a", "new_attr");
string attr;
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr));
EXPECT_EQ("new_attr", attr);
Node* n2 = graph_.CopyNode(n1);
n1->AddAttr("_b", "new_attr_2");
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr));
EXPECT_EQ("new_attr", attr);
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr));
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_b", &attr));
EXPECT_EQ("new_attr_2", attr);
EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr));
EXPECT_EQ(Status::OK(), GetNodeAttr(n2->def(), "_a", &attr));
EXPECT_EQ("new_attr", attr);
EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr));
EXPECT_NE(Status::OK(), GetNodeAttr(n2->def(), "_b", &attr));
}
// Convert edge iteration results into a sorted string.

View File

@ -56,9 +56,11 @@ class OptimizerCSE {
bool Optimize(const std::function<bool(const Node*)>& consider_fn);
private:
struct Scratch;
static size_t NodeHash(const Node* n);
static bool Equivalent(const Node* a, const Node* b,
AttrSlice::Scratch* scratch);
static bool Equivalent(const Node* a, const Node* b, Scratch* s);
static bool EqualAttrs(const Node* a, const Node* b, Scratch* s);
Graph* g_;
};
@ -108,7 +110,7 @@ size_t OptimizerCSE::NodeHash(const Node* n) {
// Hash the attrs. For example, this makes sure different constants
// end up in different hash buckets.
string tmp;
for (const auto& attr : n->attrs()) {
for (const auto& attr : n->def().attr()) {
tmp = attr.first;
attr.second.AppendToString(&tmp);
// Add hashes of attrs, so the order of attrs doesn't matter.
@ -120,6 +122,28 @@ size_t OptimizerCSE::NodeHash(const Node* n) {
return h;
}
struct OptimizerCSE::Scratch {
// For EqualAttrs():
string a;
string b;
};
bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) {
if (a->def().attr_size() != b->def().attr_size()) return false;
for (const auto& attr : b->def().attr()) {
auto iter = a->def().attr().find(attr.first);
if (iter == a->def().attr().end()) return false;
// Note: it should be safe to compare proto serializations of the attr
// values since at most one field should be set in each (indeed, it
// should be the same field).
iter->second.SerializeToString(&scratch->a);
attr.second.SerializeToString(&scratch->b);
if (scratch->a != scratch->b) return false;
}
return true;
}
static bool HasRefInput(const Node* n) {
for (auto dt : n->input_types()) {
if (IsRefType(dt)) return true;
@ -127,8 +151,7 @@ static bool HasRefInput(const Node* n) {
return false;
}
bool OptimizerCSE::Equivalent(const Node* a, const Node* b,
AttrSlice::Scratch* scratch) {
bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) {
// Different op names are different
if (a->type_string() != b->type_string()) return false;
@ -141,7 +164,7 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b,
// Compare attrs. Note that equal attrs implies equal input and
// output types.
if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false;
if (!EqualAttrs(a, b, scratch)) return false;
// Compare input sources
if (a->num_inputs() != b->num_inputs()) return false;
@ -183,7 +206,7 @@ bool OptimizerCSE::Optimize(
// Scratch space for Equivalent calls. Allocated here and passed in to
// Equivalent to avoid allocation inside the loop below.
bool changed = false;
AttrSlice::Scratch scratch;
Scratch scratch;
for (Node* n : order) {
if (!n->IsOp()) continue;

View File

@ -192,9 +192,9 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
Tensor tensor_names;
Tensor shape_and_slices;
TF_RETURN_IF_ERROR(
GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names));
TF_RETURN_IF_ERROR(
GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices));
GetNodeAttr(AttrSlice(tensor_names_op->def()), "value", &tensor_names));
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(shape_and_slices_op->def()), "value",
&shape_and_slices));
int tn_size = tensor_names.NumElements();
int var_size = added_variables.size();

View File

@ -112,15 +112,17 @@ TEST_F(QuantizeTrainingTest, SignedInput) {
TF_ASSERT_OK(
FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"),
&identity_q_node));
NodeDef identity_q = identity_q_node->def();
ASSERT_EQ("true",
SummarizeAttrValue(*identity_q_node->attrs().Find("signed_input")));
SummarizeAttrValue(identity_q.attr().find("signed_input")->second));
// Quantize_and_dequantize node for relu should have signed_input==false.
Node* relu_q_node;
TF_ASSERT_OK(
FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
&relu_q_node));
NodeDef relu_q = relu_q_node->def();
ASSERT_EQ("false",
SummarizeAttrValue(*relu_q_node->attrs().Find("signed_input")));
SummarizeAttrValue(relu_q.attr().find("signed_input")->second));
}
TEST_F(QuantizeTrainingTest, RangeGivenTrue) {
@ -163,15 +165,17 @@ TEST_F(QuantizeTrainingTest, RangeGivenTrue) {
TF_ASSERT_OK(
FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"),
&relu6_q_node));
NodeDef identity_q = relu6_q_node->def();
ASSERT_EQ("true",
SummarizeAttrValue(*relu6_q_node->attrs().Find("range_given")));
SummarizeAttrValue(identity_q.attr().find("range_given")->second));
// Quantize_and_dequantize node for relu should have range_given==true.
Node* relu_q_node;
TF_ASSERT_OK(
FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
&relu_q_node));
NodeDef relu_q = relu_q_node->def();
ASSERT_EQ("true",
SummarizeAttrValue(*relu_q_node->attrs().Find("range_given")));
SummarizeAttrValue(relu_q.attr().find("range_given")->second));
}
TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) {

View File

@ -106,7 +106,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
// Copy the _output_shapes from the original node to the feed node,
// if any.
std::vector<PartialTensorShape> output_shapes;
if (GetNodeAttr(n->attrs(), "_output_shapes", &output_shapes).ok()) {
if (GetNodeAttr(n->def(), "_output_shapes", &output_shapes).ok()) {
if (n->num_outputs() != output_shapes.size()) {
return errors::InvalidArgument(
"FeedInputs: ", t,
@ -129,8 +129,8 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
if (e->src_output() == id.second) {
to_remove.emplace_back(e);
} else if (e->src_output() == Graph::kControlSlot &&
(n->type_string() == "Placeholder" ||
n->type_string() == "PlaceholderV2")) {
(n->def().op() == "Placeholder" ||
n->def().op() == "PlaceholderV2")) {
// When feeding a Placeholder node, any outgoing control edges
// will be replaced with a control edge from the replacement
// recv_node.

View File

@ -81,7 +81,7 @@ class SubgraphTest : public ::testing::Test {
for (const string& s : expected_nodes) {
Node* n = FindNode(s);
EXPECT_TRUE(n != nullptr) << s;
if (n->type_string() == "_Send" || n->type_string() == "_Recv") {
if (n->def().op() == "_Send" || n->def().op() == "_Recv") {
EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s;
}
}
@ -367,7 +367,7 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) {
for (Node* node : graph()->nodes()) {
if (node->name() == "_recv_input_1") {
std::vector<PartialTensorShape> shapes;
TF_ASSERT_OK(GetNodeAttr(node->attrs(), "_output_shapes", &shapes));
TF_ASSERT_OK(GetNodeAttr(node->def(), "_output_shapes", &shapes));
ASSERT_EQ(1, shapes.size());
EXPECT_TRUE(PartialTensorShape({23}).IsIdenticalTo(shapes[0]));
break;

View File

@ -48,6 +48,12 @@ class Cluster {
// of the requested resources are available.
virtual Status Provision() = 0;
// Attempts to shutdown the cluster.
// Returns OK iff there are no pending calls to the Run() method and all the
// resources used by the cluster could be released. Returns an error
// otherwise.
virtual Status Shutdown() { return Status::OK(); }
// Whether soft placement is allowed. If allow_soft_placement is true,
// an op will be placed on CPU if there's no GPU implementation for the OP
// or if no GPU devices are known or registered or if we need to co-locate
@ -58,7 +64,8 @@ class Cluster {
// before Provision().
void SetNumWarmupSteps(int num_steps);
// Disable the collection of detailed statistics.
// Disable the collection of detailed statistics. Must be called
// before Provision().
void DisableDetailedStats(bool disable);
// Return the list of TensorFlow devices that are available to execute a

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
@ -91,6 +92,31 @@ Status SingleMachine::Initialize(const GrapplerItem& item) {
return Status::OK();
}
Status SingleMachine::Shutdown() {
TF_RETURN_IF_ERROR(CloseSession(true /*use_timeout*/));
// Delete the threadpool: this ensures that all the pending closures complete
// before we return. Note that if that if TF deadlocked on us, the closures
// will never complete, and the call to thread_pool_.reset() will never
// return: therefore we need to delete the threadpool with the background
// thread. That thread itself will also never complete, so the user should
// abort the process to avoid leaking too many resources.
auto n = std::make_shared<Notification>();
Env::Default()->SchedClosure([this, n]() {
thread_pool_.reset();
n->Notify();
});
int64 timeout_us = 1000000ll * timeout_s_;
const bool notified = WaitForNotificationWithTimeout(n.get(), timeout_us);
if (!notified) {
// Let the caller know that we can't shutdown the session properly since
// there are calls to Session::Run() still running.
return errors::Unavailable("The session is still running graphs after ",
timeout_s_, " seconds");
}
return Status::OK();
}
Status SingleMachine::Run(const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch,
@ -163,10 +189,11 @@ Status SingleMachine::RunWithTimeout(
mutex_lock l(close_mu_);
CHECK(!closing_);
}
auto status = std::make_shared<Status>();
auto local_metadata = std::make_shared<RunMetadata>();
const bool executed_in_time = ExecuteWithTimeout(
[this, status, local_metadata, &feed, &fetch]() {
[this, status, local_metadata, feed, fetch]() {
*status = session_->Run(run_options_, feed, {}, fetch, nullptr,
local_metadata.get());
},
@ -230,11 +257,7 @@ Status SingleMachine::ResetSession() {
LOG(INFO) << "Cleaning up previous session";
// Make sure the session is properly closed
TF_RETURN_IF_ERROR(CloseSession(true /*use_timeout*/));
// Flush all the pending closures (if any).
thread_pool_.reset(new thread::ThreadPool(
Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
TF_RETURN_IF_ERROR(Shutdown());
// We need to Reset the session to ensure that all the variables are
// deleted. But first we need to delete the session since Reset()
@ -245,6 +268,10 @@ Status SingleMachine::ResetSession() {
LOG(INFO) << "Starting new session";
// Create a new threadpool
thread_pool_.reset(new thread::ThreadPool(
Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
session_.reset(NewSession(options_));
CHECK(session_ != nullptr);

View File

@ -33,6 +33,8 @@ class SingleMachine : public Cluster {
~SingleMachine() override;
Status Provision() override;
Status Shutdown() override;
Status Initialize(const GrapplerItem& item) override;
Status Run(const GraphDef& item,
const std::vector<std::pair<string, Tensor>>& feed,

View File

@ -122,6 +122,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
] + if_cuda([
"//tensorflow/core:cuda",
@ -226,7 +227,6 @@ cc_library(
":utils",
":virtual_placer",
":virtual_scheduler",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace grappler {

View File

@ -151,8 +151,8 @@ Status GraphProperties::InferStatically() {
if (!node->assigned_device_name().empty()) {
device_names_[node->name()] = node->assigned_device_name();
} else if (!node->requested_device().empty()) {
device_names_[node->name()] = node->requested_device();
} else if (!node->def().device().empty()) {
device_names_[node->name()] = node->def().device();
} else {
device_names_[node->name()] = "not set";
}

View File

@ -26,6 +26,9 @@ limitations under the License.
#include "cuda/include/cudnn.h"
#endif
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
@ -34,6 +37,8 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -64,6 +69,77 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) {
return tensors;
}
static void ExtractExtraProperties(
const NodeDef& node,
const std::unordered_map<string, const NodeDef*>& name_to_node,
std::vector<OpInfo::TensorProperties>* extra_inputs,
protobuf::Map<string, AttrValue>* attr_map) {
OpRegistry* op_registry = OpRegistry::Global();
const OpDef* op_def;
auto s = op_registry->LookUpOpDef(node.op(), &op_def);
if (!s.ok()) {
op_def = nullptr;
}
for (int i = 0; i < node.input_size(); ++i) {
const string input_name = node.input(i);
CHECK(!input_name.empty());
TensorId input_tensor_id = ParseTensorName(input_name);
const string input_node_name = input_tensor_id.first.ToString();
auto iter = name_to_node.find(input_node_name);
if (iter == name_to_node.end()) continue;
const NodeDef* input_node = iter->second;
// The value attribute in Const input is useful for cost prediction.
if (input_node->op() == "Const") {
auto it = input_node->attr().find("value");
if (it == input_node->attr().end()) continue;
const AttrValue& attr_value = it->second;
std::vector<TensorProto> tensors = ExtractTensors(attr_value);
if (tensors.empty()) continue;
const TensorProto& t = tensors[0];
OpInfo::TensorProperties input;
input.set_dtype(t.dtype());
*(input.mutable_shape()) = t.tensor_shape();
*(input.mutable_value()) = t;
extra_inputs->push_back(input);
// For filename input, the file size can also be useful.
if (op_def &&
op_def->input_arg(i).name().find("filename") != std::string::npos) {
Tensor tensor;
CHECK(tensor.FromProto(t));
const string filename = tensor.scalar<string>()();
Env* env = Env::Default();
FileStatistics stat;
Status s = env->Stat(filename, &stat);
if (s.ok()) {
AttrValue attr;
attr.set_i(stat.length);
string attr_key = strings::StrCat("input_", i, "_filesize");
(*attr_map)[attr_key] = attr;
}
}
}
// When the input is a handle (e.g. look up table handle), the information
// in the op itself is not sufficient to predict the op memory.
if (op_def &&
op_def->input_arg(i).name().find("handle") != std::string::npos) {
string new_key = strings::StrCat("parent_", i, "_op");
AttrValue attr;
attr.set_s(input_node->op());
(*attr_map)[new_key] = attr;
// TODO(yuefengz): Only parent node's op name is copied. Copy inputs
// and attributes when necessary.
}
}
}
std::vector<OpInfo::TensorProperties> FindInputFeatures(
const NodeDef& node,
const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost,
@ -80,35 +156,6 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
continue;
}
auto iter = name_to_node.find(input_node_name);
if (iter != name_to_node.end()) {
const NodeDef* node = iter->second;
if (node->op() == "Const") {
auto it = node->attr().find("value");
if (it == node->attr().end()) {
inputs.push_back(UnknownInput());
continue;
}
const AttrValue& attr_value = it->second;
std::vector<TensorProto> tensors = ExtractTensors(attr_value);
if (tensors.empty()) {
inputs.push_back(UnknownInput());
continue;
}
for (const auto& t : tensors) {
OpInfo::TensorProperties input;
input.set_dtype(t.dtype());
*(input.mutable_shape()) = t.tensor_shape();
*(input.mutable_value()) = t;
inputs.push_back(input);
}
continue;
}
}
auto it = name_to_cost.find(input_node_name);
if (it == name_to_cost.end() || output_index < 0) {
inputs.push_back(UnknownInput());
@ -126,9 +173,9 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
return inputs;
}
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
DeviceProperties GetDeviceInfo(const string& device_str) {
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(node.device(), &parsed)) {
if (DeviceNameUtils::ParseFullName(device_str, &parsed)) {
if (parsed.type == "GPU") {
return GetLocalGPUInfo(parsed.id);
} else if (parsed.type == "CPU") {
@ -140,5 +187,31 @@ DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
return device;
}
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
return GetDeviceInfo(node.device());
}
OpInfo BuildOpInfo(
const NodeDef& node, const string& device_str,
const std::unordered_map<string, const NodeDef*>& name_to_node,
const std::vector<OpInfo::TensorProperties>& inputs) {
OpInfo op_info;
op_info.set_op(node.op());
*op_info.mutable_attr() = node.attr();
*op_info.mutable_device() = GetDeviceInfo(device_str);
for (auto& input : inputs) {
*op_info.add_inputs() = input;
}
std::vector<OpInfo::TensorProperties> extra_inputs;
ExtractExtraProperties(node, name_to_node, &extra_inputs,
op_info.mutable_attr());
for (auto& input : extra_inputs) {
*op_info.add_inputs() = input;
}
return op_info;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/platform/types.h"
@ -42,6 +43,15 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
// Returns the DeviceProperties of the device on which 'node' runs.
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
DeviceProperties GetDeviceInfo(const string& device_str);
// Builds the OpInfo proto for node, given all nodes in the graph, the node's
// device and its input properties which are typically built by shape inference
// or calling FindInputFeatures.
OpInfo BuildOpInfo(
const NodeDef& node, const string& device_str,
const std::unordered_map<string, const NodeDef*>& name_to_node,
const std::vector<OpInfo::TensorProperties>& inputs);
} // end namespace grappler
} // end namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More